mirror of
https://github.com/pacnpal/simpleguardhome.git
synced 2025-12-20 04:21:13 -05:00
feat(api): enhance domain check and whitelist functionality with improved error handling and validation
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
from pydantic import BaseModel, Field
|
||||
import httpx
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from .config import settings
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for rule validation and backup
|
||||
DOMAIN_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$')
|
||||
RULES_BACKUP_DIR = Path("rules_backup")
|
||||
RULES_BACKUP_DIR.mkdir(exist_ok=True)
|
||||
|
||||
class AdGuardError(Exception):
|
||||
"""Base exception for AdGuard Home API errors."""
|
||||
pass
|
||||
@@ -20,6 +29,10 @@ class AdGuardAPIError(AdGuardError):
|
||||
"""Raised when AdGuard Home API returns an error."""
|
||||
pass
|
||||
|
||||
class AdGuardValidationError(AdGuardError):
|
||||
"""Raised when input validation fails."""
|
||||
pass
|
||||
|
||||
class ResultRule(BaseModel):
|
||||
"""Rule detail according to AdGuard spec."""
|
||||
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
|
||||
@@ -56,6 +69,39 @@ class SetRulesRequest(BaseModel):
|
||||
"""Request model for set_rules endpoint according to AdGuard spec."""
|
||||
rules: List[str] = Field(..., description="List of filtering rules")
|
||||
|
||||
def validate_domain(domain: str) -> bool:
|
||||
"""Validate domain name format."""
|
||||
if not domain or len(domain) > 255:
|
||||
return False
|
||||
return bool(DOMAIN_PATTERN.match(domain))
|
||||
|
||||
def sanitize_rule(rule: str) -> str:
|
||||
"""Sanitize and validate rule format."""
|
||||
# Remove any whitespace and normalize
|
||||
rule = rule.strip()
|
||||
# Basic XSS/injection prevention
|
||||
rule = rule.replace('<', '').replace('>', '').replace('"', '').replace("'", '')
|
||||
return rule
|
||||
|
||||
def save_rules_backup(rules: List[str], action: str = "update") -> Path:
|
||||
"""Save rules to backup file with timestamp."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_file = RULES_BACKUP_DIR / f"rules_{action}_{timestamp}.json"
|
||||
with open(backup_file, 'w') as f:
|
||||
json.dump({"rules": rules, "timestamp": timestamp}, f, indent=2)
|
||||
logger.info(f"Saved rules backup to {backup_file}")
|
||||
return backup_file
|
||||
|
||||
def load_rules_backup(backup_file: Path) -> List[str]:
|
||||
"""Load rules from backup file."""
|
||||
try:
|
||||
with open(backup_file) as f:
|
||||
data = json.load(f)
|
||||
return data.get("rules", [])
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading rules backup: {str(e)}")
|
||||
return []
|
||||
|
||||
class AdGuardClient:
|
||||
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
|
||||
|
||||
@@ -114,6 +160,10 @@ class AdGuardClient:
|
||||
|
||||
async def check_domain(self, domain: str) -> FilterCheckHostResponse:
|
||||
"""Check if a domain is blocked by AdGuard Home according to spec."""
|
||||
# Validate domain format
|
||||
if not validate_domain(domain):
|
||||
raise AdGuardValidationError(f"Invalid domain format: {domain}")
|
||||
|
||||
await self._ensure_authenticated()
|
||||
url = f"{self.base_url}/filtering/check_host"
|
||||
params = {"name": domain}
|
||||
@@ -147,42 +197,6 @@ class AdGuardClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
|
||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||
|
||||
async def add_allowed_domain(self, domain: str) -> bool:
|
||||
"""Add a domain to the allowed list using set_rules endpoint according to spec."""
|
||||
await self._ensure_authenticated()
|
||||
url = f"{self.base_url}/filtering/set_rules"
|
||||
# Add as a whitelist rule according to AdGuard format
|
||||
data = {"rules": [f"@@||{domain}^"]}
|
||||
headers = {}
|
||||
|
||||
if self._session_cookie:
|
||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||
|
||||
try:
|
||||
logger.info(f"Adding domain to whitelist: {domain}")
|
||||
response = await self.client.post(url, json=data, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
logger.info("Session expired, attempting reauth")
|
||||
await self.login()
|
||||
if self._session_cookie:
|
||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||
response = await self.client.post(url, json=data, headers=headers)
|
||||
|
||||
response.raise_for_status()
|
||||
logger.info(f"Successfully added {domain} to whitelist")
|
||||
return True
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
|
||||
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
|
||||
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
|
||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||
|
||||
async def get_filter_status(self) -> FilterStatus:
|
||||
"""Get the current filtering status according to spec."""
|
||||
@@ -219,6 +233,71 @@ class AdGuardClient:
|
||||
logger.error(f"Unexpected error while getting filter status: {str(e)}")
|
||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||
|
||||
async def add_allowed_domain(self, domain: str) -> bool:
|
||||
"""Add a domain to the allowed list using AdGuard filtering API."""
|
||||
# Validate domain format
|
||||
if not validate_domain(domain):
|
||||
raise AdGuardValidationError(f"Invalid domain format: {domain}")
|
||||
|
||||
await self._ensure_authenticated()
|
||||
|
||||
try:
|
||||
# First get current user rules
|
||||
status = await self.get_filter_status()
|
||||
current_rules = status.user_rules if status else []
|
||||
|
||||
# Create sanitized whitelist rule
|
||||
new_rule = sanitize_rule(f"@@||{domain}^")
|
||||
|
||||
# Save backup of current rules
|
||||
old_backup = save_rules_backup(current_rules, "before")
|
||||
|
||||
# Add rule if not already present
|
||||
if new_rule not in current_rules:
|
||||
current_rules.append(new_rule)
|
||||
|
||||
# Save backup of new rules before updating
|
||||
new_backup = save_rules_backup(current_rules, "after")
|
||||
|
||||
# Update rules
|
||||
url = f"{self.base_url}/filtering/set_rules"
|
||||
data = {"rules": current_rules}
|
||||
headers = {}
|
||||
|
||||
if self._session_cookie:
|
||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||
|
||||
logger.info(f"Updating rules list with whitelisted domain: {domain}")
|
||||
response = await self.client.post(url, json=data, headers=headers)
|
||||
|
||||
if response.status_code == 401:
|
||||
logger.info("Session expired, attempting reauth")
|
||||
await self.login()
|
||||
if self._session_cookie:
|
||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||
response = await self.client.post(url, json=data, headers=headers)
|
||||
|
||||
response.raise_for_status()
|
||||
logger.info(f"Successfully updated rules list with whitelisted domain: {domain}")
|
||||
return True
|
||||
|
||||
except (httpx.ConnectError, httpx.HTTPError) as e:
|
||||
# On error, try to restore from backup
|
||||
logger.error(f"Error updating rules, attempting to restore from backup: {str(e)}")
|
||||
if old_backup.exists():
|
||||
try:
|
||||
old_rules = load_rules_backup(old_backup)
|
||||
if old_rules:
|
||||
await self.client.post(url, json={"rules": old_rules}, headers=headers)
|
||||
logger.info("Successfully restored rules from backup")
|
||||
except Exception as restore_error:
|
||||
logger.error(f"Failed to restore from backup: {str(restore_error)}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
|
||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -13,6 +13,7 @@ from .adguard import (
|
||||
AdGuardError,
|
||||
AdGuardConnectionError,
|
||||
AdGuardAPIError,
|
||||
AdGuardValidationError,
|
||||
FilterStatus,
|
||||
FilterCheckHostResponse,
|
||||
SetRulesRequest
|
||||
@@ -84,6 +85,11 @@ async def check_domain(name: str) -> FilterCheckHostResponse:
|
||||
result = await client.check_domain(name)
|
||||
logger.info(f"Domain check result: {result}")
|
||||
return result
|
||||
except AdGuardValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking domain {name}: {str(e)}")
|
||||
raise
|
||||
@@ -106,19 +112,34 @@ async def add_to_whitelist(request: SetRulesRequest) -> Dict:
|
||||
detail="Rules are required"
|
||||
)
|
||||
|
||||
logger.info(f"Adding rules: {request.rules}")
|
||||
# Extract domain from whitelist rule
|
||||
rule = request.rules[0]
|
||||
if not rule.startswith("@@||") or not rule.endswith("^"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid whitelist rule format"
|
||||
)
|
||||
|
||||
domain = rule[4:-1] # Remove @@|| prefix and ^ suffix
|
||||
logger.info(f"Adding domain to whitelist: {domain}")
|
||||
|
||||
try:
|
||||
async with adguard.AdGuardClient() as client:
|
||||
success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
|
||||
success = await client.add_allowed_domain(domain)
|
||||
if success:
|
||||
return {"message": "Rules added successfully"}
|
||||
return {"message": f"Domain {domain} added to whitelist"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add rules"
|
||||
detail="Failed to add domain to whitelist"
|
||||
)
|
||||
except AdGuardValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding rules: {str(e)}")
|
||||
logger.error(f"Error adding domain to whitelist: {str(e)}")
|
||||
raise
|
||||
|
||||
@app.get(
|
||||
@@ -144,6 +165,8 @@ async def adguard_exception_handler(request: Request, exc: AdGuardError) -> JSON
|
||||
"""Handle AdGuard-related exceptions according to spec."""
|
||||
if isinstance(exc, AdGuardConnectionError):
|
||||
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||
elif isinstance(exc, AdGuardValidationError):
|
||||
status_code = status.HTTP_400_BAD_REQUEST
|
||||
elif isinstance(exc, AdGuardAPIError):
|
||||
status_code = status.HTTP_502_BAD_GATEWAY
|
||||
else:
|
||||
|
||||
@@ -53,10 +53,14 @@
|
||||
unblockDiv.innerHTML = '';
|
||||
}
|
||||
} else {
|
||||
let errorMsg = data.message || 'Unknown error occurred';
|
||||
let errorType = response.status === 400 ? 'warning' : 'error';
|
||||
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
|
||||
|
||||
resultDiv.innerHTML = `
|
||||
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
|
||||
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
|
||||
<p class="font-bold">Error checking domain</p>
|
||||
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
|
||||
<p class="text-sm">${errorMsg}</p>
|
||||
</div>`;
|
||||
unblockDiv.innerHTML = '';
|
||||
}
|
||||
@@ -94,18 +98,25 @@
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
resultDiv.innerHTML = `
|
||||
<div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4">
|
||||
<p class="font-bold">Success!</p>
|
||||
<p class="text-sm">Domain ${domain} has been added to the whitelist</p>
|
||||
<p class="text-sm">${data.message}</p>
|
||||
<p class="text-xs mt-2">A backup of the rules has been saved for safety.</p>
|
||||
</div>`;
|
||||
unblockDiv.innerHTML = '';
|
||||
} else {
|
||||
const data = await response.json();
|
||||
let errorMsg = data.message || 'Unknown error occurred';
|
||||
let errorType = response.status === 400 ? 'warning' : 'error';
|
||||
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
|
||||
|
||||
resultDiv.innerHTML = `
|
||||
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
|
||||
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
|
||||
<p class="font-bold">Error unblocking domain</p>
|
||||
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
|
||||
<p class="text-sm">${errorMsg}</p>
|
||||
${errorType !== 'warning' ? '<p class="text-xs mt-2">Previous rules have been restored from backup.</p>' : ''}
|
||||
</div>`;
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -113,6 +124,7 @@
|
||||
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4">
|
||||
<p class="font-bold">Error unblocking domain</p>
|
||||
<p class="text-sm">${error.message}</p>
|
||||
<p class="text-xs mt-2">Previous rules have been restored from backup.</p>
|
||||
</div>`;
|
||||
}
|
||||
}
|
||||
@@ -128,9 +140,10 @@
|
||||
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2">
|
||||
Enter Domain to Check
|
||||
</label>
|
||||
<input type="text" id="domain" name="domain" required
|
||||
<input type="text" id="domain" name="domain" required pattern="^[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9](\.[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9])*$"
|
||||
class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"
|
||||
placeholder="example.com">
|
||||
placeholder="example.com"
|
||||
title="Please enter a valid domain name">
|
||||
</div>
|
||||
<button id="submit-btn" type="submit"
|
||||
class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200">
|
||||
@@ -144,6 +157,8 @@
|
||||
|
||||
<div class="mt-4 text-center text-gray-600 text-sm">
|
||||
Make sure your AdGuard Home instance is running and properly configured in the .env file.
|
||||
<br>
|
||||
<span class="text-xs">Rules are automatically backed up before any changes.</span>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
Reference in New Issue
Block a user