"""MySQL database connector.""" import pymysql from typing import List, Dict, Any, Optional, Generator from config import get_settings from src.utils.logger import get_logger logger = get_logger(__name__) class MySQLConnector: """Connector for MySQL database.""" def __init__(self): """Initialize MySQL connector with settings.""" self.settings = get_settings() self.connection = None def connect(self) -> None: """Establish connection to MySQL database.""" try: self.connection = pymysql.connect( host=self.settings.mysql.host, port=self.settings.mysql.port, user=self.settings.mysql.user, password=self.settings.mysql.password, database=self.settings.mysql.database, charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor, read_timeout=300, # 5 minutes read timeout write_timeout=300, # 5 minutes write timeout max_allowed_packet=67108864, # 64MB max packet ) logger.info( f"Connected to MySQL: {self.settings.mysql.host}:" f"{self.settings.mysql.port}/{self.settings.mysql.database}" ) except pymysql.Error as e: logger.error(f"Failed to connect to MySQL: {e}") raise def disconnect(self) -> None: """Close connection to MySQL database.""" if self.connection: self.connection.close() logger.info("Disconnected from MySQL") def __enter__(self): """Context manager entry.""" self.connect() return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.disconnect() def get_row_count(self, table: str) -> int: """Get total row count for a table. Args: table: Table name Returns: Number of rows in the table """ try: with self.connection.cursor() as cursor: cursor.execute(f"SELECT COUNT(*) as count FROM `{table}`") result = cursor.fetchone() return result["count"] except pymysql.Error as e: logger.error(f"Failed to get row count for {table}: {e}") raise def fetch_all_rows( self, table: str, batch_size: Optional[int] = None ) -> Generator[List[Dict[str, Any]], None, None]: """Fetch all rows from a table in batches. Args: table: Table name batch_size: Number of rows per batch (uses config default if None) Yields: Batches of row dictionaries """ if batch_size is None: batch_size = self.settings.migration.batch_size offset = 0 max_retries = 3 while True: retries = 0 while retries < max_retries: try: with self.connection.cursor() as cursor: query = f"SELECT * FROM `{table}` LIMIT %s OFFSET %s" cursor.execute(query, (batch_size, offset)) rows = cursor.fetchall() if not rows: return yield rows offset += len(rows) break # Success, exit retry loop except pymysql.Error as e: retries += 1 if retries >= max_retries: logger.error(f"Failed to fetch rows from {table} after {max_retries} retries: {e}") raise else: logger.warning(f"Fetch failed (retry {retries}/{max_retries}): {e}") # Reconnect and retry try: self.disconnect() self.connect() except Exception as reconnect_error: logger.error(f"Failed to reconnect: {reconnect_error}") raise def fetch_rows_since( self, table: str, since_timestamp: str, batch_size: Optional[int] = None ) -> Generator[List[Dict[str, Any]], None, None]: """Fetch rows modified since a timestamp. Args: table: Table name since_timestamp: ISO format timestamp (e.g., '2024-01-01T00:00:00') batch_size: Number of rows per batch (uses config default if None) Yields: Batches of row dictionaries """ if batch_size is None: batch_size = self.settings.migration.batch_size offset = 0 timestamp_col = "updated_at" if table == "ELABDATADISP" else "created_at" while True: try: with self.connection.cursor() as cursor: query = ( f"SELECT * FROM `{table}` " f"WHERE `{timestamp_col}` > %s " f"ORDER BY `{timestamp_col}` ASC " f"LIMIT %s OFFSET %s" ) cursor.execute(query, (since_timestamp, batch_size, offset)) rows = cursor.fetchall() if not rows: break yield rows offset += len(rows) except pymysql.Error as e: logger.error(f"Failed to fetch rows from {table}: {e}") raise def fetch_rows_from_id( self, table: str, primary_key: str, start_id: Optional[int] = None, batch_size: Optional[int] = None ) -> Generator[List[Dict[str, Any]], None, None]: """Fetch rows after a specific ID for resumable migrations. Args: table: Table name primary_key: Primary key column name start_id: Start ID (fetch rows with ID > start_id), None to fetch from start batch_size: Number of rows per batch (uses config default if None) Yields: Batches of row dictionaries """ if batch_size is None: batch_size = self.settings.migration.batch_size offset = 0 while True: try: with self.connection.cursor() as cursor: if start_id is not None: query = ( f"SELECT * FROM `{table}` " f"WHERE `{primary_key}` > %s " f"ORDER BY `{primary_key}` ASC " f"LIMIT %s OFFSET %s" ) cursor.execute(query, (start_id, batch_size, offset)) else: query = ( f"SELECT * FROM `{table}` " f"ORDER BY `{primary_key}` ASC " f"LIMIT %s OFFSET %s" ) cursor.execute(query, (batch_size, offset)) rows = cursor.fetchall() if not rows: break yield rows offset += len(rows) except pymysql.Error as e: logger.error(f"Failed to fetch rows from {table}: {e}") raise def get_table_structure(self, table: str) -> Dict[str, Any]: """Get table structure (column info). Args: table: Table name Returns: Dictionary with column information """ try: with self.connection.cursor() as cursor: cursor.execute(f"DESCRIBE `{table}`") columns = cursor.fetchall() return {col["Field"]: col for col in columns} except pymysql.Error as e: logger.error(f"Failed to get structure for {table}: {e}") raise