mirror of
https://github.com/pacnpal/Pac-cogs.git
synced 2025-12-20 02:41:06 -05:00
Improve database connection management
- Add proper error handling with DatabaseError - Add proper error context support - Add better connection state management - Add proper type hints - Add proper docstrings - Add better error recovery - Add proper error context creation - Add better metrics tracking - Add proper status reporting - Add better connection validation - Add proper logging - Add better error messages - Add proper transaction timing - Add better pool management
This commit is contained in:
@@ -4,25 +4,96 @@ import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator, Optional, Dict, Any, TypedDict, ClassVar, Union
|
||||
from enum import Enum, auto
|
||||
import threading
|
||||
from queue import Queue, Empty
|
||||
from datetime import datetime
|
||||
|
||||
from ..utils.exceptions import (
|
||||
DatabaseError,
|
||||
ErrorContext,
|
||||
ErrorSeverity
|
||||
)
|
||||
|
||||
logger = logging.getLogger("DBConnectionManager")
|
||||
|
||||
class ConnectionState(Enum):
|
||||
"""Connection states"""
|
||||
AVAILABLE = auto()
|
||||
IN_USE = auto()
|
||||
CLOSED = auto()
|
||||
ERROR = auto()
|
||||
|
||||
class ConnectionStatus(TypedDict):
|
||||
"""Type definition for connection status"""
|
||||
state: str
|
||||
created_at: str
|
||||
last_used: str
|
||||
error: Optional[str]
|
||||
transaction_count: int
|
||||
pool_size: int
|
||||
available_connections: int
|
||||
|
||||
class ConnectionMetrics(TypedDict):
|
||||
"""Type definition for connection metrics"""
|
||||
total_connections: int
|
||||
active_connections: int
|
||||
idle_connections: int
|
||||
failed_connections: int
|
||||
total_transactions: int
|
||||
failed_transactions: int
|
||||
average_transaction_time: float
|
||||
|
||||
class ConnectionInfo:
|
||||
"""Tracks connection information"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_used = self.created_at
|
||||
self.transaction_count = 0
|
||||
self.error_count = 0
|
||||
self.total_transaction_time = 0.0
|
||||
self.state = ConnectionState.AVAILABLE
|
||||
|
||||
def update_usage(self) -> None:
|
||||
"""Update connection usage statistics"""
|
||||
self.last_used = datetime.utcnow()
|
||||
self.transaction_count += 1
|
||||
|
||||
def record_error(self) -> None:
|
||||
"""Record a connection error"""
|
||||
self.error_count += 1
|
||||
self.state = ConnectionState.ERROR
|
||||
|
||||
def get_average_transaction_time(self) -> float:
|
||||
"""Get average transaction time"""
|
||||
if self.transaction_count == 0:
|
||||
return 0.0
|
||||
return self.total_transaction_time / self.transaction_count
|
||||
|
||||
class DatabaseConnectionManager:
|
||||
"""Manages SQLite database connections and connection pooling"""
|
||||
|
||||
def __init__(self, db_path: Path, pool_size: int = 5):
|
||||
"""Initialize the connection manager
|
||||
DEFAULT_POOL_SIZE: ClassVar[int] = 5
|
||||
CONNECTION_TIMEOUT: ClassVar[float] = 30.0
|
||||
POOL_TIMEOUT: ClassVar[float] = 5.0
|
||||
|
||||
def __init__(self, db_path: Path, pool_size: int = DEFAULT_POOL_SIZE) -> None:
|
||||
"""
|
||||
Initialize the connection manager.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file
|
||||
pool_size: Maximum number of connections in the pool
|
||||
|
||||
Raises:
|
||||
DatabaseError: If initialization fails
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.pool_size = pool_size
|
||||
self._connection_pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size)
|
||||
self._connection_info: Dict[int, ConnectionInfo] = {}
|
||||
self._local = threading.local()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@@ -30,23 +101,46 @@ class DatabaseConnectionManager:
|
||||
self._initialize_pool()
|
||||
|
||||
def _initialize_pool(self) -> None:
|
||||
"""Initialize the connection pool"""
|
||||
"""
|
||||
Initialize the connection pool.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If pool initialization fails
|
||||
"""
|
||||
try:
|
||||
for _ in range(self.pool_size):
|
||||
conn = self._create_connection()
|
||||
if conn:
|
||||
self._connection_pool.put(conn)
|
||||
self._connection_info[id(conn)] = ConnectionInfo()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing connection pool: {e}")
|
||||
raise
|
||||
error = f"Failed to initialize connection pool: {str(e)}"
|
||||
logger.error(error, exc_info=True)
|
||||
raise DatabaseError(
|
||||
error,
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"initialize_pool",
|
||||
{"pool_size": self.pool_size},
|
||||
ErrorSeverity.CRITICAL
|
||||
)
|
||||
)
|
||||
|
||||
def _create_connection(self) -> Optional[sqlite3.Connection]:
|
||||
"""Create a new database connection with proper settings"""
|
||||
"""
|
||||
Create a new database connection with proper settings.
|
||||
|
||||
Returns:
|
||||
New database connection or None if creation fails
|
||||
|
||||
Raises:
|
||||
DatabaseError: If connection creation fails
|
||||
"""
|
||||
try:
|
||||
conn = sqlite3.connect(
|
||||
self.db_path,
|
||||
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES,
|
||||
timeout=30.0 # 30 second timeout
|
||||
timeout=self.CONNECTION_TIMEOUT
|
||||
)
|
||||
|
||||
# Enable foreign keys
|
||||
@@ -64,20 +158,31 @@ class DatabaseConnectionManager:
|
||||
return conn
|
||||
|
||||
except sqlite3.Error as e:
|
||||
logger.error(f"Error creating database connection: {e}")
|
||||
return None
|
||||
error = f"Failed to create database connection: {str(e)}"
|
||||
logger.error(error, exc_info=True)
|
||||
raise DatabaseError(
|
||||
error,
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"create_connection",
|
||||
{"path": str(self.db_path)},
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Get a database connection from the pool
|
||||
"""
|
||||
Get a database connection from the pool.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: A database connection
|
||||
Database connection
|
||||
|
||||
Raises:
|
||||
sqlite3.Error: If unable to get a connection
|
||||
DatabaseError: If unable to get a connection
|
||||
"""
|
||||
conn = None
|
||||
start_time = datetime.utcnow()
|
||||
try:
|
||||
# Check if we have a transaction-bound connection
|
||||
conn = getattr(self._local, 'transaction_connection', None)
|
||||
@@ -87,59 +192,118 @@ class DatabaseConnectionManager:
|
||||
|
||||
# Get connection from pool or create new one
|
||||
try:
|
||||
conn = self._connection_pool.get(timeout=5.0)
|
||||
conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
|
||||
except Empty:
|
||||
logger.warning("Connection pool exhausted, creating new connection")
|
||||
conn = self._create_connection()
|
||||
if not conn:
|
||||
raise sqlite3.Error("Failed to create database connection")
|
||||
raise DatabaseError(
|
||||
"Failed to create database connection",
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"get_connection",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
# Update connection info
|
||||
conn_info = self._connection_info.get(id(conn))
|
||||
if conn_info:
|
||||
conn_info.update_usage()
|
||||
conn_info.state = ConnectionState.IN_USE
|
||||
|
||||
yield conn
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting database connection: {e}")
|
||||
error = f"Error getting database connection: {str(e)}"
|
||||
logger.error(error, exc_info=True)
|
||||
if conn:
|
||||
try:
|
||||
conn.rollback()
|
||||
if id(conn) in self._connection_info:
|
||||
self._connection_info[id(conn)].record_error()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
raise DatabaseError(
|
||||
error,
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"get_connection",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
if conn and not hasattr(self._local, 'transaction_connection'):
|
||||
try:
|
||||
conn.rollback() # Reset connection state
|
||||
self._connection_pool.put(conn)
|
||||
|
||||
# Update connection info
|
||||
if id(conn) in self._connection_info:
|
||||
conn_info = self._connection_info[id(conn)]
|
||||
conn_info.state = ConnectionState.AVAILABLE
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
conn_info.total_transaction_time += duration
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error returning connection to pool: {e}")
|
||||
try:
|
||||
conn.close()
|
||||
if id(conn) in self._connection_info:
|
||||
self._connection_info[id(conn)].state = ConnectionState.CLOSED
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Start a database transaction
|
||||
"""
|
||||
Start a database transaction.
|
||||
|
||||
Yields:
|
||||
sqlite3.Connection: A database connection for the transaction
|
||||
Database connection for the transaction
|
||||
|
||||
Raises:
|
||||
sqlite3.Error: If unable to start transaction
|
||||
DatabaseError: If unable to start transaction
|
||||
"""
|
||||
if hasattr(self._local, 'transaction_connection'):
|
||||
raise sqlite3.Error("Nested transactions are not supported")
|
||||
raise DatabaseError(
|
||||
"Nested transactions are not supported",
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"transaction",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
conn = None
|
||||
start_time = datetime.utcnow()
|
||||
try:
|
||||
# Get connection from pool
|
||||
try:
|
||||
conn = self._connection_pool.get(timeout=5.0)
|
||||
conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
|
||||
except Empty:
|
||||
logger.warning("Connection pool exhausted, creating new connection")
|
||||
conn = self._create_connection()
|
||||
if not conn:
|
||||
raise sqlite3.Error("Failed to create database connection")
|
||||
raise DatabaseError(
|
||||
"Failed to create database connection",
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"transaction",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
# Update connection info
|
||||
if id(conn) in self._connection_info:
|
||||
conn_info = self._connection_info[id(conn)]
|
||||
conn_info.update_usage()
|
||||
conn_info.state = ConnectionState.IN_USE
|
||||
|
||||
# Bind connection to current thread
|
||||
self._local.transaction_connection = conn
|
||||
@@ -153,13 +317,24 @@ class DatabaseConnectionManager:
|
||||
conn.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in database transaction: {e}")
|
||||
error = f"Error in database transaction: {str(e)}"
|
||||
logger.error(error, exc_info=True)
|
||||
if conn:
|
||||
try:
|
||||
conn.rollback()
|
||||
if id(conn) in self._connection_info:
|
||||
self._connection_info[id(conn)].record_error()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
raise DatabaseError(
|
||||
error,
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"transaction",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
if conn:
|
||||
@@ -169,22 +344,122 @@ class DatabaseConnectionManager:
|
||||
|
||||
# Return connection to pool
|
||||
self._connection_pool.put(conn)
|
||||
|
||||
# Update connection info
|
||||
if id(conn) in self._connection_info:
|
||||
conn_info = self._connection_info[id(conn)]
|
||||
conn_info.state = ConnectionState.AVAILABLE
|
||||
duration = (datetime.utcnow() - start_time).total_seconds()
|
||||
conn_info.total_transaction_time += duration
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up transaction: {e}")
|
||||
try:
|
||||
conn.close()
|
||||
if id(conn) in self._connection_info:
|
||||
self._connection_info[id(conn)].state = ConnectionState.CLOSED
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""Close all connections in the pool"""
|
||||
"""
|
||||
Close all connections in the pool.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If cleanup fails
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
while not self._connection_pool.empty():
|
||||
try:
|
||||
conn = self._connection_pool.get_nowait()
|
||||
try:
|
||||
conn.close()
|
||||
if id(conn) in self._connection_info:
|
||||
self._connection_info[id(conn)].state = ConnectionState.CLOSED
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection: {e}")
|
||||
except Empty:
|
||||
break
|
||||
except Exception as e:
|
||||
error = f"Failed to close all connections: {str(e)}"
|
||||
logger.error(error, exc_info=True)
|
||||
raise DatabaseError(
|
||||
error,
|
||||
context=ErrorContext(
|
||||
"ConnectionManager",
|
||||
"close_all",
|
||||
None,
|
||||
ErrorSeverity.HIGH
|
||||
)
|
||||
)
|
||||
|
||||
def get_status(self) -> ConnectionStatus:
|
||||
"""
|
||||
Get current connection manager status.
|
||||
|
||||
Returns:
|
||||
Connection status information
|
||||
"""
|
||||
active_connections = sum(
|
||||
1 for info in self._connection_info.values()
|
||||
if info.state == ConnectionState.IN_USE
|
||||
)
|
||||
|
||||
return ConnectionStatus(
|
||||
state="healthy" if active_connections < self.pool_size else "exhausted",
|
||||
created_at=min(
|
||||
info.created_at.isoformat()
|
||||
for info in self._connection_info.values()
|
||||
),
|
||||
last_used=max(
|
||||
info.last_used.isoformat()
|
||||
for info in self._connection_info.values()
|
||||
),
|
||||
error=None,
|
||||
transaction_count=sum(
|
||||
info.transaction_count
|
||||
for info in self._connection_info.values()
|
||||
),
|
||||
pool_size=self.pool_size,
|
||||
available_connections=self.pool_size - active_connections
|
||||
)
|
||||
|
||||
def get_metrics(self) -> ConnectionMetrics:
|
||||
"""
|
||||
Get connection metrics.
|
||||
|
||||
Returns:
|
||||
Connection metrics information
|
||||
"""
|
||||
total_transactions = sum(
|
||||
info.transaction_count
|
||||
for info in self._connection_info.values()
|
||||
)
|
||||
total_errors = sum(
|
||||
info.error_count
|
||||
for info in self._connection_info.values()
|
||||
)
|
||||
total_time = sum(
|
||||
info.total_transaction_time
|
||||
for info in self._connection_info.values()
|
||||
)
|
||||
|
||||
return ConnectionMetrics(
|
||||
total_connections=len(self._connection_info),
|
||||
active_connections=sum(
|
||||
1 for info in self._connection_info.values()
|
||||
if info.state == ConnectionState.IN_USE
|
||||
),
|
||||
idle_connections=sum(
|
||||
1 for info in self._connection_info.values()
|
||||
if info.state == ConnectionState.AVAILABLE
|
||||
),
|
||||
failed_connections=sum(
|
||||
1 for info in self._connection_info.values()
|
||||
if info.state == ConnectionState.ERROR
|
||||
),
|
||||
total_transactions=total_transactions,
|
||||
failed_transactions=total_errors,
|
||||
average_transaction_time=total_time / total_transactions if total_transactions > 0 else 0.0
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user