feat(api): enhance domain check and whitelist functionality with improved error handling and validation

This commit is contained in:
pacnpal
2025-01-28 17:20:22 +00:00
parent 7615f074a8
commit 91e20e3711
3 changed files with 166 additions and 49 deletions

View File

@@ -1,13 +1,22 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import httpx import httpx
import logging import logging
import re
import json
from pathlib import Path
from datetime import datetime
from .config import settings from .config import settings
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for rule validation and backup
DOMAIN_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$')
RULES_BACKUP_DIR = Path("rules_backup")
RULES_BACKUP_DIR.mkdir(exist_ok=True)
class AdGuardError(Exception): class AdGuardError(Exception):
"""Base exception for AdGuard Home API errors.""" """Base exception for AdGuard Home API errors."""
pass pass
@@ -20,6 +29,10 @@ class AdGuardAPIError(AdGuardError):
"""Raised when AdGuard Home API returns an error.""" """Raised when AdGuard Home API returns an error."""
pass pass
class AdGuardValidationError(AdGuardError):
"""Raised when input validation fails."""
pass
class ResultRule(BaseModel): class ResultRule(BaseModel):
"""Rule detail according to AdGuard spec.""" """Rule detail according to AdGuard spec."""
filter_list_id: Optional[int] = Field(None, description="Filter list ID") filter_list_id: Optional[int] = Field(None, description="Filter list ID")
@@ -56,6 +69,39 @@ class SetRulesRequest(BaseModel):
"""Request model for set_rules endpoint according to AdGuard spec.""" """Request model for set_rules endpoint according to AdGuard spec."""
rules: List[str] = Field(..., description="List of filtering rules") rules: List[str] = Field(..., description="List of filtering rules")
def validate_domain(domain: str) -> bool:
"""Validate domain name format."""
if not domain or len(domain) > 255:
return False
return bool(DOMAIN_PATTERN.match(domain))
def sanitize_rule(rule: str) -> str:
"""Sanitize and validate rule format."""
# Remove any whitespace and normalize
rule = rule.strip()
# Basic XSS/injection prevention
rule = rule.replace('<', '').replace('>', '').replace('"', '').replace("'", '')
return rule
def save_rules_backup(rules: List[str], action: str = "update") -> Path:
"""Save rules to backup file with timestamp."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = RULES_BACKUP_DIR / f"rules_{action}_{timestamp}.json"
with open(backup_file, 'w') as f:
json.dump({"rules": rules, "timestamp": timestamp}, f, indent=2)
logger.info(f"Saved rules backup to {backup_file}")
return backup_file
def load_rules_backup(backup_file: Path) -> List[str]:
"""Load rules from backup file."""
try:
with open(backup_file) as f:
data = json.load(f)
return data.get("rules", [])
except Exception as e:
logger.error(f"Error loading rules backup: {str(e)}")
return []
class AdGuardClient: class AdGuardClient:
"""Client for interacting with AdGuard Home API according to OpenAPI spec.""" """Client for interacting with AdGuard Home API according to OpenAPI spec."""
@@ -114,6 +160,10 @@ class AdGuardClient:
async def check_domain(self, domain: str) -> FilterCheckHostResponse: async def check_domain(self, domain: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home according to spec.""" """Check if a domain is blocked by AdGuard Home according to spec."""
# Validate domain format
if not validate_domain(domain):
raise AdGuardValidationError(f"Invalid domain format: {domain}")
await self._ensure_authenticated() await self._ensure_authenticated()
url = f"{self.base_url}/filtering/check_host" url = f"{self.base_url}/filtering/check_host"
params = {"name": domain} params = {"name": domain}
@@ -148,42 +198,6 @@ class AdGuardClient:
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}") logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list using set_rules endpoint according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/set_rules"
# Add as a whitelist rule according to AdGuard format
data = {"rules": [f"@@||{domain}^"]}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus: async def get_filter_status(self) -> FilterStatus:
"""Get the current filtering status according to spec.""" """Get the current filtering status according to spec."""
await self._ensure_authenticated() await self._ensure_authenticated()
@@ -219,6 +233,71 @@ class AdGuardClient:
logger.error(f"Unexpected error while getting filter status: {str(e)}") logger.error(f"Unexpected error while getting filter status: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list using AdGuard filtering API."""
# Validate domain format
if not validate_domain(domain):
raise AdGuardValidationError(f"Invalid domain format: {domain}")
await self._ensure_authenticated()
try:
# First get current user rules
status = await self.get_filter_status()
current_rules = status.user_rules if status else []
# Create sanitized whitelist rule
new_rule = sanitize_rule(f"@@||{domain}^")
# Save backup of current rules
old_backup = save_rules_backup(current_rules, "before")
# Add rule if not already present
if new_rule not in current_rules:
current_rules.append(new_rule)
# Save backup of new rules before updating
new_backup = save_rules_backup(current_rules, "after")
# Update rules
url = f"{self.base_url}/filtering/set_rules"
data = {"rules": current_rules}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
logger.info(f"Updating rules list with whitelisted domain: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully updated rules list with whitelisted domain: {domain}")
return True
except (httpx.ConnectError, httpx.HTTPError) as e:
# On error, try to restore from backup
logger.error(f"Error updating rules, attempting to restore from backup: {str(e)}")
if old_backup.exists():
try:
old_rules = load_rules_backup(old_backup)
if old_rules:
await self.client.post(url, json={"rules": old_rules}, headers=headers)
logger.info("Successfully restored rules from backup")
except Exception as restore_error:
logger.error(f"Failed to restore from backup: {str(restore_error)}")
raise
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def close(self): async def close(self):
"""Close the HTTP client.""" """Close the HTTP client."""
await self.client.aclose() await self.client.aclose()

View File

@@ -13,6 +13,7 @@ from .adguard import (
AdGuardError, AdGuardError,
AdGuardConnectionError, AdGuardConnectionError,
AdGuardAPIError, AdGuardAPIError,
AdGuardValidationError,
FilterStatus, FilterStatus,
FilterCheckHostResponse, FilterCheckHostResponse,
SetRulesRequest SetRulesRequest
@@ -84,6 +85,11 @@ async def check_domain(name: str) -> FilterCheckHostResponse:
result = await client.check_domain(name) result = await client.check_domain(name)
logger.info(f"Domain check result: {result}") logger.info(f"Domain check result: {result}")
return result return result
except AdGuardValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e: except Exception as e:
logger.error(f"Error checking domain {name}: {str(e)}") logger.error(f"Error checking domain {name}: {str(e)}")
raise raise
@@ -106,19 +112,34 @@ async def add_to_whitelist(request: SetRulesRequest) -> Dict:
detail="Rules are required" detail="Rules are required"
) )
logger.info(f"Adding rules: {request.rules}") # Extract domain from whitelist rule
rule = request.rules[0]
if not rule.startswith("@@||") or not rule.endswith("^"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid whitelist rule format"
)
domain = rule[4:-1] # Remove @@|| prefix and ^ suffix
logger.info(f"Adding domain to whitelist: {domain}")
try: try:
async with adguard.AdGuardClient() as client: async with adguard.AdGuardClient() as client:
success = await client.add_allowed_domain(request.rules[0].strip("@@||^")) success = await client.add_allowed_domain(domain)
if success: if success:
return {"message": "Rules added successfully"} return {"message": f"Domain {domain} added to whitelist"}
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add rules" detail="Failed to add domain to whitelist"
) )
except AdGuardValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e: except Exception as e:
logger.error(f"Error adding rules: {str(e)}") logger.error(f"Error adding domain to whitelist: {str(e)}")
raise raise
@app.get( @app.get(
@@ -144,6 +165,8 @@ async def adguard_exception_handler(request: Request, exc: AdGuardError) -> JSON
"""Handle AdGuard-related exceptions according to spec.""" """Handle AdGuard-related exceptions according to spec."""
if isinstance(exc, AdGuardConnectionError): if isinstance(exc, AdGuardConnectionError):
status_code = status.HTTP_503_SERVICE_UNAVAILABLE status_code = status.HTTP_503_SERVICE_UNAVAILABLE
elif isinstance(exc, AdGuardValidationError):
status_code = status.HTTP_400_BAD_REQUEST
elif isinstance(exc, AdGuardAPIError): elif isinstance(exc, AdGuardAPIError):
status_code = status.HTTP_502_BAD_GATEWAY status_code = status.HTTP_502_BAD_GATEWAY
else: else:

View File

@@ -53,10 +53,14 @@
unblockDiv.innerHTML = ''; unblockDiv.innerHTML = '';
} }
} else { } else {
let errorMsg = data.message || 'Unknown error occurred';
let errorType = response.status === 400 ? 'warning' : 'error';
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
resultDiv.innerHTML = ` resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4"> <div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
<p class="font-bold">Error checking domain</p> <p class="font-bold">Error checking domain</p>
<p class="text-sm">${data.message || 'Unknown error occurred'}</p> <p class="text-sm">${errorMsg}</p>
</div>`; </div>`;
unblockDiv.innerHTML = ''; unblockDiv.innerHTML = '';
} }
@@ -94,18 +98,25 @@
}); });
if (response.ok) { if (response.ok) {
const data = await response.json();
resultDiv.innerHTML = ` resultDiv.innerHTML = `
<div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4"> <div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4">
<p class="font-bold">Success!</p> <p class="font-bold">Success!</p>
<p class="text-sm">Domain ${domain} has been added to the whitelist</p> <p class="text-sm">${data.message}</p>
<p class="text-xs mt-2">A backup of the rules has been saved for safety.</p>
</div>`; </div>`;
unblockDiv.innerHTML = ''; unblockDiv.innerHTML = '';
} else { } else {
const data = await response.json(); const data = await response.json();
let errorMsg = data.message || 'Unknown error occurred';
let errorType = response.status === 400 ? 'warning' : 'error';
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
resultDiv.innerHTML = ` resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4"> <div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
<p class="font-bold">Error unblocking domain</p> <p class="font-bold">Error unblocking domain</p>
<p class="text-sm">${data.message || 'Unknown error occurred'}</p> <p class="text-sm">${errorMsg}</p>
${errorType !== 'warning' ? '<p class="text-xs mt-2">Previous rules have been restored from backup.</p>' : ''}
</div>`; </div>`;
} }
} catch (error) { } catch (error) {
@@ -113,6 +124,7 @@
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4"> <div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4">
<p class="font-bold">Error unblocking domain</p> <p class="font-bold">Error unblocking domain</p>
<p class="text-sm">${error.message}</p> <p class="text-sm">${error.message}</p>
<p class="text-xs mt-2">Previous rules have been restored from backup.</p>
</div>`; </div>`;
} }
} }
@@ -128,9 +140,10 @@
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2"> <label for="domain" class="block text-gray-700 text-sm font-bold mb-2">
Enter Domain to Check Enter Domain to Check
</label> </label>
<input type="text" id="domain" name="domain" required <input type="text" id="domain" name="domain" required pattern="^[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9](\.[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9])*$"
class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"
placeholder="example.com"> placeholder="example.com"
title="Please enter a valid domain name">
</div> </div>
<button id="submit-btn" type="submit" <button id="submit-btn" type="submit"
class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200"> class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200">
@@ -144,6 +157,8 @@
<div class="mt-4 text-center text-gray-600 text-sm"> <div class="mt-4 text-center text-gray-600 text-sm">
Make sure your AdGuard Home instance is running and properly configured in the .env file. Make sure your AdGuard Home instance is running and properly configured in the .env file.
<br>
<span class="text-xs">Rules are automatically backed up before any changes.</span>
</div> </div>
</div> </div>
</body> </body>