feat(api): update endpoint names and response models to align with AdGuard spec

This commit is contained in:
pacnpal
2025-01-28 17:01:50 +00:00
parent f4b28cef51
commit 7615f074a8
3 changed files with 99 additions and 152 deletions

View File

@@ -20,55 +20,41 @@ class AdGuardAPIError(AdGuardError):
"""Raised when AdGuard Home API returns an error.""" """Raised when AdGuard Home API returns an error."""
pass pass
# Response models matching AdGuard Home API spec class ResultRule(BaseModel):
"""Rule detail according to AdGuard spec."""
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
text: Optional[str] = Field(None, description="Rule text")
class FilterCheckHostResponse(BaseModel):
"""Response model for check_host endpoint according to AdGuard spec."""
reason: str = Field(..., description="Request filtering status")
filter_id: Optional[int] = Field(None, deprecated=True)
rule: Optional[str] = Field(None, deprecated=True)
rules: Optional[List[ResultRule]] = Field(None, description="Applied rules")
service_name: Optional[str] = Field(None, description="Blocked service name")
cname: Optional[str] = Field(None, description="CNAME value if rewritten")
ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
class Filter(BaseModel): class Filter(BaseModel):
"""Filter subscription info according to AdGuard spec.""" """Filter subscription info according to AdGuard spec."""
enabled: bool enabled: bool
id: int = Field(..., description="Filter ID", example=1234) id: int = Field(..., description="Filter ID")
name: str = Field(..., example="AdGuard Simplified Domain Names filter") name: str = Field(..., description="Filter name")
rules_count: int = Field(..., description="Number of rules in filter", example=5912) rules_count: int = Field(..., description="Number of rules")
url: str = Field(..., example="https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt") url: str = Field(..., description="Filter URL")
last_updated: Optional[str] = Field(None, example="2018-10-30T12:18:57+03:00") last_updated: Optional[str] = None
class FilterStatus(BaseModel): class FilterStatus(BaseModel):
"""Filtering settings according to AdGuard spec.""" """Filtering settings according to AdGuard spec."""
enabled: bool = Field(..., description="Whether filtering is enabled") enabled: bool
interval: Optional[int] = Field(None, description="Update interval in hours") interval: Optional[int] = None
filters: List[Filter] = Field(default_factory=list) filters: List[Filter] = Field(default_factory=list)
whitelist_filters: List[Filter] = Field(default_factory=list) whitelist_filters: List[Filter] = Field(default_factory=list)
user_rules: List[str] = Field(default_factory=list) user_rules: List[str] = Field(default_factory=list)
class DnsAnswer(BaseModel): class SetRulesRequest(BaseModel):
"""DNS answer section according to AdGuard spec.""" """Request model for set_rules endpoint according to AdGuard spec."""
ttl: int = Field(..., description="Time to live") rules: List[str] = Field(..., description="List of filtering rules")
type: str = Field(..., description="Record type", example="A")
value: str = Field(..., description="Record value", example="217.69.139.201")
class DomainCheckResult(BaseModel):
"""Response model for check_host endpoint according to AdGuard spec."""
reason: str = Field(
...,
description="Request filtering status",
enum=[
"NotFilteredNotFound",
"NotFilteredWhiteList",
"NotFilteredError",
"FilteredBlackList",
"FilteredSafeBrowsing",
"FilteredParental",
"FilteredInvalid",
"FilteredSafeSearch",
"FilteredBlockedService",
"Rewrite",
"RewriteEtcHosts",
"RewriteRule"
]
)
filter_id: Optional[int] = Field(None, description="ID of the filter list containing the rule")
rule: Optional[str] = Field(None, description="Applied filtering rule")
service_name: Optional[str] = Field(None, description="Blocked service name if applicable")
cname: Optional[str] = Field(None, description="CNAME value if rewritten")
ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
class AdGuardClient: class AdGuardClient:
"""Client for interacting with AdGuard Home API according to OpenAPI spec.""" """Client for interacting with AdGuard Home API according to OpenAPI spec."""
@@ -90,15 +76,7 @@ class AdGuardClient:
logger.info(f"Initialized AdGuard Home client with base URL: {self.base_url}") logger.info(f"Initialized AdGuard Home client with base URL: {self.base_url}")
async def login(self) -> bool: async def login(self) -> bool:
"""Authenticate with AdGuard Home and get session cookie. """Authenticate with AdGuard Home and get session cookie."""
Returns:
bool: True if authentication successful
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If authentication fails
"""
if not self._auth: if not self._auth:
logger.warning("No credentials configured, skipping authentication") logger.warning("No credentials configured, skipping authentication")
return False return False
@@ -134,19 +112,8 @@ class AdGuardClient:
if not self._session_cookie: if not self._session_cookie:
await self.login() await self.login()
async def check_domain(self, domain: str) -> DomainCheckResult: async def check_domain(self, domain: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home. """Check if a domain is blocked by AdGuard Home according to spec."""
Args:
domain: The domain to check
Returns:
DomainCheckResult according to AdGuard spec
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If AdGuard Home API returns an error
"""
await self._ensure_authenticated() await self._ensure_authenticated()
url = f"{self.base_url}/filtering/check_host" url = f"{self.base_url}/filtering/check_host"
params = {"name": domain} params = {"name": domain}
@@ -169,7 +136,7 @@ class AdGuardClient:
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
logger.info(f"Domain check result for {domain}: {result}") logger.info(f"Domain check result for {domain}: {result}")
return DomainCheckResult(**result) return FilterCheckHostResponse(**result)
except httpx.ConnectError as e: except httpx.ConnectError as e:
logger.error(f"Connection error while checking domain {domain}: {str(e)}") logger.error(f"Connection error while checking domain {domain}: {str(e)}")
@@ -181,16 +148,44 @@ class AdGuardClient:
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}") logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus: async def add_allowed_domain(self, domain: str) -> bool:
"""Get the current filtering status. """Add a domain to the allowed list using set_rules endpoint according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/set_rules"
# Add as a whitelist rule according to AdGuard format
data = {"rules": [f"@@||{domain}^"]}
headers = {}
Returns: if self._session_cookie:
FilterStatus according to AdGuard spec headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
Raises: if response.status_code == 401:
AdGuardConnectionError: If connection to AdGuard Home fails logger.info("Session expired, attempting reauth")
AdGuardAPIError: If AdGuard Home API returns an error await self.login()
""" if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus:
"""Get the current filtering status according to spec."""
await self._ensure_authenticated() await self._ensure_authenticated()
url = f"{self.base_url}/filtering/status" url = f"{self.base_url}/filtering/status"
headers = {} headers = {}
@@ -224,52 +219,6 @@ class AdGuardClient:
logger.error(f"Unexpected error while getting filter status: {str(e)}") logger.error(f"Unexpected error while getting filter status: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list according to AdGuard spec.
Args:
domain: The domain to allow
Returns:
bool: True if successful
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If AdGuard Home API returns an error
"""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/whitelist/add"
data = {"name": domain}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def close(self): async def close(self):
"""Close the HTTP client.""" """Close the HTTP client."""
await self.client.aclose() await self.client.aclose()

View File

@@ -13,23 +13,24 @@ from .adguard import (
AdGuardError, AdGuardError,
AdGuardConnectionError, AdGuardConnectionError,
AdGuardAPIError, AdGuardAPIError,
DomainCheckResult, FilterStatus,
FilterStatus FilterCheckHostResponse,
SetRulesRequest
) )
from pydantic import BaseModel from pydantic import BaseModel, Field
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Initialize FastAPI app # Initialize API with proper OpenAPI info
app = FastAPI( app = FastAPI(
title="SimpleGuardHome", title="SimpleGuardHome",
description="AdGuard Home REST API interface", description="AdGuard Home REST API interface",
version="1.0.0", version="1.0.0",
openapi_url="/api/openapi.json",
docs_url="/api/docs", docs_url="/api/docs",
redoc_url="/api/redoc" redoc_url="/api/redoc",
openapi_url="/api/openapi.json"
) )
# Add CORS middleware with security headers # Add CORS middleware with security headers
@@ -46,12 +47,10 @@ app.add_middleware(
templates_path = Path(__file__).parent / "templates" templates_path = Path(__file__).parent / "templates"
templates = Jinja2Templates(directory=str(templates_path)) templates = Jinja2Templates(directory=str(templates_path))
# Request/Response Models # Response models matching AdGuard spec
class DomainRequest(BaseModel):
name: str
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
message: str """Error response model according to AdGuard spec."""
message: str = Field(..., description="The error message")
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
async def home(request: Request): async def home(request: Request):
@@ -63,7 +62,7 @@ async def home(request: Request):
@app.get( @app.get(
"/control/filtering/check_host", "/control/filtering/check_host",
response_model=DomainCheckResult, response_model=FilterCheckHostResponse,
responses={ responses={
200: {"description": "OK"}, 200: {"description": "OK"},
400: {"description": "Bad Request", "model": ErrorResponse}, 400: {"description": "Bad Request", "model": ErrorResponse},
@@ -71,8 +70,8 @@ async def home(request: Request):
}, },
tags=["filtering"] tags=["filtering"]
) )
async def check_domain(name: str) -> Dict: async def check_domain(name: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home using AdGuard spec.""" """Check if a domain is blocked by AdGuard Home according to spec."""
if not name: if not name:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@@ -90,7 +89,7 @@ async def check_domain(name: str) -> Dict:
raise raise
@app.post( @app.post(
"/control/filtering/whitelist/add", "/control/filtering/set_rules",
response_model=Dict, response_model=Dict,
responses={ responses={
200: {"description": "OK"}, 200: {"description": "OK"},
@@ -99,27 +98,27 @@ async def check_domain(name: str) -> Dict:
}, },
tags=["filtering"] tags=["filtering"]
) )
async def add_to_whitelist(request: DomainRequest) -> Dict: async def add_to_whitelist(request: SetRulesRequest) -> Dict:
"""Add a domain to the allowed list using AdGuard spec.""" """Add rules using set_rules endpoint according to AdGuard spec."""
if not request.name: if not request.rules:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Domain name is required" detail="Rules are required"
) )
logger.info(f"Adding domain to whitelist: {request.name}") logger.info(f"Adding rules: {request.rules}")
try: try:
async with adguard.AdGuardClient() as client: async with adguard.AdGuardClient() as client:
success = await client.add_allowed_domain(request.name) success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
if success: if success:
return {"message": f"Successfully whitelisted {request.name}"} return {"message": "Rules added successfully"}
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to whitelist domain" detail="Failed to add rules"
) )
except Exception as e: except Exception as e:
logger.error(f"Error whitelisting domain {request.name}: {str(e)}") logger.error(f"Error adding rules: {str(e)}")
raise raise
@app.get( @app.get(
@@ -132,7 +131,7 @@ async def add_to_whitelist(request: DomainRequest) -> Dict:
tags=["filtering"] tags=["filtering"]
) )
async def get_filtering_status() -> FilterStatus: async def get_filtering_status() -> FilterStatus:
"""Get the current filtering status using AdGuard spec.""" """Get filtering status according to AdGuard spec."""
try: try:
async with adguard.AdGuardClient() as client: async with adguard.AdGuardClient() as client:
return await client.get_filter_status() return await client.get_filter_status()

