"""Utility functions for refactored scripts.""" import asyncio import logging from datetime import datetime from typing import Any, Optional import aiomysql logger = logging.getLogger(__name__) async def get_db_connection(config: dict) -> aiomysql.Connection: """ Create an async database connection. Args: config: Database configuration dictionary Returns: aiomysql.Connection: Async database connection Raises: Exception: If connection fails """ try: conn = await aiomysql.connect(**config) logger.debug("Database connection established") return conn except Exception as e: logger.error(f"Failed to connect to database: {e}") raise async def execute_query( conn: aiomysql.Connection, query: str, params: tuple | list = None, fetch_one: bool = False, fetch_all: bool = False, ) -> Any | None: """ Execute a database query safely with proper error handling. Args: conn: Database connection query: SQL query string params: Query parameters fetch_one: Whether to fetch one result fetch_all: Whether to fetch all results Returns: Query results or None Raises: Exception: If query execution fails """ async with conn.cursor(aiomysql.DictCursor) as cursor: try: await cursor.execute(query, params or ()) if fetch_one: return await cursor.fetchone() elif fetch_all: return await cursor.fetchall() return None except Exception as e: logger.error(f"Query execution failed: {e}") logger.debug(f"Query: {query}") logger.debug(f"Params: {params}") raise async def execute_many(conn: aiomysql.Connection, query: str, params_list: list) -> int: """ Execute a query with multiple parameter sets (batch insert). Args: conn: Database connection query: SQL query string params_list: List of parameter tuples Returns: Number of affected rows Raises: Exception: If query execution fails """ if not params_list: logger.warning("execute_many called with empty params_list") return 0 async with conn.cursor() as cursor: try: await cursor.executemany(query, params_list) affected_rows = cursor.rowcount logger.debug(f"Batch insert completed: {affected_rows} rows affected") return affected_rows except Exception as e: logger.error(f"Batch query execution failed: {e}") logger.debug(f"Query: {query}") logger.debug(f"Number of parameter sets: {len(params_list)}") raise def parse_datetime(date_str: str, time_str: str = None) -> datetime: """ Parse date and optional time strings into datetime object. Args: date_str: Date string (various formats supported) time_str: Optional time string Returns: datetime object Examples: >>> parse_datetime("2024-10-11", "14:30:00") datetime(2024, 10, 11, 14, 30, 0) >>> parse_datetime("2024-10-11T14:30:00") datetime(2024, 10, 11, 14, 30, 0) """ # Handle ISO format with T separator if "T" in date_str: return datetime.fromisoformat(date_str.replace("T", " ")) # Handle separate date and time if time_str: return datetime.strptime(f"{date_str} {time_str}", "%Y-%m-%d %H:%M:%S") # Handle date only return datetime.strptime(date_str, "%Y-%m-%d") async def retry_on_failure( coro_func, max_retries: int = 3, delay: float = 1.0, backoff: float = 2.0, *args, **kwargs, ): """ Retry an async function on failure with exponential backoff. Args: coro_func: Async function to retry max_retries: Maximum number of retry attempts delay: Initial delay between retries (seconds) backoff: Backoff multiplier for delay *args: Arguments to pass to coro_func **kwargs: Keyword arguments to pass to coro_func Returns: Result from coro_func Raises: Exception: If all retries fail """ last_exception = None for attempt in range(max_retries): try: return await coro_func(*args, **kwargs) except Exception as e: last_exception = e if attempt < max_retries - 1: wait_time = delay * (backoff**attempt) logger.warning(f"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying in {wait_time}s...") await asyncio.sleep(wait_time) else: logger.error(f"All {max_retries} attempts failed") raise last_exception