mirror of
https://github.com/pacnpal/Pac-cogs.git
synced 2025-12-20 10:51:05 -05:00
191 lines
6.4 KiB
Python
191 lines
6.4 KiB
Python
"""Module for managing database connections"""
|
|
|
|
import logging
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from contextlib import contextmanager
|
|
from typing import Generator, Optional
|
|
import threading
|
|
from queue import Queue, Empty
|
|
|
|
logger = logging.getLogger("DBConnectionManager")
|
|
|
|
class DatabaseConnectionManager:
|
|
"""Manages SQLite database connections and connection pooling"""
|
|
|
|
def __init__(self, db_path: Path, pool_size: int = 5):
|
|
"""Initialize the connection manager
|
|
|
|
Args:
|
|
db_path: Path to the SQLite database file
|
|
pool_size: Maximum number of connections in the pool
|
|
"""
|
|
self.db_path = db_path
|
|
self.pool_size = pool_size
|
|
self._connection_pool: Queue[sqlite3.Connection] = Queue(maxsize=pool_size)
|
|
self._local = threading.local()
|
|
self._lock = threading.Lock()
|
|
|
|
# Initialize connection pool
|
|
self._initialize_pool()
|
|
|
|
def _initialize_pool(self) -> None:
|
|
"""Initialize the connection pool"""
|
|
try:
|
|
for _ in range(self.pool_size):
|
|
conn = self._create_connection()
|
|
if conn:
|
|
self._connection_pool.put(conn)
|
|
except Exception as e:
|
|
logger.error(f"Error initializing connection pool: {e}")
|
|
raise
|
|
|
|
def _create_connection(self) -> Optional[sqlite3.Connection]:
|
|
"""Create a new database connection with proper settings"""
|
|
try:
|
|
conn = sqlite3.connect(
|
|
self.db_path,
|
|
detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES,
|
|
timeout=30.0 # 30 second timeout
|
|
)
|
|
|
|
# Enable foreign keys
|
|
conn.execute("PRAGMA foreign_keys = ON")
|
|
|
|
# Set journal mode to WAL for better concurrency
|
|
conn.execute("PRAGMA journal_mode = WAL")
|
|
|
|
# Set synchronous mode to NORMAL for better performance
|
|
conn.execute("PRAGMA synchronous = NORMAL")
|
|
|
|
# Enable extended result codes for better error handling
|
|
conn.execute("PRAGMA extended_result_codes = ON")
|
|
|
|
return conn
|
|
|
|
except sqlite3.Error as e:
|
|
logger.error(f"Error creating database connection: {e}")
|
|
return None
|
|
|
|
@contextmanager
|
|
def get_connection(self) -> Generator[sqlite3.Connection, None, None]:
|
|
"""Get a database connection from the pool
|
|
|
|
Yields:
|
|
sqlite3.Connection: A database connection
|
|
|
|
Raises:
|
|
sqlite3.Error: If unable to get a connection
|
|
"""
|
|
conn = None
|
|
try:
|
|
# Check if we have a transaction-bound connection
|
|
conn = getattr(self._local, 'transaction_connection', None)
|
|
if conn is not None:
|
|
yield conn
|
|
return
|
|
|
|
# Get connection from pool or create new one
|
|
try:
|
|
conn = self._connection_pool.get(timeout=5.0)
|
|
except Empty:
|
|
logger.warning("Connection pool exhausted, creating new connection")
|
|
conn = self._create_connection()
|
|
if not conn:
|
|
raise sqlite3.Error("Failed to create database connection")
|
|
|
|
yield conn
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting database connection: {e}")
|
|
if conn:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
finally:
|
|
if conn and not hasattr(self._local, 'transaction_connection'):
|
|
try:
|
|
conn.rollback() # Reset connection state
|
|
self._connection_pool.put(conn)
|
|
except Exception as e:
|
|
logger.error(f"Error returning connection to pool: {e}")
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
@contextmanager
|
|
def transaction(self) -> Generator[sqlite3.Connection, None, None]:
|
|
"""Start a database transaction
|
|
|
|
Yields:
|
|
sqlite3.Connection: A database connection for the transaction
|
|
|
|
Raises:
|
|
sqlite3.Error: If unable to start transaction
|
|
"""
|
|
if hasattr(self._local, 'transaction_connection'):
|
|
raise sqlite3.Error("Nested transactions are not supported")
|
|
|
|
conn = None
|
|
try:
|
|
# Get connection from pool
|
|
try:
|
|
conn = self._connection_pool.get(timeout=5.0)
|
|
except Empty:
|
|
logger.warning("Connection pool exhausted, creating new connection")
|
|
conn = self._create_connection()
|
|
if not conn:
|
|
raise sqlite3.Error("Failed to create database connection")
|
|
|
|
# Bind connection to current thread
|
|
self._local.transaction_connection = conn
|
|
|
|
# Start transaction
|
|
conn.execute("BEGIN")
|
|
|
|
yield conn
|
|
|
|
# Commit transaction
|
|
conn.commit()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in database transaction: {e}")
|
|
if conn:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
finally:
|
|
if conn:
|
|
try:
|
|
# Remove thread-local binding
|
|
delattr(self._local, 'transaction_connection')
|
|
|
|
# Return connection to pool
|
|
self._connection_pool.put(conn)
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up transaction: {e}")
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
def close_all(self) -> None:
|
|
"""Close all connections in the pool"""
|
|
with self._lock:
|
|
while not self._connection_pool.empty():
|
|
try:
|
|
conn = self._connection_pool.get_nowait()
|
|
try:
|
|
conn.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing connection: {e}")
|
|
except Empty:
|
|
break
|