From db2fb2093263b4e66841811260dcc1b1255a602e Mon Sep 17 00:00:00 2001 From: pacnpal <183241239+pacnpal@users.noreply.github.com> Date: Thu, 12 Jun 2025 13:46:20 -0400 Subject: [PATCH] Update adguard.py --- src/simpleguardhome/adguard.py | 51 +++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/simpleguardhome/adguard.py b/src/simpleguardhome/adguard.py index d8fc63d..7f1fb35 100644 --- a/src/simpleguardhome/adguard.py +++ b/src/simpleguardhome/adguard.py @@ -74,6 +74,22 @@ def validate_domain(domain: str) -> bool: if not domain or len(domain) > 255: return False return bool(DOMAIN_PATTERN.match(domain)) + +def get_parent_domains(domain: str) -> List[str]: + """Get all parent domains for a given domain, excluding the TLD. + + Example: + Input: track.soclevercomm.jmsend.com + Output: [ + 'track.soclevercomm.jmsend.com', + 'soclevercomm.jmsend.com', + 'jmsend.com' + ] + """ + parts = domain.split('.') + # Only return domains that have at least one subdomain (length >= 2) + # This excludes TLDs like 'com', 'net', 'org', etc. + return ['.'.join(parts[i:]) for i in range(len(parts)-1)] def sanitize_rule(rule: str) -> str: """Sanitize and validate rule format.""" @@ -157,7 +173,7 @@ class AdGuardClient: """Ensure we have a valid session cookie.""" if not self._session_cookie: await self.login() - + async def check_domain(self, domain: str) -> FilterCheckHostResponse: """Check if a domain is blocked by AdGuard Home according to spec.""" # Validate domain format @@ -166,7 +182,6 @@ class AdGuardClient: await self._ensure_authenticated() url = f"{self.base_url}/filtering/check_host" - params = {"name": domain} headers = {} if self._session_cookie: @@ -174,16 +189,32 @@ class AdGuardClient: try: logger.info(f"Checking domain: {domain}") - response = await self.client.get(url, params=params, headers=headers) + # Get all parent domains to check (excluding TLD) + domains_to_check = get_parent_domains(domain) - if response.status_code == 401: - logger.info("Session expired, attempting reauth") - await self.login() - if self._session_cookie: - headers['Cookie'] = f'agh_session={self._session_cookie}' + for check_domain in domains_to_check: + params = {"name": check_domain} response = await self.client.get(url, params=params, headers=headers) + + if response.status_code == 401: + logger.info("Session expired, attempting reauth") + await self.login() + if self._session_cookie: + headers['Cookie'] = f'agh_session={self._session_cookie}' + response = await self.client.get(url, params=params, headers=headers) + + response.raise_for_status() + result = response.json() + + # If this domain is filtered, return the result + if result.get("reason", "").startswith("Filtered"): + logger.info(f"Domain {domain} is filtered due to parent domain {check_domain}") + logger.info(f"Domain check result: {result}") + return FilterCheckHostResponse(**result) - response.raise_for_status() + # If no parent domains are filtered, return the result for the original domain + params = {"name": domain} + response = await self.client.get(url, params=params, headers=headers) result = response.json() logger.info(f"Domain check result for {domain}: {result}") return FilterCheckHostResponse(**result) @@ -307,4 +338,4 @@ class AdGuardClient: return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() \ No newline at end of file + await self.close()