From 7615f074a865addddda2da5f48fdcd2ac8b22ed5 Mon Sep 17 00:00:00 2001
From: pacnpal <183241239+pacnpal@users.noreply.github.com>
Date: Tue, 28 Jan 2025 17:01:50 +0000
Subject: [PATCH] feat(api): update endpoint names and response models to align
with AdGuard spec
---
src/simpleguardhome/adguard.py | 181 ++++++++---------------
src/simpleguardhome/main.py | 49 +++---
src/simpleguardhome/templates/index.html | 21 ++-
3 files changed, 99 insertions(+), 152 deletions(-)
diff --git a/src/simpleguardhome/adguard.py b/src/simpleguardhome/adguard.py
index aaf9e75..b8ef86e 100644
--- a/src/simpleguardhome/adguard.py
+++ b/src/simpleguardhome/adguard.py
@@ -20,55 +20,41 @@ class AdGuardAPIError(AdGuardError):
"""Raised when AdGuard Home API returns an error."""
pass
-# Response models matching AdGuard Home API spec
+class ResultRule(BaseModel):
+ """Rule detail according to AdGuard spec."""
+ filter_list_id: Optional[int] = Field(None, description="Filter list ID")
+ text: Optional[str] = Field(None, description="Rule text")
+
+class FilterCheckHostResponse(BaseModel):
+ """Response model for check_host endpoint according to AdGuard spec."""
+ reason: str = Field(..., description="Request filtering status")
+ filter_id: Optional[int] = Field(None, deprecated=True)
+ rule: Optional[str] = Field(None, deprecated=True)
+ rules: Optional[List[ResultRule]] = Field(None, description="Applied rules")
+ service_name: Optional[str] = Field(None, description="Blocked service name")
+ cname: Optional[str] = Field(None, description="CNAME value if rewritten")
+ ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
+
class Filter(BaseModel):
"""Filter subscription info according to AdGuard spec."""
enabled: bool
- id: int = Field(..., description="Filter ID", example=1234)
- name: str = Field(..., example="AdGuard Simplified Domain Names filter")
- rules_count: int = Field(..., description="Number of rules in filter", example=5912)
- url: str = Field(..., example="https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt")
- last_updated: Optional[str] = Field(None, example="2018-10-30T12:18:57+03:00")
+ id: int = Field(..., description="Filter ID")
+ name: str = Field(..., description="Filter name")
+ rules_count: int = Field(..., description="Number of rules")
+ url: str = Field(..., description="Filter URL")
+ last_updated: Optional[str] = None
class FilterStatus(BaseModel):
"""Filtering settings according to AdGuard spec."""
- enabled: bool = Field(..., description="Whether filtering is enabled")
- interval: Optional[int] = Field(None, description="Update interval in hours")
+ enabled: bool
+ interval: Optional[int] = None
filters: List[Filter] = Field(default_factory=list)
whitelist_filters: List[Filter] = Field(default_factory=list)
user_rules: List[str] = Field(default_factory=list)
-class DnsAnswer(BaseModel):
- """DNS answer section according to AdGuard spec."""
- ttl: int = Field(..., description="Time to live")
- type: str = Field(..., description="Record type", example="A")
- value: str = Field(..., description="Record value", example="217.69.139.201")
-
-class DomainCheckResult(BaseModel):
- """Response model for check_host endpoint according to AdGuard spec."""
- reason: str = Field(
- ...,
- description="Request filtering status",
- enum=[
- "NotFilteredNotFound",
- "NotFilteredWhiteList",
- "NotFilteredError",
- "FilteredBlackList",
- "FilteredSafeBrowsing",
- "FilteredParental",
- "FilteredInvalid",
- "FilteredSafeSearch",
- "FilteredBlockedService",
- "Rewrite",
- "RewriteEtcHosts",
- "RewriteRule"
- ]
- )
- filter_id: Optional[int] = Field(None, description="ID of the filter list containing the rule")
- rule: Optional[str] = Field(None, description="Applied filtering rule")
- service_name: Optional[str] = Field(None, description="Blocked service name if applicable")
- cname: Optional[str] = Field(None, description="CNAME value if rewritten")
- ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
+class SetRulesRequest(BaseModel):
+ """Request model for set_rules endpoint according to AdGuard spec."""
+ rules: List[str] = Field(..., description="List of filtering rules")
class AdGuardClient:
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
@@ -90,15 +76,7 @@ class AdGuardClient:
logger.info(f"Initialized AdGuard Home client with base URL: {self.base_url}")
async def login(self) -> bool:
- """Authenticate with AdGuard Home and get session cookie.
-
- Returns:
- bool: True if authentication successful
-
- Raises:
- AdGuardConnectionError: If connection to AdGuard Home fails
- AdGuardAPIError: If authentication fails
- """
+ """Authenticate with AdGuard Home and get session cookie."""
if not self._auth:
logger.warning("No credentials configured, skipping authentication")
return False
@@ -134,19 +112,8 @@ class AdGuardClient:
if not self._session_cookie:
await self.login()
- async def check_domain(self, domain: str) -> DomainCheckResult:
- """Check if a domain is blocked by AdGuard Home.
-
- Args:
- domain: The domain to check
-
- Returns:
- DomainCheckResult according to AdGuard spec
-
- Raises:
- AdGuardConnectionError: If connection to AdGuard Home fails
- AdGuardAPIError: If AdGuard Home API returns an error
- """
+ async def check_domain(self, domain: str) -> FilterCheckHostResponse:
+ """Check if a domain is blocked by AdGuard Home according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/check_host"
params = {"name": domain}
@@ -169,7 +136,7 @@ class AdGuardClient:
response.raise_for_status()
result = response.json()
logger.info(f"Domain check result for {domain}: {result}")
- return DomainCheckResult(**result)
+ return FilterCheckHostResponse(**result)
except httpx.ConnectError as e:
logger.error(f"Connection error while checking domain {domain}: {str(e)}")
@@ -181,16 +148,44 @@ class AdGuardClient:
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
- async def get_filter_status(self) -> FilterStatus:
- """Get the current filtering status.
+ async def add_allowed_domain(self, domain: str) -> bool:
+ """Add a domain to the allowed list using set_rules endpoint according to spec."""
+ await self._ensure_authenticated()
+ url = f"{self.base_url}/filtering/set_rules"
+ # Add as a whitelist rule according to AdGuard format
+ data = {"rules": [f"@@||{domain}^"]}
+ headers = {}
- Returns:
- FilterStatus according to AdGuard spec
+ if self._session_cookie:
+ headers['Cookie'] = f'agh_session={self._session_cookie}'
+
+ try:
+ logger.info(f"Adding domain to whitelist: {domain}")
+ response = await self.client.post(url, json=data, headers=headers)
- Raises:
- AdGuardConnectionError: If connection to AdGuard Home fails
- AdGuardAPIError: If AdGuard Home API returns an error
- """
+ if response.status_code == 401:
+ logger.info("Session expired, attempting reauth")
+ await self.login()
+ if self._session_cookie:
+ headers['Cookie'] = f'agh_session={self._session_cookie}'
+ response = await self.client.post(url, json=data, headers=headers)
+
+ response.raise_for_status()
+ logger.info(f"Successfully added {domain} to whitelist")
+ return True
+
+ except httpx.ConnectError as e:
+ logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
+ raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
+ except httpx.HTTPError as e:
+ logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
+ raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
+ except Exception as e:
+ logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
+ raise AdGuardError(f"Unexpected error: {str(e)}")
+
+ async def get_filter_status(self) -> FilterStatus:
+ """Get the current filtering status according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/status"
headers = {}
@@ -224,52 +219,6 @@ class AdGuardClient:
logger.error(f"Unexpected error while getting filter status: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
- async def add_allowed_domain(self, domain: str) -> bool:
- """Add a domain to the allowed list according to AdGuard spec.
-
- Args:
- domain: The domain to allow
-
- Returns:
- bool: True if successful
-
- Raises:
- AdGuardConnectionError: If connection to AdGuard Home fails
- AdGuardAPIError: If AdGuard Home API returns an error
- """
- await self._ensure_authenticated()
- url = f"{self.base_url}/filtering/whitelist/add"
- data = {"name": domain}
- headers = {}
-
- if self._session_cookie:
- headers['Cookie'] = f'agh_session={self._session_cookie}'
-
- try:
- logger.info(f"Adding domain to whitelist: {domain}")
- response = await self.client.post(url, json=data, headers=headers)
-
- if response.status_code == 401:
- logger.info("Session expired, attempting reauth")
- await self.login()
- if self._session_cookie:
- headers['Cookie'] = f'agh_session={self._session_cookie}'
- response = await self.client.post(url, json=data, headers=headers)
-
- response.raise_for_status()
- logger.info(f"Successfully added {domain} to whitelist")
- return True
-
- except httpx.ConnectError as e:
- logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
- raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
- except httpx.HTTPError as e:
- logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
- raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
- except Exception as e:
- logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
- raise AdGuardError(f"Unexpected error: {str(e)}")
-
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()
diff --git a/src/simpleguardhome/main.py b/src/simpleguardhome/main.py
index ccbbe7d..7c30a3f 100644
--- a/src/simpleguardhome/main.py
+++ b/src/simpleguardhome/main.py
@@ -13,23 +13,24 @@ from .adguard import (
AdGuardError,
AdGuardConnectionError,
AdGuardAPIError,
- DomainCheckResult,
- FilterStatus
+ FilterStatus,
+ FilterCheckHostResponse,
+ SetRulesRequest
)
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-# Initialize FastAPI app
+# Initialize API with proper OpenAPI info
app = FastAPI(
title="SimpleGuardHome",
description="AdGuard Home REST API interface",
version="1.0.0",
- openapi_url="/api/openapi.json",
docs_url="/api/docs",
- redoc_url="/api/redoc"
+ redoc_url="/api/redoc",
+ openapi_url="/api/openapi.json"
)
# Add CORS middleware with security headers
@@ -46,12 +47,10 @@ app.add_middleware(
templates_path = Path(__file__).parent / "templates"
templates = Jinja2Templates(directory=str(templates_path))
-# Request/Response Models
-class DomainRequest(BaseModel):
- name: str
-
+# Response models matching AdGuard spec
class ErrorResponse(BaseModel):
- message: str
+ """Error response model according to AdGuard spec."""
+ message: str = Field(..., description="The error message")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
@@ -63,7 +62,7 @@ async def home(request: Request):
@app.get(
"/control/filtering/check_host",
- response_model=DomainCheckResult,
+ response_model=FilterCheckHostResponse,
responses={
200: {"description": "OK"},
400: {"description": "Bad Request", "model": ErrorResponse},
@@ -71,8 +70,8 @@ async def home(request: Request):
},
tags=["filtering"]
)
-async def check_domain(name: str) -> Dict:
- """Check if a domain is blocked by AdGuard Home using AdGuard spec."""
+async def check_domain(name: str) -> FilterCheckHostResponse:
+ """Check if a domain is blocked by AdGuard Home according to spec."""
if not name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -90,7 +89,7 @@ async def check_domain(name: str) -> Dict:
raise
@app.post(
- "/control/filtering/whitelist/add",
+ "/control/filtering/set_rules",
response_model=Dict,
responses={
200: {"description": "OK"},
@@ -99,27 +98,27 @@ async def check_domain(name: str) -> Dict:
},
tags=["filtering"]
)
-async def add_to_whitelist(request: DomainRequest) -> Dict:
- """Add a domain to the allowed list using AdGuard spec."""
- if not request.name:
+async def add_to_whitelist(request: SetRulesRequest) -> Dict:
+ """Add rules using set_rules endpoint according to AdGuard spec."""
+ if not request.rules:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
- detail="Domain name is required"
+ detail="Rules are required"
)
- logger.info(f"Adding domain to whitelist: {request.name}")
+ logger.info(f"Adding rules: {request.rules}")
try:
async with adguard.AdGuardClient() as client:
- success = await client.add_allowed_domain(request.name)
+ success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
if success:
- return {"message": f"Successfully whitelisted {request.name}"}
+ return {"message": "Rules added successfully"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="Failed to whitelist domain"
+ detail="Failed to add rules"
)
except Exception as e:
- logger.error(f"Error whitelisting domain {request.name}: {str(e)}")
+ logger.error(f"Error adding rules: {str(e)}")
raise
@app.get(
@@ -132,7 +131,7 @@ async def add_to_whitelist(request: DomainRequest) -> Dict:
tags=["filtering"]
)
async def get_filtering_status() -> FilterStatus:
- """Get the current filtering status using AdGuard spec."""
+ """Get filtering status according to AdGuard spec."""
try:
async with adguard.AdGuardClient() as client:
return await client.get_filter_status()
diff --git a/src/simpleguardhome/templates/index.html b/src/simpleguardhome/templates/index.html
index 46f7901..acdd839 100644
--- a/src/simpleguardhome/templates/index.html
+++ b/src/simpleguardhome/templates/index.html
@@ -7,7 +7,7 @@