Improve database connection management

- Add proper error handling with DatabaseError
- Add proper error context support
- Add better connection state management
- Add proper type hints
- Add proper docstrings
- Add better error recovery
- Add proper error context creation
- Add better metrics tracking
- Add proper status reporting
- Add better connection validation
- Add proper logging
- Add better error messages
- Add proper transaction timing
- Add better pool management
This commit is contained in:
pacnpal
2024-11-16 21:22:00 +00:00
parent 775781b325
commit 6f545ef7fd

View File

@@ -4,25 +4,96 @@ import logging
import sqlite3
from pathlib import Path
from contextlib import contextmanager
from typing import Generator, Optional
from typing import Generator, Optional, Dict, Any, TypedDict, ClassVar, Union
from enum import Enum, auto
import threading
from queue import Queue, Empty
from datetime import datetime
from ..utils.exceptions import (
DatabaseError,
ErrorContext,
ErrorSeverity
)
logger = logging.getLogger("DBConnectionManager")
class ConnectionState(Enum):
"""Connection states"""
AVAILABLE = auto()
IN_USE = auto()
CLOSED = auto()
ERROR = auto()
class ConnectionStatus(TypedDict):
"""Type definition for connection status"""
state: str
created_at: str
last_used: str
error: Optional[str]
transaction_count: int
pool_size: int
available_connections: int
class ConnectionMetrics(TypedDict):
"""Type definition for connection metrics"""
total_connections: int
active_connections: int
idle_connections: int
failed_connections: int
total_transactions: int
failed_transactions: int
average_transaction_time: float
class ConnectionInfo:
"""Tracks connection information"""
def __init__(self) -> None:
self.created_at = datetime.utcnow()
self.last_used = self.created_at
self.transaction_count = 0
self.error_count = 0
self.total_transaction_time = 0.0
self.state = ConnectionState.AVAILABLE
def update_usage(self) -> None:
"""Update connection usage statistics"""
self.last_used = datetime.utcnow()
self.transaction_count += 1
def record_error(self) -> None:
"""Record a connection error"""
self.error_count += 1
self.state = ConnectionState.ERROR
def get_average_transaction_time(self) -> float:
"""Get average transaction time"""
if self.transaction_count == 0:
return 0.0
return self.total_transaction_time / self.transaction_count
class DatabaseConnectionManager:
"""Manages SQLite database connections and connection pooling"""
def __init__(self, db_path: Path, pool_size: int = 5):
"""Initialize the connection manager
DEFAULT_POOL_SIZE: ClassVar[int] = 5
CONNECTION_TIMEOUT: ClassVar[float] = 30.0
POOL_TIMEOUT: ClassVar[float] = 5.0
def __init__(self, db_path: Path, pool_size: int = DEFAULT_POOL_SIZE) -> None:
"""
Initialize the connection manager.
Args:
db_path: Path to the SQLite database file
pool_size: Maximum number of connections in the pool
Raises:
DatabaseError: If initialization fails
"""
self.db_path = db_path
self.pool_size = pool_size
self._connection_pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size)
self._connection_info: Dict[int, ConnectionInfo] = {}
self._local = threading.local()
self._lock = threading.Lock()
@@ -30,23 +101,46 @@ class DatabaseConnectionManager:
self._initialize_pool()
def _initialize_pool(self) -> None:
"""Initialize the connection pool"""
"""
Initialize the connection pool.
Raises:
DatabaseError: If pool initialization fails
"""
try:
for _ in range(self.pool_size):
conn = self._create_connection()
if conn:
self._connection_pool.put(conn)
self._connection_info[id(conn)] = ConnectionInfo()
except Exception as e:
logger.error(f"Error initializing connection pool: {e}")
raise
error = f"Failed to initialize connection pool: {str(e)}"
logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"initialize_pool",
{"pool_size": self.pool_size},
ErrorSeverity.CRITICAL
)
)
def _create_connection(self) -> Optional[sqlite3.Connection]:
"""Create a new database connection with proper settings"""
"""
Create a new database connection with proper settings.
Returns:
New database connection or None if creation fails
Raises:
DatabaseError: If connection creation fails
"""
try:
conn = sqlite3.connect(
self.db_path,
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES,
timeout=30.0 # 30 second timeout
timeout=self.CONNECTION_TIMEOUT
)
# Enable foreign keys
@@ -64,20 +158,31 @@ class DatabaseConnectionManager:
return conn
except sqlite3.Error as e:
logger.error(f"Error creating database connection: {e}")
return None
error = f"Failed to create database connection: {str(e)}"
logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"create_connection",
{"path": str(self.db_path)},
ErrorSeverity.HIGH
)
)
@contextmanager
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
"""Get a database connection from the pool
"""
Get a database connection from the pool.
Yields:
sqlite3.Connection: A database connection
Database connection
Raises:
sqlite3.Error: If unable to get a connection
DatabaseError: If unable to get a connection
"""
conn = None
start_time = datetime.utcnow()
try:
# Check if we have a transaction-bound connection
conn = getattr(self._local, 'transaction_connection', None)
@@ -87,59 +192,118 @@ class DatabaseConnectionManager:
# Get connection from pool or create new one
try:
conn = self._connection_pool.get(timeout=5.0)
conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
except Empty:
logger.warning("Connection pool exhausted, creating new connection")
conn = self._create_connection()
if not conn:
raise sqlite3.Error("Failed to create database connection")
raise DatabaseError(
"Failed to create database connection",
context=ErrorContext(
"ConnectionManager",
"get_connection",
None,
ErrorSeverity.HIGH
)
)
# Update connection info
conn_info = self._connection_info.get(id(conn))
if conn_info:
conn_info.update_usage()
conn_info.state = ConnectionState.IN_USE
yield conn
except Exception as e:
logger.error(f"Error getting database connection: {e}")
error = f"Error getting database connection: {str(e)}"
logger.error(error, exc_info=True)
if conn:
try:
conn.rollback()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].record_error()
except Exception:
pass
raise
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"get_connection",
None,
ErrorSeverity.HIGH
)
)
finally:
if conn and not hasattr(self._local, 'transaction_connection'):
try:
conn.rollback() # Reset connection state
self._connection_pool.put(conn)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.state = ConnectionState.AVAILABLE
duration = (datetime.utcnow() - start_time).total_seconds()
conn_info.total_transaction_time += duration
except Exception as e:
logger.error(f"Error returning connection to pool: {e}")
try:
conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception:
pass
@contextmanager
def transaction(self) -> Generator[sqlite3.Connection, None, None]:
"""Start a database transaction
"""
Start a database transaction.
Yields:
sqlite3.Connection: A database connection for the transaction
Database connection for the transaction
Raises:
sqlite3.Error: If unable to start transaction
DatabaseError: If unable to start transaction
"""
if hasattr(self._local, 'transaction_connection'):
raise sqlite3.Error("Nested transactions are not supported")
raise DatabaseError(
"Nested transactions are not supported",
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
conn = None
start_time = datetime.utcnow()
try:
# Get connection from pool
try:
conn = self._connection_pool.get(timeout=5.0)
conn = self._connection_pool.get(timeout=self.POOL_TIMEOUT)
except Empty:
logger.warning("Connection pool exhausted, creating new connection")
conn = self._create_connection()
if not conn:
raise sqlite3.Error("Failed to create database connection")
raise DatabaseError(
"Failed to create database connection",
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.update_usage()
conn_info.state = ConnectionState.IN_USE
# Bind connection to current thread
self._local.transaction_connection = conn
@@ -153,13 +317,24 @@ class DatabaseConnectionManager:
conn.commit()
except Exception as e:
logger.error(f"Error in database transaction: {e}")
error = f"Error in database transaction: {str(e)}"
logger.error(error, exc_info=True)
if conn:
try:
conn.rollback()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].record_error()
except Exception:
pass
raise
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"transaction",
None,
ErrorSeverity.HIGH
)
)
finally:
if conn:
@@ -169,22 +344,122 @@ class DatabaseConnectionManager:
# Return connection to pool
self._connection_pool.put(conn)
# Update connection info
if id(conn) in self._connection_info:
conn_info = self._connection_info[id(conn)]
conn_info.state = ConnectionState.AVAILABLE
duration = (datetime.utcnow() - start_time).total_seconds()
conn_info.total_transaction_time += duration
except Exception as e:
logger.error(f"Error cleaning up transaction: {e}")
try:
conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception:
pass
def close_all(self) -> None:
"""Close all connections in the pool"""
"""
Close all connections in the pool.
Raises:
DatabaseError: If cleanup fails
"""
with self._lock:
try:
while not self._connection_pool.empty():
try:
conn = self._connection_pool.get_nowait()
try:
conn.close()
if id(conn) in self._connection_info:
self._connection_info[id(conn)].state = ConnectionState.CLOSED
except Exception as e:
logger.error(f"Error closing connection: {e}")
except Empty:
break
except Exception as e:
error = f"Failed to close all connections: {str(e)}"
logger.error(error, exc_info=True)
raise DatabaseError(
error,
context=ErrorContext(
"ConnectionManager",
"close_all",
None,
ErrorSeverity.HIGH
)
)
def get_status(self) -> ConnectionStatus:
"""
Get current connection manager status.
Returns:
Connection status information
"""
active_connections = sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.IN_USE
)
return ConnectionStatus(
state="healthy" if active_connections < self.pool_size else "exhausted",
created_at=min(
info.created_at.isoformat()
for info in self._connection_info.values()
),
last_used=max(
info.last_used.isoformat()
for info in self._connection_info.values()
),
error=None,
transaction_count=sum(
info.transaction_count
for info in self._connection_info.values()
),
pool_size=self.pool_size,
available_connections=self.pool_size - active_connections
)
def get_metrics(self) -> ConnectionMetrics:
"""
Get connection metrics.
Returns:
Connection metrics information
"""
total_transactions = sum(
info.transaction_count
for info in self._connection_info.values()
)
total_errors = sum(
info.error_count
for info in self._connection_info.values()
)
total_time = sum(
info.total_transaction_time
for info in self._connection_info.values()
)
return ConnectionMetrics(
total_connections=len(self._connection_info),
active_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.IN_USE
),
idle_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.AVAILABLE
),
failed_connections=sum(
1 for info in self._connection_info.values()
if info.state == ConnectionState.ERROR
),
total_transactions=total_transactions,
failed_transactions=total_errors,
average_transaction_time=total_time / total_transactions if total_transactions > 0 else 0.0
)