179 lines
4.7 KiB
Python
179 lines
4.7 KiB
Python
"""Utility functions for refactored scripts."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Optional
|
|
|
|
import aiomysql
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def get_db_connection(config: dict) -> aiomysql.Connection:
|
|
"""
|
|
Create an async database connection.
|
|
|
|
Args:
|
|
config: Database configuration dictionary
|
|
|
|
Returns:
|
|
aiomysql.Connection: Async database connection
|
|
|
|
Raises:
|
|
Exception: If connection fails
|
|
"""
|
|
try:
|
|
conn = await aiomysql.connect(**config)
|
|
logger.debug("Database connection established")
|
|
return conn
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to database: {e}")
|
|
raise
|
|
|
|
|
|
async def execute_query(
|
|
conn: aiomysql.Connection,
|
|
query: str,
|
|
params: tuple | list = None,
|
|
fetch_one: bool = False,
|
|
fetch_all: bool = False,
|
|
) -> Any | None:
|
|
"""
|
|
Execute a database query safely with proper error handling.
|
|
|
|
Args:
|
|
conn: Database connection
|
|
query: SQL query string
|
|
params: Query parameters
|
|
fetch_one: Whether to fetch one result
|
|
fetch_all: Whether to fetch all results
|
|
|
|
Returns:
|
|
Query results or None
|
|
|
|
Raises:
|
|
Exception: If query execution fails
|
|
"""
|
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
try:
|
|
await cursor.execute(query, params or ())
|
|
|
|
if fetch_one:
|
|
return await cursor.fetchone()
|
|
elif fetch_all:
|
|
return await cursor.fetchall()
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Query execution failed: {e}")
|
|
logger.debug(f"Query: {query}")
|
|
logger.debug(f"Params: {params}")
|
|
raise
|
|
|
|
|
|
async def execute_many(conn: aiomysql.Connection, query: str, params_list: list) -> int:
|
|
"""
|
|
Execute a query with multiple parameter sets (batch insert).
|
|
|
|
Args:
|
|
conn: Database connection
|
|
query: SQL query string
|
|
params_list: List of parameter tuples
|
|
|
|
Returns:
|
|
Number of affected rows
|
|
|
|
Raises:
|
|
Exception: If query execution fails
|
|
"""
|
|
if not params_list:
|
|
logger.warning("execute_many called with empty params_list")
|
|
return 0
|
|
|
|
async with conn.cursor() as cursor:
|
|
try:
|
|
await cursor.executemany(query, params_list)
|
|
affected_rows = cursor.rowcount
|
|
logger.debug(f"Batch insert completed: {affected_rows} rows affected")
|
|
return affected_rows
|
|
|
|
except Exception as e:
|
|
logger.error(f"Batch query execution failed: {e}")
|
|
logger.debug(f"Query: {query}")
|
|
logger.debug(f"Number of parameter sets: {len(params_list)}")
|
|
raise
|
|
|
|
|
|
def parse_datetime(date_str: str, time_str: str = None) -> datetime:
|
|
"""
|
|
Parse date and optional time strings into datetime object.
|
|
|
|
Args:
|
|
date_str: Date string (various formats supported)
|
|
time_str: Optional time string
|
|
|
|
Returns:
|
|
datetime object
|
|
|
|
Examples:
|
|
>>> parse_datetime("2024-10-11", "14:30:00")
|
|
datetime(2024, 10, 11, 14, 30, 0)
|
|
|
|
>>> parse_datetime("2024-10-11T14:30:00")
|
|
datetime(2024, 10, 11, 14, 30, 0)
|
|
"""
|
|
# Handle ISO format with T separator
|
|
if "T" in date_str:
|
|
return datetime.fromisoformat(date_str.replace("T", " "))
|
|
|
|
# Handle separate date and time
|
|
if time_str:
|
|
return datetime.strptime(f"{date_str} {time_str}", "%Y-%m-%d %H:%M:%S")
|
|
|
|
# Handle date only
|
|
return datetime.strptime(date_str, "%Y-%m-%d")
|
|
|
|
|
|
async def retry_on_failure(
|
|
coro_func,
|
|
max_retries: int = 3,
|
|
delay: float = 1.0,
|
|
backoff: float = 2.0,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Retry an async function on failure with exponential backoff.
|
|
|
|
Args:
|
|
coro_func: Async function to retry
|
|
max_retries: Maximum number of retry attempts
|
|
delay: Initial delay between retries (seconds)
|
|
backoff: Backoff multiplier for delay
|
|
*args: Arguments to pass to coro_func
|
|
**kwargs: Keyword arguments to pass to coro_func
|
|
|
|
Returns:
|
|
Result from coro_func
|
|
|
|
Raises:
|
|
Exception: If all retries fail
|
|
"""
|
|
last_exception = None
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return await coro_func(*args, **kwargs)
|
|
except Exception as e:
|
|
last_exception = e
|
|
if attempt < max_retries - 1:
|
|
wait_time = delay * (backoff**attempt)
|
|
logger.warning(f"Attempt {attempt + 1}/{max_retries} failed: {e}. Retrying in {wait_time}s...")
|
|
await asyncio.sleep(wait_time)
|
|
else:
|
|
logger.error(f"All {max_retries} attempts failed")
|
|
|
|
raise last_exception
|