feat(api): enhance domain check and whitelist functionality with improved error handling and validation

This commit is contained in:
pacnpal
2025-01-28 17:20:22 +00:00
parent 7615f074a8
commit 91e20e3711
3 changed files with 166 additions and 49 deletions

View File

@@ -1,13 +1,22 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set
from pydantic import BaseModel, Field
import httpx
import logging
import re
import json
from pathlib import Path
from datetime import datetime
from .config import settings
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants for rule validation and backup
DOMAIN_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$')
RULES_BACKUP_DIR = Path("rules_backup")
RULES_BACKUP_DIR.mkdir(exist_ok=True)
class AdGuardError(Exception):
"""Base exception for AdGuard Home API errors."""
pass
@@ -20,6 +29,10 @@ class AdGuardAPIError(AdGuardError):
"""Raised when AdGuard Home API returns an error."""
pass
class AdGuardValidationError(AdGuardError):
"""Raised when input validation fails."""
pass
class ResultRule(BaseModel):
"""Rule detail according to AdGuard spec."""
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
@@ -56,6 +69,39 @@ class SetRulesRequest(BaseModel):
"""Request model for set_rules endpoint according to AdGuard spec."""
rules: List[str] = Field(..., description="List of filtering rules")
def validate_domain(domain: str) -> bool:
"""Validate domain name format."""
if not domain or len(domain) > 255:
return False
return bool(DOMAIN_PATTERN.match(domain))
def sanitize_rule(rule: str) -> str:
"""Sanitize and validate rule format."""
# Remove any whitespace and normalize
rule = rule.strip()
# Basic XSS/injection prevention
rule = rule.replace('<', '').replace('>', '').replace('"', '').replace("'", '')
return rule
def save_rules_backup(rules: List[str], action: str = "update") -> Path:
"""Save rules to backup file with timestamp."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = RULES_BACKUP_DIR / f"rules_{action}_{timestamp}.json"
with open(backup_file, 'w') as f:
json.dump({"rules": rules, "timestamp": timestamp}, f, indent=2)
logger.info(f"Saved rules backup to {backup_file}")
return backup_file
def load_rules_backup(backup_file: Path) -> List[str]:
"""Load rules from backup file."""
try:
with open(backup_file) as f:
data = json.load(f)
return data.get("rules", [])
except Exception as e:
logger.error(f"Error loading rules backup: {str(e)}")
return []
class AdGuardClient:
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
@@ -114,6 +160,10 @@ class AdGuardClient:
async def check_domain(self, domain: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home according to spec."""
# Validate domain format
if not validate_domain(domain):
raise AdGuardValidationError(f"Invalid domain format: {domain}")
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/check_host"
params = {"name": domain}
@@ -147,42 +197,6 @@ class AdGuardClient:
except Exception as e:
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list using set_rules endpoint according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/set_rules"
# Add as a whitelist rule according to AdGuard format
data = {"rules": [f"@@||{domain}^"]}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus:
"""Get the current filtering status according to spec."""
@@ -219,6 +233,71 @@ class AdGuardClient:
logger.error(f"Unexpected error while getting filter status: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list using AdGuard filtering API."""
# Validate domain format
if not validate_domain(domain):
raise AdGuardValidationError(f"Invalid domain format: {domain}")
await self._ensure_authenticated()
try:
# First get current user rules
status = await self.get_filter_status()
current_rules = status.user_rules if status else []
# Create sanitized whitelist rule
new_rule = sanitize_rule(f"@@||{domain}^")
# Save backup of current rules
old_backup = save_rules_backup(current_rules, "before")
# Add rule if not already present
if new_rule not in current_rules:
current_rules.append(new_rule)
# Save backup of new rules before updating
new_backup = save_rules_backup(current_rules, "after")
# Update rules
url = f"{self.base_url}/filtering/set_rules"
data = {"rules": current_rules}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
logger.info(f"Updating rules list with whitelisted domain: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully updated rules list with whitelisted domain: {domain}")
return True
except (httpx.ConnectError, httpx.HTTPError) as e:
# On error, try to restore from backup
logger.error(f"Error updating rules, attempting to restore from backup: {str(e)}")
if old_backup.exists():
try:
old_rules = load_rules_backup(old_backup)
if old_rules:
await self.client.post(url, json={"rules": old_rules}, headers=headers)
logger.info("Successfully restored rules from backup")
except Exception as restore_error:
logger.error(f"Failed to restore from backup: {str(restore_error)}")
raise
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()

View File

@@ -13,6 +13,7 @@ from .adguard import (
AdGuardError,
AdGuardConnectionError,
AdGuardAPIError,
AdGuardValidationError,
FilterStatus,
FilterCheckHostResponse,
SetRulesRequest
@@ -84,6 +85,11 @@ async def check_domain(name: str) -> FilterCheckHostResponse:
result = await client.check_domain(name)
logger.info(f"Domain check result: {result}")
return result
except AdGuardValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Error checking domain {name}: {str(e)}")
raise
@@ -106,19 +112,34 @@ async def add_to_whitelist(request: SetRulesRequest) -> Dict:
detail="Rules are required"
)
logger.info(f"Adding rules: {request.rules}")
# Extract domain from whitelist rule
rule = request.rules[0]
if not rule.startswith("@@||") or not rule.endswith("^"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid whitelist rule format"
)
domain = rule[4:-1] # Remove @@|| prefix and ^ suffix
logger.info(f"Adding domain to whitelist: {domain}")
try:
async with adguard.AdGuardClient() as client:
success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
success = await client.add_allowed_domain(domain)
if success:
return {"message": "Rules added successfully"}
return {"message": f"Domain {domain} added to whitelist"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add rules"
detail="Failed to add domain to whitelist"
)
except AdGuardValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)
except Exception as e:
logger.error(f"Error adding rules: {str(e)}")
logger.error(f"Error adding domain to whitelist: {str(e)}")
raise
@app.get(
@@ -144,6 +165,8 @@ async def adguard_exception_handler(request: Request, exc: AdGuardError) -> JSON
"""Handle AdGuard-related exceptions according to spec."""
if isinstance(exc, AdGuardConnectionError):
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
elif isinstance(exc, AdGuardValidationError):
status_code = status.HTTP_400_BAD_REQUEST
elif isinstance(exc, AdGuardAPIError):
status_code = status.HTTP_502_BAD_GATEWAY
else:

View File

@@ -53,10 +53,14 @@
unblockDiv.innerHTML = '';
}
} else {
let errorMsg = data.message || 'Unknown error occurred';
let errorType = response.status === 400 ? 'warning' : 'error';
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
<p class="font-bold">Error checking domain</p>
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
<p class="text-sm">${errorMsg}</p>
</div>`;
unblockDiv.innerHTML = '';
}
@@ -94,18 +98,25 @@
});
if (response.ok) {
const data = await response.json();
resultDiv.innerHTML = `
<div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4">
<p class="font-bold">Success!</p>
<p class="text-sm">Domain ${domain} has been added to the whitelist</p>
<p class="text-sm">${data.message}</p>
<p class="text-xs mt-2">A backup of the rules has been saved for safety.</p>
</div>`;
unblockDiv.innerHTML = '';
} else {
const data = await response.json();
let errorMsg = data.message || 'Unknown error occurred';
let errorType = response.status === 400 ? 'warning' : 'error';
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
<p class="font-bold">Error unblocking domain</p>
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
<p class="text-sm">${errorMsg}</p>
${errorType !== 'warning' ? '<p class="text-xs mt-2">Previous rules have been restored from backup.</p>' : ''}
</div>`;
}
} catch (error) {
@@ -113,6 +124,7 @@
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4">
<p class="font-bold">Error unblocking domain</p>
<p class="text-sm">${error.message}</p>
<p class="text-xs mt-2">Previous rules have been restored from backup.</p>
</div>`;
}
}
@@ -128,9 +140,10 @@
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2">
Enter Domain to Check
</label>
<input type="text" id="domain" name="domain" required
<input type="text" id="domain" name="domain" required pattern="^[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9](\.[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9])*$"
class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"
placeholder="example.com">
placeholder="example.com"
title="Please enter a valid domain name">
</div>
<button id="submit-btn" type="submit"
class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200">
@@ -144,6 +157,8 @@
<div class="mt-4 text-center text-gray-600 text-sm">
Make sure your AdGuard Home instance is running and properly configured in the .env file.
<br>
<span class="text-xs">Rules are automatically backed up before any changes.</span>
</div>
</div>
</body>