diff --git a/src/simpleguardhome/adguard.py b/src/simpleguardhome/adguard.py index aaf9e75..b8ef86e 100644 --- a/src/simpleguardhome/adguard.py +++ b/src/simpleguardhome/adguard.py @@ -20,55 +20,41 @@ class AdGuardAPIError(AdGuardError): """Raised when AdGuard Home API returns an error.""" pass -# Response models matching AdGuard Home API spec +class ResultRule(BaseModel): + """Rule detail according to AdGuard spec.""" + filter_list_id: Optional[int] = Field(None, description="Filter list ID") + text: Optional[str] = Field(None, description="Rule text") + +class FilterCheckHostResponse(BaseModel): + """Response model for check_host endpoint according to AdGuard spec.""" + reason: str = Field(..., description="Request filtering status") + filter_id: Optional[int] = Field(None, deprecated=True) + rule: Optional[str] = Field(None, deprecated=True) + rules: Optional[List[ResultRule]] = Field(None, description="Applied rules") + service_name: Optional[str] = Field(None, description="Blocked service name") + cname: Optional[str] = Field(None, description="CNAME value if rewritten") + ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten") + class Filter(BaseModel): """Filter subscription info according to AdGuard spec.""" enabled: bool - id: int = Field(..., description="Filter ID", example=1234) - name: str = Field(..., example="AdGuard Simplified Domain Names filter") - rules_count: int = Field(..., description="Number of rules in filter", example=5912) - url: str = Field(..., example="https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt") - last_updated: Optional[str] = Field(None, example="2018-10-30T12:18:57+03:00") + id: int = Field(..., description="Filter ID") + name: str = Field(..., description="Filter name") + rules_count: int = Field(..., description="Number of rules") + url: str = Field(..., description="Filter URL") + last_updated: Optional[str] = None class FilterStatus(BaseModel): """Filtering settings according to AdGuard spec.""" - enabled: bool = Field(..., description="Whether filtering is enabled") - interval: Optional[int] = Field(None, description="Update interval in hours") + enabled: bool + interval: Optional[int] = None filters: List[Filter] = Field(default_factory=list) whitelist_filters: List[Filter] = Field(default_factory=list) user_rules: List[str] = Field(default_factory=list) -class DnsAnswer(BaseModel): - """DNS answer section according to AdGuard spec.""" - ttl: int = Field(..., description="Time to live") - type: str = Field(..., description="Record type", example="A") - value: str = Field(..., description="Record value", example="217.69.139.201") - -class DomainCheckResult(BaseModel): - """Response model for check_host endpoint according to AdGuard spec.""" - reason: str = Field( - ..., - description="Request filtering status", - enum=[ - "NotFilteredNotFound", - "NotFilteredWhiteList", - "NotFilteredError", - "FilteredBlackList", - "FilteredSafeBrowsing", - "FilteredParental", - "FilteredInvalid", - "FilteredSafeSearch", - "FilteredBlockedService", - "Rewrite", - "RewriteEtcHosts", - "RewriteRule" - ] - ) - filter_id: Optional[int] = Field(None, description="ID of the filter list containing the rule") - rule: Optional[str] = Field(None, description="Applied filtering rule") - service_name: Optional[str] = Field(None, description="Blocked service name if applicable") - cname: Optional[str] = Field(None, description="CNAME value if rewritten") - ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten") +class SetRulesRequest(BaseModel): + """Request model for set_rules endpoint according to AdGuard spec.""" + rules: List[str] = Field(..., description="List of filtering rules") class AdGuardClient: """Client for interacting with AdGuard Home API according to OpenAPI spec.""" @@ -90,15 +76,7 @@ class AdGuardClient: logger.info(f"Initialized AdGuard Home client with base URL: {self.base_url}") async def login(self) -> bool: - """Authenticate with AdGuard Home and get session cookie. - - Returns: - bool: True if authentication successful - - Raises: - AdGuardConnectionError: If connection to AdGuard Home fails - AdGuardAPIError: If authentication fails - """ + """Authenticate with AdGuard Home and get session cookie.""" if not self._auth: logger.warning("No credentials configured, skipping authentication") return False @@ -134,19 +112,8 @@ class AdGuardClient: if not self._session_cookie: await self.login() - async def check_domain(self, domain: str) -> DomainCheckResult: - """Check if a domain is blocked by AdGuard Home. - - Args: - domain: The domain to check - - Returns: - DomainCheckResult according to AdGuard spec - - Raises: - AdGuardConnectionError: If connection to AdGuard Home fails - AdGuardAPIError: If AdGuard Home API returns an error - """ + async def check_domain(self, domain: str) -> FilterCheckHostResponse: + """Check if a domain is blocked by AdGuard Home according to spec.""" await self._ensure_authenticated() url = f"{self.base_url}/filtering/check_host" params = {"name": domain} @@ -169,7 +136,7 @@ class AdGuardClient: response.raise_for_status() result = response.json() logger.info(f"Domain check result for {domain}: {result}") - return DomainCheckResult(**result) + return FilterCheckHostResponse(**result) except httpx.ConnectError as e: logger.error(f"Connection error while checking domain {domain}: {str(e)}") @@ -181,16 +148,44 @@ class AdGuardClient: logger.error(f"Unexpected error while checking domain {domain}: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}") - async def get_filter_status(self) -> FilterStatus: - """Get the current filtering status. + async def add_allowed_domain(self, domain: str) -> bool: + """Add a domain to the allowed list using set_rules endpoint according to spec.""" + await self._ensure_authenticated() + url = f"{self.base_url}/filtering/set_rules" + # Add as a whitelist rule according to AdGuard format + data = {"rules": [f"@@||{domain}^"]} + headers = {} - Returns: - FilterStatus according to AdGuard spec + if self._session_cookie: + headers['Cookie'] = f'agh_session={self._session_cookie}' + + try: + logger.info(f"Adding domain to whitelist: {domain}") + response = await self.client.post(url, json=data, headers=headers) - Raises: - AdGuardConnectionError: If connection to AdGuard Home fails - AdGuardAPIError: If AdGuard Home API returns an error - """ + if response.status_code == 401: + logger.info("Session expired, attempting reauth") + await self.login() + if self._session_cookie: + headers['Cookie'] = f'agh_session={self._session_cookie}' + response = await self.client.post(url, json=data, headers=headers) + + response.raise_for_status() + logger.info(f"Successfully added {domain} to whitelist") + return True + + except httpx.ConnectError as e: + logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}") + raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}") + except httpx.HTTPError as e: + logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}") + raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}") + raise AdGuardError(f"Unexpected error: {str(e)}") + + async def get_filter_status(self) -> FilterStatus: + """Get the current filtering status according to spec.""" await self._ensure_authenticated() url = f"{self.base_url}/filtering/status" headers = {} @@ -224,52 +219,6 @@ class AdGuardClient: logger.error(f"Unexpected error while getting filter status: {str(e)}") raise AdGuardError(f"Unexpected error: {str(e)}") - async def add_allowed_domain(self, domain: str) -> bool: - """Add a domain to the allowed list according to AdGuard spec. - - Args: - domain: The domain to allow - - Returns: - bool: True if successful - - Raises: - AdGuardConnectionError: If connection to AdGuard Home fails - AdGuardAPIError: If AdGuard Home API returns an error - """ - await self._ensure_authenticated() - url = f"{self.base_url}/filtering/whitelist/add" - data = {"name": domain} - headers = {} - - if self._session_cookie: - headers['Cookie'] = f'agh_session={self._session_cookie}' - - try: - logger.info(f"Adding domain to whitelist: {domain}") - response = await self.client.post(url, json=data, headers=headers) - - if response.status_code == 401: - logger.info("Session expired, attempting reauth") - await self.login() - if self._session_cookie: - headers['Cookie'] = f'agh_session={self._session_cookie}' - response = await self.client.post(url, json=data, headers=headers) - - response.raise_for_status() - logger.info(f"Successfully added {domain} to whitelist") - return True - - except httpx.ConnectError as e: - logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}") - raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}") - except httpx.HTTPError as e: - logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}") - raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}") - except Exception as e: - logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}") - raise AdGuardError(f"Unexpected error: {str(e)}") - async def close(self): """Close the HTTP client.""" await self.client.aclose() diff --git a/src/simpleguardhome/main.py b/src/simpleguardhome/main.py index ccbbe7d..7c30a3f 100644 --- a/src/simpleguardhome/main.py +++ b/src/simpleguardhome/main.py @@ -13,23 +13,24 @@ from .adguard import ( AdGuardError, AdGuardConnectionError, AdGuardAPIError, - DomainCheckResult, - FilterStatus + FilterStatus, + FilterCheckHostResponse, + SetRulesRequest ) -from pydantic import BaseModel +from pydantic import BaseModel, Field # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Initialize FastAPI app +# Initialize API with proper OpenAPI info app = FastAPI( title="SimpleGuardHome", description="AdGuard Home REST API interface", version="1.0.0", - openapi_url="/api/openapi.json", docs_url="/api/docs", - redoc_url="/api/redoc" + redoc_url="/api/redoc", + openapi_url="/api/openapi.json" ) # Add CORS middleware with security headers @@ -46,12 +47,10 @@ app.add_middleware( templates_path = Path(__file__).parent / "templates" templates = Jinja2Templates(directory=str(templates_path)) -# Request/Response Models -class DomainRequest(BaseModel): - name: str - +# Response models matching AdGuard spec class ErrorResponse(BaseModel): - message: str + """Error response model according to AdGuard spec.""" + message: str = Field(..., description="The error message") @app.get("/", response_class=HTMLResponse) async def home(request: Request): @@ -63,7 +62,7 @@ async def home(request: Request): @app.get( "/control/filtering/check_host", - response_model=DomainCheckResult, + response_model=FilterCheckHostResponse, responses={ 200: {"description": "OK"}, 400: {"description": "Bad Request", "model": ErrorResponse}, @@ -71,8 +70,8 @@ async def home(request: Request): }, tags=["filtering"] ) -async def check_domain(name: str) -> Dict: - """Check if a domain is blocked by AdGuard Home using AdGuard spec.""" +async def check_domain(name: str) -> FilterCheckHostResponse: + """Check if a domain is blocked by AdGuard Home according to spec.""" if not name: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -90,7 +89,7 @@ async def check_domain(name: str) -> Dict: raise @app.post( - "/control/filtering/whitelist/add", + "/control/filtering/set_rules", response_model=Dict, responses={ 200: {"description": "OK"}, @@ -99,27 +98,27 @@ async def check_domain(name: str) -> Dict: }, tags=["filtering"] ) -async def add_to_whitelist(request: DomainRequest) -> Dict: - """Add a domain to the allowed list using AdGuard spec.""" - if not request.name: +async def add_to_whitelist(request: SetRulesRequest) -> Dict: + """Add rules using set_rules endpoint according to AdGuard spec.""" + if not request.rules: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Domain name is required" + detail="Rules are required" ) - logger.info(f"Adding domain to whitelist: {request.name}") + logger.info(f"Adding rules: {request.rules}") try: async with adguard.AdGuardClient() as client: - success = await client.add_allowed_domain(request.name) + success = await client.add_allowed_domain(request.rules[0].strip("@@||^")) if success: - return {"message": f"Successfully whitelisted {request.name}"} + return {"message": "Rules added successfully"} else: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to whitelist domain" + detail="Failed to add rules" ) except Exception as e: - logger.error(f"Error whitelisting domain {request.name}: {str(e)}") + logger.error(f"Error adding rules: {str(e)}") raise @app.get( @@ -132,7 +131,7 @@ async def add_to_whitelist(request: DomainRequest) -> Dict: tags=["filtering"] ) async def get_filtering_status() -> FilterStatus: - """Get the current filtering status using AdGuard spec.""" + """Get filtering status according to AdGuard spec.""" try: async with adguard.AdGuardClient() as client: return await client.get_filter_status() diff --git a/src/simpleguardhome/templates/index.html b/src/simpleguardhome/templates/index.html index 46f7901..acdd839 100644 --- a/src/simpleguardhome/templates/index.html +++ b/src/simpleguardhome/templates/index.html @@ -7,7 +7,7 @@