"""Migration state management.""" import json from pathlib import Path from datetime import datetime from typing import Optional, Dict, Any from src.utils.logger import get_logger logger = get_logger(__name__) class MigrationState: """Manage migration state for incremental migrations.""" DEFAULT_STATE_FILE = "migration_state.json" def __init__(self, state_file: str = DEFAULT_STATE_FILE): """Initialize migration state. Args: state_file: Path to state file """ self.state_file = Path(state_file) self.state = self._load_state() def _load_state(self) -> Dict[str, Any]: """Load state from file.""" if self.state_file.exists(): try: with open(self.state_file, "r") as f: return json.load(f) except Exception as e: logger.warning(f"Failed to load state file: {e}") return {} return {} def _save_state(self) -> None: """Save state to file.""" try: with open(self.state_file, "w") as f: json.dump(self.state, f, indent=2) except Exception as e: logger.error(f"Failed to save state file: {e}") raise def get_last_timestamp(self, table: str) -> Optional[str]: """Get last migration timestamp for a table. Args: table: Table name Returns: ISO format timestamp or None if not found """ return self.state.get(table, {}).get("last_timestamp") def set_last_timestamp(self, table: str, timestamp: str) -> None: """Set last migration timestamp for a table. Args: table: Table name timestamp: ISO format timestamp """ if table not in self.state: self.state[table] = {} self.state[table]["last_timestamp"] = timestamp self.state[table]["last_updated"] = datetime.utcnow().isoformat() self._save_state() def get_migration_count(self, table: str) -> int: """Get total migration count for a table. Args: table: Table name Returns: Total rows migrated """ return self.state.get(table, {}).get("total_migrated", 0) def increment_migration_count(self, table: str, count: int) -> None: """Increment migration count for a table. Args: table: Table name count: Number of rows to add """ if table not in self.state: self.state[table] = {} current = self.state[table].get("total_migrated", 0) self.state[table]["total_migrated"] = current + count self._save_state() def reset(self, table: Optional[str] = None) -> None: """Reset migration state. Args: table: Table name to reset, or None to reset all """ if table: self.state[table] = {} else: self.state = {} self._save_state()