Improve database connection management

- Add proper error handling with DatabaseError
- Add proper error context support
- Add better connection state management
- Add proper type hints
- Add proper docstrings
- Add better error recovery
- Add proper error context creation
- Add better metrics tracking
- Add proper status reporting
- Add better connection validation
- Add proper logging
- Add better error messages
- Add proper transaction timing
- Add better pool management
This commit is contained in:
pacnpal
2024-11-16 21:22:00 +00:00
parent 775781b325
commit 6f545ef7fd

View File

@@ -4,25 +4,96 @@ import logging
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from typing import Generator, Optional from typing import Generator, Optional, Dict, Any, TypedDict, ClassVar, Union
from enum import Enum, auto
import threading import threading
from queue import Queue, Empty from queue import Queue, Empty
from datetime import datetime
from ..utils.exceptions import (
DatabaseError,
ErrorContext,
ErrorSeverity
)
logger = logging.getLogger("DBConnectionManager") logger = logging.getLogger("DBConnectionManager")
class ConnectionState(Enum):
"""Connection states"""
AVAILABLE = auto()
IN_USE = auto()
CLOSED = auto()
ERROR = auto()
class ConnectionStatus(TypedDict):
"""Type definition for connection status"""
state: str
created_at: str
last_used: str
error: Optional[str]
transaction_count: int
pool_size: int
available_connections: int
class ConnectionMetrics(TypedDict):
"""Type definition for connection metrics"""
total_connections: int
active_connections: int
idle_connections: int
failed_connections: int
total_transactions: int
failed_transactions: int
average_transaction_time: float
class ConnectionInfo:
"""Tracks connection information"""
def __init__(self) -> None:
self.created_at = datetime.utcnow()
self.last_used = self.created_at
self.transaction_count = 0
self.error_count = 0
self.total_transaction_time = 0.0
self.state = ConnectionState.AVAILABLE
def update_usage(self) -> None:
"""Update connection usage statistics"""
self.last_used = datetime.utcnow()
self.transaction_count += 1
def record_error(self) -> None:
"""Record a connection error"""
self.error_count += 1
self.state = ConnectionState.ERROR
def get_average_transaction_time(self) -> float:
"""Get average transaction time"""
if self.transaction_count == 0:
return 0.0
return self.total_transaction_time / self.transaction_count
class DatabaseConnectionManager: class DatabaseConnectionManager:
"""Manages SQLite database connections and connection pooling""" """Manages SQLite database connections and connection pooling"""
def __init__(self, db_path: Path, pool_size: int = 5): DEFAULT_POOL_SIZE: ClassVar[int] = 5
"""Initialize the connection manager CONNECTION_TIMEOUT: ClassVar[float] = 30.0
POOL_TIMEOUT: ClassVar[float] = 5.0
def __init__(self, db_path: Path, pool_size: int = DEFAULT_POOL_SIZE) -> None:
"""
Initialize the connection manager.
Args: Args:
db_path: Path to the SQLite database file db_path: Path to the SQLite database file
pool_size: Maximum number of connections in the pool pool_size: Maximum number of connections in the pool
Raises:
DatabaseError: If initialization fails
""" """
self.db_path = db_path self.db_path = db_path
self.pool_size = pool_size self.pool_size = pool_size
self._connection_pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size) self._connection_pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size)
self._connection_info: Dict[int, ConnectionInfo] = {}
self._local = threading.local() self._local = threading.local()
self._lock = threading.Lock() self._lock = threading.Lock()
@@ -30,23 +101,46 @@ class DatabaseConnectionManager:
self._initialize_pool() self._initialize_pool()
def _initialize_pool(self) -> None: def _initialize_pool(self) -> None:
"""Initialize the connection pool""" """
Initialize the connection pool.
Raises:
DatabaseError: If pool initialization fails
"""
try: try:
for _ in range(self.pool_size): for _ in range(self.pool_size):
conn = self._create_connection() conn = self._create_connection()
if conn: if conn:
self._connection_pool.put(conn) self._connection_pool.put(conn)
self._connection_info[id(conn)] = ConnectionInfo()
except Exception as e: except Exception as e:
logger.error(f"Error initializing connection pool: {e}") error = f"Failed to initialize connection pool: {str(e)}"
raise logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"initialize_pool",
{"pool_size": self.pool_size},
ErrorSeverity.CRITICAL
)
)
def _create_connection(self) -> Optional[sqlite3.Connection]: def _create_connection(self) -> Optional[sqlite3.Connection]:
"""Create a new database connection with proper settings""" """
Create a new database connection with proper settings.
Returns:
New database connection or None if creation fails
Raises:
DatabaseError: If connection creation fails
"""
try: try:
conn = sqlite3.connect( conn = sqlite3.connect(
self.db_path, self.db_path,
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES,
timeout=30.0 # 30 second timeout timeout=self.CONNECTION_TIMEOUT
) )
# Enable foreign keys # Enable foreign keys
@@ -64,20 +158,31 @@ class DatabaseConnectionManager:
return conn return conn
except sqlite3.Error as e: except sqlite3.Error as e:
logger.error(f"Error creating database connection: {e}") error = f"Failed to create database connection: {str(e)}"
return None logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"create_connection",
{"path": str(self.db_path)},
ErrorSeverity.HIGH
)
)
@contextmanager @contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]: def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""Get a database connection from the pool """
Get a database connection from the pool.
Yields: Yields:
sqlite3.Connection: A database connection Database connection
Raises: Raises:
sqlite3.Error: If unable to get a connection DatabaseError: If unable to get a connection
""" """
conn = None conn = None
start_time = datetime.utcnow()
try: try:
# Check if we have a transaction-bound connection # Check if we have a transaction-bound connection
conn = getattr(self._local, 'transaction_connection', None) conn = getattr(self._local, 'transaction_connection', None)
@@ -87,59 +192,118 @@ class DatabaseConnectionManager:
# Get connection from pool or create new one # Get connection from pool or create new one
try: try:
conn = self._connection_pool.get(timeout=5.0) conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
except Empty: except Empty:
logger.warning("Connection pool exhausted, creating new connection") logger.warning("Connection pool exhausted, creating new connection")
conn = self._create_connection() conn = self._create_connection()
if not conn: if not conn:
raise sqlite3.Error("Failed to create database connection") raise DatabaseError(
"Failed to create database connection",
context=ErrorContext(
"ConnectionManager",
"get_connection",
None,
ErrorSeverity.HIGH
)
)
# Update connection info
conn_info = self._connection_info.get(id(conn))
if conn_info:
conn_info.update_usage()
conn_info.state = ConnectionState.IN_USE
yield conn yield conn
except Exception as e: except Exception as e:
logger.error(f"Error getting database connection: {e}") error = f"Error getting database connection: {str(e)}"
logger.error(error, exc_info=True)
if conn: if conn:
try: try:
conn.rollback() conn.rollback()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].record_error()
except Exception: except Exception:
pass pass
raise raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"get_connection",
None,
ErrorSeverity.HIGH
)
)
finally: finally:
if conn and not hasattr(self._local, 'transaction_connection'): if conn and not hasattr(self._local, 'transaction_connection'):
try: try:
conn.rollback() # Reset connection state conn.rollback() # Reset connection state
self._connection_pool.put(conn) self._connection_pool.put(conn)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.state = ConnectionState.AVAILABLE
duration = (datetime.utcnow() - start_time).total_seconds()
conn_info.total_transaction_time += duration
except Exception as e: except Exception as e:
logger.error(f"Error returning connection to pool: {e}") logger.error(f"Error returning connection to pool: {e}")
try: try:
conn.close() conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception: except Exception:
pass pass
@contextmanager @contextmanager
def transaction(self) -> Generator[sqlite3.Connection, None, None]: def transaction(self) -> Generator[sqlite3.Connection, None, None]:
"""Start a database transaction """
Start a database transaction.
Yields: Yields:
sqlite3.Connection: A database connection for the transaction Database connection for the transaction
Raises: Raises:
sqlite3.Error: If unable to start transaction DatabaseError: If unable to start transaction
""" """
if hasattr(self._local, 'transaction_connection'): if hasattr(self._local, 'transaction_connection'):
raise sqlite3.Error("Nested transactions are not supported") raise DatabaseError(
"Nested transactions are not supported",
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
conn = None conn = None
start_time = datetime.utcnow()
try: try:
# Get connection from pool # Get connection from pool
try: try:
conn = self._connection_pool.get(timeout=5.0) conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
except Empty: except Empty:
logger.warning("Connection pool exhausted, creating new connection") logger.warning("Connection pool exhausted, creating new connection")
conn = self._create_connection() conn = self._create_connection()
if not conn: if not conn:
raise sqlite3.Error("Failed to create database connection") raise DatabaseError(
"Failed to create database connection",
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.update_usage()
conn_info.state = ConnectionState.IN_USE
# Bind connection to current thread # Bind connection to current thread
self._local.transaction_connection = conn self._local.transaction_connection = conn
@@ -153,13 +317,24 @@ class DatabaseConnectionManager:
conn.commit() conn.commit()
except Exception as e: except Exception as e:
logger.error(f"Error in database transaction: {e}") error = f"Error in database transaction: {str(e)}"
logger.error(error, exc_info=True)
if conn: if conn:
try: try:
conn.rollback() conn.rollback()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].record_error()
except Exception: except Exception:
pass pass
raise raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
finally: finally:
if conn: if conn:
@@ -169,22 +344,122 @@ class DatabaseConnectionManager:
# Return connection to pool # Return connection to pool
self._connection_pool.put(conn) self._connection_pool.put(conn)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.state = ConnectionState.AVAILABLE
duration = (datetime.utcnow() - start_time).total_seconds()
conn_info.total_transaction_time += duration
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up transaction: {e}") logger.error(f"Error cleaning up transaction: {e}")
try: try:
conn.close() conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception: except Exception:
pass pass
def close_all(self) -> None: def close_all(self) -> None:
"""Close all connections in the pool""" """
Close all connections in the pool.
Raises:
DatabaseError: If cleanup fails
"""
with self._lock: with self._lock:
try:
while not self._connection_pool.empty(): while not self._connection_pool.empty():
try: try:
conn = self._connection_pool.get_nowait() conn = self._connection_pool.get_nowait()
try: try:
conn.close() conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception as e: except Exception as e:
logger.error(f"Error closing connection: {e}") logger.error(f"Error closing connection: {e}")
except Empty: except Empty:
break break
except Exception as e:
error = f"Failed to close all connections: {str(e)}"
logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"close_all",
None,
ErrorSeverity.HIGH
)
)
def get_status(self) -> ConnectionStatus:
"""
Get current connection manager status.
Returns:
Connection status information
"""
active_connections = sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.IN_USE
)
return ConnectionStatus(
state="healthy" if active_connections < self.pool_size else "exhausted",
created_at=min(
info.created_at.isoformat()
for info in self._connection_info.values()
),
last_used=max(
info.last_used.isoformat()
for info in self._connection_info.values()
),
error=None,
transaction_count=sum(
info.transaction_count
for info in self._connection_info.values()
),
pool_size=self.pool_size,
available_connections=self.pool_size - active_connections
)
def get_metrics(self) -> ConnectionMetrics:
"""
Get connection metrics.
Returns:
Connection metrics information
"""
total_transactions = sum(
info.transaction_count
for info in self._connection_info.values()
)
total_errors = sum(
info.error_count
for info in self._connection_info.values()
)
total_time = sum(
info.total_transaction_time
for info in self._connection_info.values()
)
return ConnectionMetrics(
total_connections=len(self._connection_info),
active_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.IN_USE
),
idle_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.AVAILABLE
),
failed_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.ERROR
),
total_transactions=total_transactions,
failed_transactions=total_errors,
average_transaction_time=total_time / total_transactions if total_transactions > 0 else 0.0
)