View File

@@ -7,7 +7,7 @@
<script src="https://cdn.tailwindcss.com"></script> <script src="https://cdn.tailwindcss.com"></script>
<script> <script>
async function checkDomain(event) { async function checkDomain(event) {
event.preventDefault(); // Important: prevent default form submission event.preventDefault();
const domain = document.getElementById('domain').value; const domain = document.getElementById('domain').value;
const resultDiv = document.getElementById('result'); const resultDiv = document.getElementById('result');
const unblockDiv = document.getElementById('unblock-action'); const unblockDiv = document.getElementById('unblock-action');
@@ -28,13 +28,14 @@
const data = await response.json(); const data = await response.json();
if (response.ok) { if (response.ok) {
if (data.reason.startsWith('Filtered')) { const isBlocked = data.reason.startsWith('Filtered');
if (isBlocked) {
resultDiv.innerHTML = ` resultDiv.innerHTML = `
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4 mb-4"> <div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4 mb-4">
<p class="font-bold">Domain is blocked</p> <p class="font-bold">Domain is blocked</p>
<p class="text-sm"><strong>${domain}</strong> is blocked by rule:</p> <p class="text-sm"><strong>${domain}</strong> is blocked</p>
<p class="text-sm font-mono bg-red-50 p-2 mt-1 rounded">${data.rule || 'No rule specified'}</p> <p class="text-sm">Reason: ${data.reason}</p>
${data.filter_id ? `<p class="text-sm mt-2">Filter ID: ${data.filter_id}</p>` : ''} ${data.rules?.length ? `<p class="text-sm font-mono bg-red-50 p-2 mt-1 rounded">Rule: ${data.rules[0].text}</p>` : ''}
${data.service_name ? `<p class="text-sm mt-2">Service: ${data.service_name}</p>` : ''} ${data.service_name ? `<p class="text-sm mt-2">Service: ${data.service_name}</p>` : ''}
</div>`; </div>`;
unblockDiv.innerHTML = ` unblockDiv.innerHTML = `
@@ -52,7 +53,6 @@
unblockDiv.innerHTML = ''; unblockDiv.innerHTML = '';
} }
} else { } else {
// Show error message
resultDiv.innerHTML = ` resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4"> <div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
<p class="font-bold">Error checking domain</p> <p class="font-bold">Error checking domain</p>
@@ -68,7 +68,6 @@
</div>`; </div>`;
unblockDiv.innerHTML = ''; unblockDiv.innerHTML = '';
} finally { } finally {
// Reset button state
submitBtn.disabled = false; submitBtn.disabled = false;
submitBtn.innerHTML = 'Check Domain'; submitBtn.innerHTML = 'Check Domain';
} }
@@ -80,17 +79,18 @@
const unblockBtn = unblockDiv.querySelector('button'); const unblockBtn = unblockDiv.querySelector('button');
try { try {
// Show loading state
unblockBtn.disabled = true; unblockBtn.disabled = true;
unblockBtn.innerHTML = '<span class="inline-flex items-center">Unblocking... <svg class="animate-spin ml-2 h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg></span>'; unblockBtn.innerHTML = '<span class="inline-flex items-center">Unblocking... <svg class="animate-spin ml-2 h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg></span>';
const response = await fetch('/control/filtering/whitelist/add', { const response = await fetch('/control/filtering/set_rules', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Accept': 'application/json' 'Accept': 'application/json'
}, },
body: JSON.stringify({ name: domain }) body: JSON.stringify({
rules: [`@@||${domain}^`]
})
}); });
if (response.ok) { if (response.ok) {
@@ -123,7 +123,6 @@
<h1 class="text-3xl font-bold text-center mb-8 text-gray-800">SimpleGuardHome</h1> <h1 class="text-3xl font-bold text-center mb-8 text-gray-800">SimpleGuardHome</h1>
<div class="bg-white rounded-lg shadow-md p-6"> <div class="bg-white rounded-lg shadow-md p-6">
<!-- Remove form action to prevent default submission -->
<form onsubmit="checkDomain(event)" class="mb-6"> <form onsubmit="checkDomain(event)" class="mb-6">
<div class="mb-4"> <div class="mb-4">
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2"> <label for="domain" class="block text-gray-700 text-sm font-bold mb-2">