mirror of
https://github.com/pacnpal/simpleguardhome.git
synced 2025-12-21 04:51:13 -05:00
feat(api): enhance domain check and whitelist functionality with improved error handling and validation
This commit is contained in:
@@ -1,13 +1,22 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional, Set
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import httpx
|
import httpx
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants for rule validation and backup
|
||||||
|
DOMAIN_PATTERN = re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$')
|
||||||
|
RULES_BACKUP_DIR = Path("rules_backup")
|
||||||
|
RULES_BACKUP_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
class AdGuardError(Exception):
|
class AdGuardError(Exception):
|
||||||
"""Base exception for AdGuard Home API errors."""
|
"""Base exception for AdGuard Home API errors."""
|
||||||
pass
|
pass
|
||||||
@@ -20,6 +29,10 @@ class AdGuardAPIError(AdGuardError):
|
|||||||
"""Raised when AdGuard Home API returns an error."""
|
"""Raised when AdGuard Home API returns an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class AdGuardValidationError(AdGuardError):
|
||||||
|
"""Raised when input validation fails."""
|
||||||
|
pass
|
||||||
|
|
||||||
class ResultRule(BaseModel):
|
class ResultRule(BaseModel):
|
||||||
"""Rule detail according to AdGuard spec."""
|
"""Rule detail according to AdGuard spec."""
|
||||||
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
|
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
|
||||||
@@ -56,6 +69,39 @@ class SetRulesRequest(BaseModel):
|
|||||||
"""Request model for set_rules endpoint according to AdGuard spec."""
|
"""Request model for set_rules endpoint according to AdGuard spec."""
|
||||||
rules: List[str] = Field(..., description="List of filtering rules")
|
rules: List[str] = Field(..., description="List of filtering rules")
|
||||||
|
|
||||||
|
def validate_domain(domain: str) -> bool:
|
||||||
|
"""Validate domain name format."""
|
||||||
|
if not domain or len(domain) > 255:
|
||||||
|
return False
|
||||||
|
return bool(DOMAIN_PATTERN.match(domain))
|
||||||
|
|
||||||
|
def sanitize_rule(rule: str) -> str:
|
||||||
|
"""Sanitize and validate rule format."""
|
||||||
|
# Remove any whitespace and normalize
|
||||||
|
rule = rule.strip()
|
||||||
|
# Basic XSS/injection prevention
|
||||||
|
rule = rule.replace('<', '').replace('>', '').replace('"', '').replace("'", '')
|
||||||
|
return rule
|
||||||
|
|
||||||
|
def save_rules_backup(rules: List[str], action: str = "update") -> Path:
|
||||||
|
"""Save rules to backup file with timestamp."""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_file = RULES_BACKUP_DIR / f"rules_{action}_{timestamp}.json"
|
||||||
|
with open(backup_file, 'w') as f:
|
||||||
|
json.dump({"rules": rules, "timestamp": timestamp}, f, indent=2)
|
||||||
|
logger.info(f"Saved rules backup to {backup_file}")
|
||||||
|
return backup_file
|
||||||
|
|
||||||
|
def load_rules_backup(backup_file: Path) -> List[str]:
|
||||||
|
"""Load rules from backup file."""
|
||||||
|
try:
|
||||||
|
with open(backup_file) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data.get("rules", [])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading rules backup: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
class AdGuardClient:
|
class AdGuardClient:
|
||||||
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
|
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
|
||||||
|
|
||||||
@@ -114,6 +160,10 @@ class AdGuardClient:
|
|||||||
|
|
||||||
async def check_domain(self, domain: str) -> FilterCheckHostResponse:
|
async def check_domain(self, domain: str) -> FilterCheckHostResponse:
|
||||||
"""Check if a domain is blocked by AdGuard Home according to spec."""
|
"""Check if a domain is blocked by AdGuard Home according to spec."""
|
||||||
|
# Validate domain format
|
||||||
|
if not validate_domain(domain):
|
||||||
|
raise AdGuardValidationError(f"Invalid domain format: {domain}")
|
||||||
|
|
||||||
await self._ensure_authenticated()
|
await self._ensure_authenticated()
|
||||||
url = f"{self.base_url}/filtering/check_host"
|
url = f"{self.base_url}/filtering/check_host"
|
||||||
params = {"name": domain}
|
params = {"name": domain}
|
||||||
@@ -148,42 +198,6 @@ class AdGuardClient:
|
|||||||
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
|
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
|
||||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
async def add_allowed_domain(self, domain: str) -> bool:
|
|
||||||
"""Add a domain to the allowed list using set_rules endpoint according to spec."""
|
|
||||||
await self._ensure_authenticated()
|
|
||||||
url = f"{self.base_url}/filtering/set_rules"
|
|
||||||
# Add as a whitelist rule according to AdGuard format
|
|
||||||
data = {"rules": [f"@@||{domain}^"]}
|
|
||||||
headers = {}
|
|
||||||
|
|
||||||
if self._session_cookie:
|
|
||||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Adding domain to whitelist: {domain}")
|
|
||||||
response = await self.client.post(url, json=data, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code == 401:
|
|
||||||
logger.info("Session expired, attempting reauth")
|
|
||||||
await self.login()
|
|
||||||
if self._session_cookie:
|
|
||||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
|
||||||
response = await self.client.post(url, json=data, headers=headers)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
logger.info(f"Successfully added {domain} to whitelist")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except httpx.ConnectError as e:
|
|
||||||
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
|
|
||||||
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
|
|
||||||
except httpx.HTTPError as e:
|
|
||||||
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
|
|
||||||
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
|
|
||||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
|
||||||
|
|
||||||
async def get_filter_status(self) -> FilterStatus:
|
async def get_filter_status(self) -> FilterStatus:
|
||||||
"""Get the current filtering status according to spec."""
|
"""Get the current filtering status according to spec."""
|
||||||
await self._ensure_authenticated()
|
await self._ensure_authenticated()
|
||||||
@@ -219,6 +233,71 @@ class AdGuardClient:
|
|||||||
logger.error(f"Unexpected error while getting filter status: {str(e)}")
|
logger.error(f"Unexpected error while getting filter status: {str(e)}")
|
||||||
raise AdGuardError(f"Unexpected error: {str(e)}")
|
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
|
async def add_allowed_domain(self, domain: str) -> bool:
|
||||||
|
"""Add a domain to the allowed list using AdGuard filtering API."""
|
||||||
|
# Validate domain format
|
||||||
|
if not validate_domain(domain):
|
||||||
|
raise AdGuardValidationError(f"Invalid domain format: {domain}")
|
||||||
|
|
||||||
|
await self._ensure_authenticated()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First get current user rules
|
||||||
|
status = await self.get_filter_status()
|
||||||
|
current_rules = status.user_rules if status else []
|
||||||
|
|
||||||
|
# Create sanitized whitelist rule
|
||||||
|
new_rule = sanitize_rule(f"@@||{domain}^")
|
||||||
|
|
||||||
|
# Save backup of current rules
|
||||||
|
old_backup = save_rules_backup(current_rules, "before")
|
||||||
|
|
||||||
|
# Add rule if not already present
|
||||||
|
if new_rule not in current_rules:
|
||||||
|
current_rules.append(new_rule)
|
||||||
|
|
||||||
|
# Save backup of new rules before updating
|
||||||
|
new_backup = save_rules_backup(current_rules, "after")
|
||||||
|
|
||||||
|
# Update rules
|
||||||
|
url = f"{self.base_url}/filtering/set_rules"
|
||||||
|
data = {"rules": current_rules}
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if self._session_cookie:
|
||||||
|
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||||
|
|
||||||
|
logger.info(f"Updating rules list with whitelisted domain: {domain}")
|
||||||
|
response = await self.client.post(url, json=data, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code == 401:
|
||||||
|
logger.info("Session expired, attempting reauth")
|
||||||
|
await self.login()
|
||||||
|
if self._session_cookie:
|
||||||
|
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||||
|
response = await self.client.post(url, json=data, headers=headers)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
logger.info(f"Successfully updated rules list with whitelisted domain: {domain}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except (httpx.ConnectError, httpx.HTTPError) as e:
|
||||||
|
# On error, try to restore from backup
|
||||||
|
logger.error(f"Error updating rules, attempting to restore from backup: {str(e)}")
|
||||||
|
if old_backup.exists():
|
||||||
|
try:
|
||||||
|
old_rules = load_rules_backup(old_backup)
|
||||||
|
if old_rules:
|
||||||
|
await self.client.post(url, json={"rules": old_rules}, headers=headers)
|
||||||
|
logger.info("Successfully restored rules from backup")
|
||||||
|
except Exception as restore_error:
|
||||||
|
logger.error(f"Failed to restore from backup: {str(restore_error)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
|
||||||
|
raise AdGuardError(f"Unexpected error: {str(e)}")
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close the HTTP client."""
|
"""Close the HTTP client."""
|
||||||
await self.client.aclose()
|
await self.client.aclose()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from .adguard import (
|
|||||||
AdGuardError,
|
AdGuardError,
|
||||||
AdGuardConnectionError,
|
AdGuardConnectionError,
|
||||||
AdGuardAPIError,
|
AdGuardAPIError,
|
||||||
|
AdGuardValidationError,
|
||||||
FilterStatus,
|
FilterStatus,
|
||||||
FilterCheckHostResponse,
|
FilterCheckHostResponse,
|
||||||
SetRulesRequest
|
SetRulesRequest
|
||||||
@@ -84,6 +85,11 @@ async def check_domain(name: str) -> FilterCheckHostResponse:
|
|||||||
result = await client.check_domain(name)
|
result = await client.check_domain(name)
|
||||||
logger.info(f"Domain check result: {result}")
|
logger.info(f"Domain check result: {result}")
|
||||||
return result
|
return result
|
||||||
|
except AdGuardValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error checking domain {name}: {str(e)}")
|
logger.error(f"Error checking domain {name}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
@@ -106,19 +112,34 @@ async def add_to_whitelist(request: SetRulesRequest) -> Dict:
|
|||||||
detail="Rules are required"
|
detail="Rules are required"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Adding rules: {request.rules}")
|
# Extract domain from whitelist rule
|
||||||
|
rule = request.rules[0]
|
||||||
|
if not rule.startswith("@@||") or not rule.endswith("^"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid whitelist rule format"
|
||||||
|
)
|
||||||
|
|
||||||
|
domain = rule[4:-1] # Remove @@|| prefix and ^ suffix
|
||||||
|
logger.info(f"Adding domain to whitelist: {domain}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with adguard.AdGuardClient() as client:
|
async with adguard.AdGuardClient() as client:
|
||||||
success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
|
success = await client.add_allowed_domain(domain)
|
||||||
if success:
|
if success:
|
||||||
return {"message": "Rules added successfully"}
|
return {"message": f"Domain {domain} added to whitelist"}
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to add rules"
|
detail="Failed to add domain to whitelist"
|
||||||
)
|
)
|
||||||
|
except AdGuardValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error adding rules: {str(e)}")
|
logger.error(f"Error adding domain to whitelist: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
@@ -144,6 +165,8 @@ async def adguard_exception_handler(request: Request, exc: AdGuardError) -> JSON
|
|||||||
"""Handle AdGuard-related exceptions according to spec."""
|
"""Handle AdGuard-related exceptions according to spec."""
|
||||||
if isinstance(exc, AdGuardConnectionError):
|
if isinstance(exc, AdGuardConnectionError):
|
||||||
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
status_code = status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
elif isinstance(exc, AdGuardValidationError):
|
||||||
|
status_code = status.HTTP_400_BAD_REQUEST
|
||||||
elif isinstance(exc, AdGuardAPIError):
|
elif isinstance(exc, AdGuardAPIError):
|
||||||
status_code = status.HTTP_502_BAD_GATEWAY
|
status_code = status.HTTP_502_BAD_GATEWAY
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -53,10 +53,14 @@
|
|||||||
unblockDiv.innerHTML = '';
|
unblockDiv.innerHTML = '';
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
let errorMsg = data.message || 'Unknown error occurred';
|
||||||
|
let errorType = response.status === 400 ? 'warning' : 'error';
|
||||||
|
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
|
||||||
|
|
||||||
resultDiv.innerHTML = `
|
resultDiv.innerHTML = `
|
||||||
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
|
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
|
||||||
<p class="font-bold">Error checking domain</p>
|
<p class="font-bold">Error checking domain</p>
|
||||||
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
|
<p class="text-sm">${errorMsg}</p>
|
||||||
</div>`;
|
</div>`;
|
||||||
unblockDiv.innerHTML = '';
|
unblockDiv.innerHTML = '';
|
||||||
}
|
}
|
||||||
@@ -94,18 +98,25 @@
|
|||||||
});
|
});
|
||||||
|
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
|
const data = await response.json();
|
||||||
resultDiv.innerHTML = `
|
resultDiv.innerHTML = `
|
||||||
<div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4">
|
<div class="bg-green-100 border-l-4 border-green-500 text-green-700 p-4">
|
||||||
<p class="font-bold">Success!</p>
|
<p class="font-bold">Success!</p>
|
||||||
<p class="text-sm">Domain ${domain} has been added to the whitelist</p>
|
<p class="text-sm">${data.message}</p>
|
||||||
|
<p class="text-xs mt-2">A backup of the rules has been saved for safety.</p>
|
||||||
</div>`;
|
</div>`;
|
||||||
unblockDiv.innerHTML = '';
|
unblockDiv.innerHTML = '';
|
||||||
} else {
|
} else {
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
let errorMsg = data.message || 'Unknown error occurred';
|
||||||
|
let errorType = response.status === 400 ? 'warning' : 'error';
|
||||||
|
let bgColor = errorType === 'warning' ? 'yellow' : 'red';
|
||||||
|
|
||||||
resultDiv.innerHTML = `
|
resultDiv.innerHTML = `
|
||||||
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
|
<div class="bg-${bgColor}-100 border-l-4 border-${bgColor}-500 text-${bgColor}-700 p-4">
|
||||||
<p class="font-bold">Error unblocking domain</p>
|
<p class="font-bold">Error unblocking domain</p>
|
||||||
<p class="text-sm">${data.message || 'Unknown error occurred'}</p>
|
<p class="text-sm">${errorMsg}</p>
|
||||||
|
${errorType !== 'warning' ? '<p class="text-xs mt-2">Previous rules have been restored from backup.</p>' : ''}
|
||||||
</div>`;
|
</div>`;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -113,6 +124,7 @@
|
|||||||
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4">
|
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4">
|
||||||
<p class="font-bold">Error unblocking domain</p>
|
<p class="font-bold">Error unblocking domain</p>
|
||||||
<p class="text-sm">${error.message}</p>
|
<p class="text-sm">${error.message}</p>
|
||||||
|
<p class="text-xs mt-2">Previous rules have been restored from backup.</p>
|
||||||
</div>`;
|
</div>`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -128,9 +140,10 @@
|
|||||||
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2">
|
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2">
|
||||||
Enter Domain to Check
|
Enter Domain to Check
|
||||||
</label>
|
</label>
|
||||||
<input type="text" id="domain" name="domain" required
|
<input type="text" id="domain" name="domain" required pattern="^[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9](\.[a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9])*$"
|
||||||
class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"
|
class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 leading-tight focus:outline-none focus:shadow-outline"
|
||||||
placeholder="example.com">
|
placeholder="example.com"
|
||||||
|
title="Please enter a valid domain name">
|
||||||
</div>
|
</div>
|
||||||
<button id="submit-btn" type="submit"
|
<button id="submit-btn" type="submit"
|
||||||
class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200">
|
class="bg-green-500 hover:bg-green-700 text-white font-bold py-2 px-4 rounded w-full transition-colors duration-200">
|
||||||
@@ -144,6 +157,8 @@
|
|||||||
|
|
||||||
<div class="mt-4 text-center text-gray-600 text-sm">
|
<div class="mt-4 text-center text-gray-600 text-sm">
|
||||||
Make sure your AdGuard Home instance is running and properly configured in the .env file.
|
Make sure your AdGuard Home instance is running and properly configured in the .env file.
|
||||||
|
<br>
|
||||||
|
<span class="text-xs">Rules are automatically backed up before any changes.</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
|
|||||||
Reference in New Issue
Block a user