feat(api): update endpoint names and response models to align with AdGuard spec

This commit is contained in:
pacnpal
2025-01-28 17:01:50 +00:00
parent f4b28cef51
commit 7615f074a8
3 changed files with 99 additions and 152 deletions

View File

@@ -20,55 +20,41 @@ class AdGuardAPIError(AdGuardError):
"""Raised when AdGuard Home API returns an error."""
pass
# Response models matching AdGuard Home API spec
class ResultRule(BaseModel):
"""Rule detail according to AdGuard spec."""
filter_list_id: Optional[int] = Field(None, description="Filter list ID")
text: Optional[str] = Field(None, description="Rule text")
class FilterCheckHostResponse(BaseModel):
"""Response model for check_host endpoint according to AdGuard spec."""
reason: str = Field(..., description="Request filtering status")
filter_id: Optional[int] = Field(None, deprecated=True)
rule: Optional[str] = Field(None, deprecated=True)
rules: Optional[List[ResultRule]] = Field(None, description="Applied rules")
service_name: Optional[str] = Field(None, description="Blocked service name")
cname: Optional[str] = Field(None, description="CNAME value if rewritten")
ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
class Filter(BaseModel):
"""Filter subscription info according to AdGuard spec."""
enabled: bool
id: int = Field(..., description="Filter ID", example=1234)
name: str = Field(..., example="AdGuard Simplified Domain Names filter")
rules_count: int = Field(..., description="Number of rules in filter", example=5912)
url: str = Field(..., example="https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt")
last_updated: Optional[str] = Field(None, example="2018-10-30T12:18:57+03:00")
id: int = Field(..., description="Filter ID")
name: str = Field(..., description="Filter name")
rules_count: int = Field(..., description="Number of rules")
url: str = Field(..., description="Filter URL")
last_updated: Optional[str] = None
class FilterStatus(BaseModel):
"""Filtering settings according to AdGuard spec."""
enabled: bool = Field(..., description="Whether filtering is enabled")
interval: Optional[int] = Field(None, description="Update interval in hours")
enabled: bool
interval: Optional[int] = None
filters: List[Filter] = Field(default_factory=list)
whitelist_filters: List[Filter] = Field(default_factory=list)
user_rules: List[str] = Field(default_factory=list)
class DnsAnswer(BaseModel):
"""DNS answer section according to AdGuard spec."""
ttl: int = Field(..., description="Time to live")
type: str = Field(..., description="Record type", example="A")
value: str = Field(..., description="Record value", example="217.69.139.201")
class DomainCheckResult(BaseModel):
"""Response model for check_host endpoint according to AdGuard spec."""
reason: str = Field(
...,
description="Request filtering status",
enum=[
"NotFilteredNotFound",
"NotFilteredWhiteList",
"NotFilteredError",
"FilteredBlackList",
"FilteredSafeBrowsing",
"FilteredParental",
"FilteredInvalid",
"FilteredSafeSearch",
"FilteredBlockedService",
"Rewrite",
"RewriteEtcHosts",
"RewriteRule"
]
)
filter_id: Optional[int] = Field(None, description="ID of the filter list containing the rule")
rule: Optional[str] = Field(None, description="Applied filtering rule")
service_name: Optional[str] = Field(None, description="Blocked service name if applicable")
cname: Optional[str] = Field(None, description="CNAME value if rewritten")
ip_addrs: Optional[List[str]] = Field(None, description="IP addresses if rewritten")
class SetRulesRequest(BaseModel):
"""Request model for set_rules endpoint according to AdGuard spec."""
rules: List[str] = Field(..., description="List of filtering rules")
class AdGuardClient:
"""Client for interacting with AdGuard Home API according to OpenAPI spec."""
@@ -90,15 +76,7 @@ class AdGuardClient:
logger.info(f"Initialized AdGuard Home client with base URL: {self.base_url}")
async def login(self) -> bool:
"""Authenticate with AdGuard Home and get session cookie.
Returns:
bool: True if authentication successful
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If authentication fails
"""
"""Authenticate with AdGuard Home and get session cookie."""
if not self._auth:
logger.warning("No credentials configured, skipping authentication")
return False
@@ -134,19 +112,8 @@ class AdGuardClient:
if not self._session_cookie:
await self.login()
async def check_domain(self, domain: str) -> DomainCheckResult:
"""Check if a domain is blocked by AdGuard Home.
Args:
domain: The domain to check
Returns:
DomainCheckResult according to AdGuard spec
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If AdGuard Home API returns an error
"""
async def check_domain(self, domain: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/check_host"
params = {"name": domain}
@@ -169,7 +136,7 @@ class AdGuardClient:
response.raise_for_status()
result = response.json()
logger.info(f"Domain check result for {domain}: {result}")
return DomainCheckResult(**result)
return FilterCheckHostResponse(**result)
except httpx.ConnectError as e:
logger.error(f"Connection error while checking domain {domain}: {str(e)}")
@@ -181,16 +148,44 @@ class AdGuardClient:
logger.error(f"Unexpected error while checking domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus:
"""Get the current filtering status.
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list using set_rules endpoint according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/set_rules"
# Add as a whitelist rule according to AdGuard format
data = {"rules": [f"@@||{domain}^"]}
headers = {}
Returns:
FilterStatus according to AdGuard spec
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If AdGuard Home API returns an error
"""
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def get_filter_status(self) -> FilterStatus:
"""Get the current filtering status according to spec."""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/status"
headers = {}
@@ -224,52 +219,6 @@ class AdGuardClient:
logger.error(f"Unexpected error while getting filter status: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def add_allowed_domain(self, domain: str) -> bool:
"""Add a domain to the allowed list according to AdGuard spec.
Args:
domain: The domain to allow
Returns:
bool: True if successful
Raises:
AdGuardConnectionError: If connection to AdGuard Home fails
AdGuardAPIError: If AdGuard Home API returns an error
"""
await self._ensure_authenticated()
url = f"{self.base_url}/filtering/whitelist/add"
data = {"name": domain}
headers = {}
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
try:
logger.info(f"Adding domain to whitelist: {domain}")
response = await self.client.post(url, json=data, headers=headers)
if response.status_code == 401:
logger.info("Session expired, attempting reauth")
await self.login()
if self._session_cookie:
headers['Cookie'] = f'agh_session={self._session_cookie}'
response = await self.client.post(url, json=data, headers=headers)
response.raise_for_status()
logger.info(f"Successfully added {domain} to whitelist")
return True
except httpx.ConnectError as e:
logger.error(f"Connection error while whitelisting domain {domain}: {str(e)}")
raise AdGuardConnectionError(f"Failed to connect to AdGuard Home: {str(e)}")
except httpx.HTTPError as e:
logger.error(f"HTTP error while whitelisting domain {domain}: {str(e)}")
raise AdGuardAPIError(f"AdGuard Home API error: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error while whitelisting domain {domain}: {str(e)}")
raise AdGuardError(f"Unexpected error: {str(e)}")
async def close(self):
"""Close the HTTP client."""
await self.client.aclose()

View File

@@ -13,23 +13,24 @@ from .adguard import (
AdGuardError,
AdGuardConnectionError,
AdGuardAPIError,
DomainCheckResult,
FilterStatus
FilterStatus,
FilterCheckHostResponse,
SetRulesRequest
)
from pydantic import BaseModel
from pydantic import BaseModel, Field
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
# Initialize API with proper OpenAPI info
app = FastAPI(
title="SimpleGuardHome",
description="AdGuard Home REST API interface",
version="1.0.0",
openapi_url="/api/openapi.json",
docs_url="/api/docs",
redoc_url="/api/redoc"
redoc_url="/api/redoc",
openapi_url="/api/openapi.json"
)
# Add CORS middleware with security headers
@@ -46,12 +47,10 @@ app.add_middleware(
templates_path = Path(__file__).parent / "templates"
templates = Jinja2Templates(directory=str(templates_path))
# Request/Response Models
class DomainRequest(BaseModel):
name: str
# Response models matching AdGuard spec
class ErrorResponse(BaseModel):
message: str
"""Error response model according to AdGuard spec."""
message: str = Field(..., description="The error message")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
@@ -63,7 +62,7 @@ async def home(request: Request):
@app.get(
"/control/filtering/check_host",
response_model=DomainCheckResult,
response_model=FilterCheckHostResponse,
responses={
200: {"description": "OK"},
400: {"description": "Bad Request", "model": ErrorResponse},
@@ -71,8 +70,8 @@ async def home(request: Request):
},
tags=["filtering"]
)
async def check_domain(name: str) -> Dict:
"""Check if a domain is blocked by AdGuard Home using AdGuard spec."""
async def check_domain(name: str) -> FilterCheckHostResponse:
"""Check if a domain is blocked by AdGuard Home according to spec."""
if not name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -90,7 +89,7 @@ async def check_domain(name: str) -> Dict:
raise
@app.post(
"/control/filtering/whitelist/add",
"/control/filtering/set_rules",
response_model=Dict,
responses={
200: {"description": "OK"},
@@ -99,27 +98,27 @@ async def check_domain(name: str) -> Dict:
},
tags=["filtering"]
)
async def add_to_whitelist(request: DomainRequest) -> Dict:
"""Add a domain to the allowed list using AdGuard spec."""
if not request.name:
async def add_to_whitelist(request: SetRulesRequest) -> Dict:
"""Add rules using set_rules endpoint according to AdGuard spec."""
if not request.rules:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Domain name is required"
detail="Rules are required"
)
logger.info(f"Adding domain to whitelist: {request.name}")
logger.info(f"Adding rules: {request.rules}")
try:
async with adguard.AdGuardClient() as client:
success = await client.add_allowed_domain(request.name)
success = await client.add_allowed_domain(request.rules[0].strip("@@||^"))
if success:
return {"message": f"Successfully whitelisted {request.name}"}
return {"message": "Rules added successfully"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to whitelist domain"
detail="Failed to add rules"
)
except Exception as e:
logger.error(f"Error whitelisting domain {request.name}: {str(e)}")
logger.error(f"Error adding rules: {str(e)}")
raise
@app.get(
@@ -132,7 +131,7 @@ async def add_to_whitelist(request: DomainRequest) -> Dict:
tags=["filtering"]
)
async def get_filtering_status() -> FilterStatus:
"""Get the current filtering status using AdGuard spec."""
"""Get filtering status according to AdGuard spec."""
try:
async with adguard.AdGuardClient() as client:
return await client.get_filter_status()

View File

@@ -7,7 +7,7 @@
<script src="https://cdn.tailwindcss.com"></script>
<script>
async function checkDomain(event) {
event.preventDefault(); // Important: prevent default form submission
event.preventDefault();
const domain = document.getElementById('domain').value;
const resultDiv = document.getElementById('result');
const unblockDiv = document.getElementById('unblock-action');
@@ -28,13 +28,14 @@
const data = await response.json();
if (response.ok) {
if (data.reason.startsWith('Filtered')) {
const isBlocked = data.reason.startsWith('Filtered');
if (isBlocked) {
resultDiv.innerHTML = `
<div class="bg-red-100 border-l-4 border-red-500 text-red-700 p-4 mb-4">
<p class="font-bold">Domain is blocked</p>
<p class="text-sm"><strong>${domain}</strong> is blocked by rule:</p>
<p class="text-sm font-mono bg-red-50 p-2 mt-1 rounded">${data.rule || 'No rule specified'}</p>
${data.filter_id ? `<p class="text-sm mt-2">Filter ID: ${data.filter_id}</p>` : ''}
<p class="text-sm"><strong>${domain}</strong> is blocked</p>
<p class="text-sm">Reason: ${data.reason}</p>
${data.rules?.length ? `<p class="text-sm font-mono bg-red-50 p-2 mt-1 rounded">Rule: ${data.rules[0].text}</p>` : ''}
${data.service_name ? `<p class="text-sm mt-2">Service: ${data.service_name}</p>` : ''}
</div>`;
unblockDiv.innerHTML = `
@@ -52,7 +53,6 @@
unblockDiv.innerHTML = '';
}
} else {
// Show error message
resultDiv.innerHTML = `
<div class="bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700 p-4">
<p class="font-bold">Error checking domain</p>
@@ -68,7 +68,6 @@
</div>`;
unblockDiv.innerHTML = '';
} finally {
// Reset button state
submitBtn.disabled = false;
submitBtn.innerHTML = 'Check Domain';
}
@@ -80,17 +79,18 @@
const unblockBtn = unblockDiv.querySelector('button');
try {
// Show loading state
unblockBtn.disabled = true;
unblockBtn.innerHTML = '<span class="inline-flex items-center">Unblocking... <svg class="animate-spin ml-2 h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg></span>';
const response = await fetch('/control/filtering/whitelist/add', {
const response = await fetch('/control/filtering/set_rules', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json'
},
body: JSON.stringify({ name: domain })
body: JSON.stringify({
rules: [`@@||${domain}^`]
})
});
if (response.ok) {
@@ -123,7 +123,6 @@
<h1 class="text-3xl font-bold text-center mb-8 text-gray-800">SimpleGuardHome</h1>
<div class="bg-white rounded-lg shadow-md p-6">
<!-- Remove form action to prevent default submission -->
<form onsubmit="checkDomain(event)" class="mb-6">
<div class="mb-4">
<label for="domain" class="block text-gray-700 text-sm font-bold mb-2">