mirror of
https://github.com/pacnpal/simpleguardhome.git
synced 2025-12-19 20:11:14 -05:00
Update adguard.py
This commit is contained in:
@@ -75,6 +75,22 @@ def validate_domain(domain: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
return bool(DOMAIN_PATTERN.match(domain))
|
return bool(DOMAIN_PATTERN.match(domain))
|
||||||
|
|
||||||
|
def get_parent_domains(domain: str) -> List[str]:
|
||||||
|
"""Get all parent domains for a given domain, excluding the TLD.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Input: track.soclevercomm.jmsend.com
|
||||||
|
Output: [
|
||||||
|
'track.soclevercomm.jmsend.com',
|
||||||
|
'soclevercomm.jmsend.com',
|
||||||
|
'jmsend.com'
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
parts = domain.split('.')
|
||||||
|
# Only return domains that have at least one subdomain (length >= 2)
|
||||||
|
# This excludes TLDs like 'com', 'net', 'org', etc.
|
||||||
|
return ['.'.join(parts[i:]) for i in range(len(parts)-1)]
|
||||||
|
|
||||||
def sanitize_rule(rule: str) -> str:
|
def sanitize_rule(rule: str) -> str:
|
||||||
"""Sanitize and validate rule format."""
|
"""Sanitize and validate rule format."""
|
||||||
# Remove any whitespace and normalize
|
# Remove any whitespace and normalize
|
||||||
@@ -166,7 +182,6 @@ class AdGuardClient:
|
|||||||
|
|
||||||
await self._ensure_authenticated()
|
await self._ensure_authenticated()
|
||||||
url = f"{self.base_url}/filtering/check_host"
|
url = f"{self.base_url}/filtering/check_host"
|
||||||
params = {"name": domain}
|
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
if self._session_cookie:
|
if self._session_cookie:
|
||||||
@@ -174,16 +189,32 @@ class AdGuardClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Checking domain: {domain}")
|
logger.info(f"Checking domain: {domain}")
|
||||||
response = await self.client.get(url, params=params, headers=headers)
|
# Get all parent domains to check (excluding TLD)
|
||||||
|
domains_to_check = get_parent_domains(domain)
|
||||||
|
|
||||||
if response.status_code == 401:
|
for check_domain in domains_to_check:
|
||||||
logger.info("Session expired, attempting reauth")
|
params = {"name": check_domain}
|
||||||
await self.login()
|
|
||||||
if self._session_cookie:
|
|
||||||
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
|
||||||
response = await self.client.get(url, params=params, headers=headers)
|
response = await self.client.get(url, params=params, headers=headers)
|
||||||
|
|
||||||
response.raise_for_status()
|
if response.status_code == 401:
|
||||||
|
logger.info("Session expired, attempting reauth")
|
||||||
|
await self.login()
|
||||||
|
if self._session_cookie:
|
||||||
|
headers['Cookie'] = f'agh_session={self._session_cookie}'
|
||||||
|
response = await self.client.get(url, params=params, headers=headers)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# If this domain is filtered, return the result
|
||||||
|
if result.get("reason", "").startswith("Filtered"):
|
||||||
|
logger.info(f"Domain {domain} is filtered due to parent domain {check_domain}")
|
||||||
|
logger.info(f"Domain check result: {result}")
|
||||||
|
return FilterCheckHostResponse(**result)
|
||||||
|
|
||||||
|
# If no parent domains are filtered, return the result for the original domain
|
||||||
|
params = {"name": domain}
|
||||||
|
response = await self.client.get(url, params=params, headers=headers)
|
||||||
result = response.json()
|
result = response.json()
|
||||||
logger.info(f"Domain check result for {domain}: {result}")
|
logger.info(f"Domain check result for {domain}: {result}")
|
||||||
return FilterCheckHostResponse(**result)
|
return FilterCheckHostResponse(**result)
|
||||||
|
|||||||
Reference in New Issue
Block a user