From 91e20e37113f081c520abb6ba928d1ed0da9be8d Mon Sep 17 00:00:00 2001 From: pacnpal <183241239+pacnpal@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:20:22 +0000 Subject: [PATCH] feat(api): enhance domain check and whitelist functionality with improved error handling and validation --- src/simpleguardhome/adguard.py | 153 +++++++++++++++++------ src/simpleguardhome/main.py | 33 ++++- src/simpleguardhome/templates/index.html | 29 +++-- 3 files changed, 166 insertions(+), 49 deletions(-) diff --git a/src/simpleguardhome/adguard.py b/src/simpleguardhome/adguard.py index b8ef86e..d8fc63d 100644 --- a/src/simpleguardhome/adguard.py +++ b/src/simpleguardhome/adguard.py @@ -1,13 +1,22 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set from pydantic import BaseModel, Field import httpx import logging +import re +import json +from pathlib import Path +from datetime import datetime from .config import settings # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Constants for rule validation and backup +DOMAIN_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$') +RULES_BACKUP_DIR = Path("rules_backup") +RULES_BACKUP_DIR.mkdir(exist_ok=True) + class AdGuardError(Exception): """Base exception for AdGuard Home API errors.""" pass @@ -20,6 +29,10 @@ class AdGuardAPIError(AdGuardError): """Raised when AdGuard Home API returns an error.""" pass +class AdGuardValidationError(AdGuardError): + """Raised when input validation fails.""" + pass + class ResultRule(BaseModel): """Rule detail according to AdGuard spec.""" filter_list_id: Optional[int] = Field(None, description="Filter list ID") @@ -56,6 +69,39 @@ class SetRulesRequest(BaseModel): """Request model for set_rules endpoint according to AdGuard spec.""" rules: List[str] = Field(..., description="List of filtering rules") +def validate_domain(domain: str) -> bool: + """Validate domain name format.""" + if not domain or len(domain) > 255: + return False + return bool(DOMAIN_PATTERN.match(domain)) + +def sanitize_rule(rule: str) -> str: + """Sanitize and validate rule format.""" + # Remove any whitespace and normalize + rule = rule.strip() + # Basic XSS/injection prevention + rule = rule.replace('<', '').replace('>', '').replace('"', '').replace("'", '') + return rule + +def save_rules_backup(rules: List[str], action: str = "update") -> Path: + """Save rules to backup file with timestamp.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = RULES_BACKUP_DIR / f"rules_{action}_{timestamp}.json" + with open(backup_file, 'w') as f: + json.dump({"rules": rules, "timestamp": timestamp}, f, indent=2) + logger.info(f"Saved rules backup to {backup_file}") + return backup_file + +def load_rules_backup(backup_file: Path) -> List[str]: + """Load rules from backup file.""" + try: + with open(backup_file) as f: + data = json.load(f) + return data.get("rules", []) + except Exception as e: + logger.error(f"Error loading rules backup: {str(e)}") + return [] + class AdGuardClient: """Client for interacting with AdGuard Home API according to OpenAPI spec.""" @@ -114,6 +160,10 @@ class AdGuardClient: async def check_domain(self, domain: str) -> FilterCheckHostResponse: """Check if a domain is blocked by AdGuard Home according to spec.""" + # Validate domain format + if not validate_domain(domain): + raise AdGuardValidationError(f"Invalid domain format: {domain}") + await self._ensure_authenticated() url = f"{self.base_url}/filtering/check_host" params = {"name": domain} @@ -147,42 +197,6 @@ class AdGuardClient: except Exception as e: logger.error(f"Unexpected error while checking domain {domain}: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}") - - async def add_allowed_domain(self, domain: str) -> bool: - """Add a domain to the allowed list using set_rules endpoint according to spec.""" - await self._ensure_authenticated() - url = f"{self.base_url}/filtering/set_rules" - # Add as a whitelist rule according to AdGuard format - data = {"rules": [f"@@||{domain}^"]} - headers = {} - - if self._session_cookie: - headers['Cookie'] = f'agh_session={self._session_cookie}' - - try: - logger.info(f"Adding domain to whitelist: {domain}") - response = await self.client.post(url, json=data, headers=headers) - - if response.status_code == 401: - logger.info("Session expired, attempting reauth") - await self.login() - if self._session_cookie: - headers['Cookie'] = f'agh_session={self._session_cookie}' - response = await self.client.post(url, json=data, headers=headers) - - response.raise_for_status() - logger.info(f"Successfully added {domain} to whitelist") - return True - - except httpx.ConnectError as e: - logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}") - raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}") - except httpx.HTTPError as e: - logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}") - raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}") - except Exception as e: - logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}") - raise AdGuardError(f"Unexpected error: {str(e)}") async def get_filter_status(self) -> FilterStatus: """Get the current filtering status according to spec.""" @@ -219,6 +233,71 @@ class AdGuardClient: logger.error(f"Unexpected error while getting filter status: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}") + async def add_allowed_domain(self, domain: str) -> bool: + """Add a domain to the allowed list using AdGuard filtering API.""" + # Validate domain format + if not validate_domain(domain): + raise AdGuardValidationError(f"Invalid domain format: {domain}") + + await self._ensure_authenticated() + + try: + # First get current user rules + status = await self.get_filter_status() + current_rules = status.user_rules if status else [] + + # Create sanitized whitelist rule + new_rule = sanitize_rule(f"@@||{domain}^") + + # Save backup of current rules + old_backup = save_rules_backup(current_rules, "before") + + # Add rule if not already present + if new_rule not in current_rules: + current_rules.append(new_rule) + + # Save backup of new rules before updating + new_backup = save_rules_backup(current_rules, "after") + + # Update rules + url = f"{self.base_url}/filtering/set_rules" + data = {"rules": current_rules} + headers = {} + + if self._session_cookie: + headers['Cookie'] = f'agh_session={self._session_cookie}' + + logger.info(f"Updating rules list with whitelisted domain: {domain}") + response = await self.client.post(url, json=data, headers=headers) + + if response.status_code == 401: + logger.info("Session expired, attempting reauth") + await self.login() + if self._session_cookie: + headers['Cookie'] = f'agh_session={self._session_cookie}' + response = await self.client.post(url, json=data, headers=headers) + + response.raise_for_status() + logger.info(f"Successfully updated rules list with whitelisted domain: {domain}") + return True + + except (httpx.ConnectError, httpx.HTTPError) as e: + # On error, try to restore from backup + logger.error(f"Error updating rules, attempting to restore from backup: {str(e)}") + if old_backup.exists(): + try: + old_rules = load_rules_backup(old_backup) + if old_rules: + await self.client.post(url, json={"rules": old_rules}, headers=headers) + logger.info("Successfully restored rules from backup") + except Exception as restore_error: + logger.error(f"Failed to restore from backup: {str(restore_error)}") + raise + + except Exception as e: + logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}") + raise AdGuardError(f"Unexpected error: {str(e)}") + async def close(self): """Close the HTTP client.""" await self.client.aclose() diff --git a/src/simpleguardhome/main.py b/src/simpleguardhome/main.py index 7c30a3f..d02332d 100644 --- a/src/simpleguardhome/main.py +++ b/src/simpleguardhome/main.py @@ -13,6 +13,7 @@ from .adguard import ( AdGuardError, AdGuardConnectionError, AdGuardAPIError, + AdGuardValidationError, FilterStatus, FilterCheckHostResponse, SetRulesRequest @@ -84,6 +85,11 @@ async def check_domain(name: str) -> FilterCheckHostResponse: result = await client.check_domain(name) logger.info(f"Domain check result: {result}") return result + except AdGuardValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) except Exception as e: logger.error(f"Error checking domain {name}: {str(e)}") raise @@ -106,19 +112,34 @@ async def add_to_whitelist(request: SetRulesRequest) -> Dict: detail="Rules are required" ) - logger.info(f"Adding rules: {request.rules}") + # Extract domain from whitelist rule + rule = request.rules[0] + if not rule.startswith("@@||") or not rule.endswith("^"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid whitelist rule format" + ) + + domain = rule[4:-1] # Remove @@|| prefix and ^ suffix + logger.info(f"Adding domain to whitelist: {domain}") + try: async with adguard.AdGuardClient() as client: - success = await client.add_allowed_domain(request.rules[0].strip("@@||^")) + success = await client.add_allowed_domain(domain) if success: - return {"message": "Rules added successfully"} + return {"message": f"Domain {domain} added to whitelist"} else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to add rules" + detail="Failed to add domain to whitelist" ) + except AdGuardValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) except Exception as e: - logger.error(f"Error adding rules: {str(e)}") + logger.error(f"Error adding domain to whitelist: {str(e)}") raise @app.get( @@ -144,6 +165,8 @@ async def adguard_exception_handler(request: Request, exc: AdGuardError) -> JSON """Handle AdGuard-related exceptions according to spec.""" if isinstance(exc, AdGuardConnectionError): status_code = status.HTTP_503_SERVICE_UNAVAILABLE + elif isinstance(exc, AdGuardValidationError): + status_code = status.HTTP_400_BAD_REQUEST elif isinstance(exc, AdGuardAPIError): status_code = status.HTTP_502_BAD_GATEWAY else: diff --git a/src/simpleguardhome/templates/index.html b/src/simpleguardhome/templates/index.html index acdd839..87e125f 100644 --- a/src/simpleguardhome/templates/index.html +++ b/src/simpleguardhome/templates/index.html @@ -53,10 +53,14 @@ unblockDiv.innerHTML = ''; } } else { + let errorMsg = data.message || 'Unknown error occurred'; + let errorType = response.status === 400 ? 'warning' : 'error'; + let bgColor = errorType === 'warning' ? 'yellow' : 'red'; + resultDiv.innerHTML = ` -
Error checking domain
-${data.message || 'Unknown error occurred'}
+${errorMsg}
Success!
-Domain ${domain} has been added to the whitelist
+${data.message}
+A backup of the rules has been saved for safety.
Error unblocking domain
-${data.message || 'Unknown error occurred'}
+${errorMsg}
+ ${errorType !== 'warning' ? 'Previous rules have been restored from backup.
' : ''}Error unblocking domain
${error.message}
+Previous rules have been restored from backup.