"""SQLite-backed experiment database with atomic task claiming."""

import json
import os
import sqlite3
import socket
import hashlib
from datetime import datetime
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from enum import Enum


class ExperimentStatus(str, Enum):
    """Status of an experiment."""
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"


@dataclass
class Experiment:
    """Represents a single experiment configuration."""
    id: int
    config_hash: str
    config: dict
    seed: int
    status: ExperimentStatus
    worker_id: Optional[str] = None
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    error_message: Optional[str] = None
    best_model_path: Optional[str] = None
    wandb_run_id: Optional[str] = None


@dataclass
class Evaluation:
    """Represents evaluation metrics at a specific epoch."""
    id: int
    experiment_id: int
    epoch: int
    metrics: dict
    created_at: datetime


def compute_config_hash(config: dict) -> str:
    """Compute a deterministic hash for a configuration dict."""
    # Sort keys recursively for deterministic ordering
    def sort_dict(d):
        if isinstance(d, dict):
            return {k: sort_dict(v) for k, v in sorted(d.items())}
        elif isinstance(d, list):
            return [sort_dict(x) for x in d]
        return d
    
    sorted_config = sort_dict(config)
    config_str = json.dumps(sorted_config, sort_keys=True)
    return hashlib.sha256(config_str.encode()).hexdigest()[:16]


class ExperimentDB:
    """
    SQLite-backed experiment database with support for distributed workers.
    
    Uses WAL mode for better concurrency and atomic transactions for
    claiming tasks to prevent race conditions between workers.
    """
    
    def __init__(self, db_path: str | Path):
        """
        Initialize the experiment database.
        
        Args:
            db_path: Path to the SQLite database file.
        """
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._init_db()
    
    def _get_connection(self) -> sqlite3.Connection:
        """Get a new database connection with proper settings."""
        conn = sqlite3.connect(
            str(self.db_path),
            timeout=30.0,  # Wait up to 30s for locks
            isolation_level=None  # Autocommit mode, we manage transactions manually
        )
        conn.row_factory = sqlite3.Row
        # Enable WAL mode for better concurrent access
        conn.execute("PRAGMA journal_mode=WAL")
        conn.execute("PRAGMA busy_timeout=30000")  # 30s busy timeout
        return conn
    
    def _init_db(self):
        """Initialize database schema if not exists."""
        conn = self._get_connection()
        try:
            conn.execute("BEGIN EXCLUSIVE")
            
            # Experiments table
            conn.execute("""
                CREATE TABLE IF NOT EXISTS experiments (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    config_hash TEXT NOT NULL,
                    config_json TEXT NOT NULL,
                    seed INTEGER NOT NULL,
                    status TEXT NOT NULL DEFAULT 'pending',
                    worker_id TEXT,
                    started_at TEXT,
                    completed_at TEXT,
                    error_message TEXT,
                    best_model_path TEXT,
                    wandb_run_id TEXT,
                    created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    UNIQUE(config_hash, seed)
                )
            """)
            
            # Evaluations table (time-series metrics)
            conn.execute("""
                CREATE TABLE IF NOT EXISTS evaluations (
                    id INTEGER PRIMARY KEY AUTOINCREMENT,
                    experiment_id INTEGER NOT NULL,
                    epoch INTEGER NOT NULL,
                    metrics_json TEXT NOT NULL,
                    created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    FOREIGN KEY (experiment_id) REFERENCES experiments(id),
                    UNIQUE(experiment_id, epoch)
                )
            """)
            
            # Indices for faster queries
            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_experiments_status 
                ON experiments(status)
            """)
            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_experiments_config_hash 
                ON experiments(config_hash)
            """)
            conn.execute("""
                CREATE INDEX IF NOT EXISTS idx_evaluations_experiment 
                ON evaluations(experiment_id)
            """)
            
            conn.execute("COMMIT")
        except Exception:
            conn.execute("ROLLBACK")
            raise
        finally:
            conn.close()
    
    def insert_experiment(
        self, 
        config: dict, 
        seed: int,
        skip_existing: bool = True
    ) -> Optional[int]:
        """
        Insert a new experiment configuration.
        
        Args:
            config: Experiment configuration dictionary.
            seed: Random seed for this run.
            skip_existing: If True, skip if (config_hash, seed) already exists.
            
        Returns:
            The experiment ID, or None if skipped.
        """
        config_hash = compute_config_hash(config)
        config_json = json.dumps(config, sort_keys=True)
        
        conn = self._get_connection()
        try:
            if skip_existing:
                # Check if exists
                cursor = conn.execute(
                    "SELECT id FROM experiments WHERE config_hash = ? AND seed = ?",
                    (config_hash, seed)
                )
                existing = cursor.fetchone()
                if existing:
                    return None
            
            cursor = conn.execute(
                """
                INSERT INTO experiments (config_hash, config_json, seed, status)
                VALUES (?, ?, ?, ?)
                """,
                (config_hash, config_json, seed, ExperimentStatus.PENDING.value)
            )
            return cursor.lastrowid
        except sqlite3.IntegrityError:
            # Already exists (race condition)
            return None
        finally:
            conn.close()
    
    def claim_pending_task(self, worker_id: Optional[str] = None) -> Optional[Experiment]:
        """
        Atomically claim a pending task for processing.
        
        Uses a transaction with exclusive lock to prevent race conditions
        between workers trying to claim the same task.
        
        Tasks are claimed in order by seed first, then randomly within each seed.
        This ensures good coverage: all unique configurations get at least one run
        (seed 0) before any configuration gets a second run (seed 1), etc., while
        maintaining diversity through random ordering within each seed group.
        
        Args:
            worker_id: Identifier for the worker claiming the task.
                       Defaults to hostname.
                       
        Returns:
            The claimed Experiment, or None if no pending tasks.
        """
        if worker_id is None:
            worker_id = f"{socket.gethostname()}-{os.getpid()}"
        
        conn = self._get_connection()
        try:
            conn.execute("BEGIN IMMEDIATE")
            
            # Find a pending task
            # Order by seed first, then randomly within each seed to ensure good coverage:
            # all configs get seed 0 before any config gets seed 1, etc.
            cursor = conn.execute(
                """
                SELECT id, config_hash, config_json, seed, status
                FROM experiments
                WHERE status = ?
                ORDER BY seed, RANDOM()
                LIMIT 1
                """,
                (ExperimentStatus.PENDING.value,)
            )
            row = cursor.fetchone()
            
            if row is None:
                conn.execute("ROLLBACK")
                return None
            
            # Claim it
            now = datetime.now().isoformat()
            conn.execute(
                """
                UPDATE experiments
                SET status = ?, worker_id = ?, started_at = ?
                WHERE id = ?
                """,
                (ExperimentStatus.RUNNING.value, worker_id, now, row["id"])
            )
            
            conn.execute("COMMIT")
            
            return Experiment(
                id=row["id"],
                config_hash=row["config_hash"],
                config=json.loads(row["config_json"]),
                seed=row["seed"],
                status=ExperimentStatus.RUNNING,
                worker_id=worker_id,
                started_at=datetime.fromisoformat(now)
            )
        except Exception:
            conn.execute("ROLLBACK")
            raise
        finally:
            conn.close()
    
    def mark_completed(
        self, 
        experiment_id: int, 
        best_model_path: Optional[str] = None,
        wandb_run_id: Optional[str] = None
    ):
        """Mark an experiment as completed."""
        conn = self._get_connection()
        try:
            conn.execute(
                """
                UPDATE experiments
                SET status = ?, completed_at = ?, best_model_path = ?, wandb_run_id = ?
                WHERE id = ?
                """,
                (
                    ExperimentStatus.COMPLETED.value,
                    datetime.now().isoformat(),
                    best_model_path,
                    wandb_run_id,
                    experiment_id
                )
            )
        finally:
            conn.close()
    
    def mark_failed(self, experiment_id: int, error_message: str):
        """Mark an experiment as failed with an error message."""
        conn = self._get_connection()
        try:
            conn.execute(
                """
                UPDATE experiments
                SET status = ?, completed_at = ?, error_message = ?
                WHERE id = ?
                """,
                (
                    ExperimentStatus.FAILED.value,
                    datetime.now().isoformat(),
                    error_message,
                    experiment_id
                )
            )
        finally:
            conn.close()
    
    def log_evaluation(
        self, 
        experiment_id: int, 
        epoch: int, 
        metrics: dict
    ):
        """
        Log evaluation metrics for an experiment at a specific epoch.
        
        Uses INSERT OR REPLACE to handle re-runs gracefully.
        """
        conn = self._get_connection()
        try:
            conn.execute(
                """
                INSERT OR REPLACE INTO evaluations (experiment_id, epoch, metrics_json, created_at)
                VALUES (?, ?, ?, ?)
                """,
                (experiment_id, epoch, json.dumps(metrics), datetime.now().isoformat())
            )
        finally:
            conn.close()
    
    def get_experiment(self, experiment_id: int) -> Optional[Experiment]:
        """Get an experiment by ID."""
        conn = self._get_connection()
        try:
            cursor = conn.execute(
                """
                SELECT id, config_hash, config_json, seed, status, 
                       worker_id, started_at, completed_at, error_message,
                       best_model_path, wandb_run_id
                FROM experiments
                WHERE id = ?
                """,
                (experiment_id,)
            )
            row = cursor.fetchone()
            if row is None:
                return None
            
            return Experiment(
                id=row["id"],
                config_hash=row["config_hash"],
                config=json.loads(row["config_json"]),
                seed=row["seed"],
                status=ExperimentStatus(row["status"]),
                worker_id=row["worker_id"],
                started_at=datetime.fromisoformat(row["started_at"]) if row["started_at"] else None,
                completed_at=datetime.fromisoformat(row["completed_at"]) if row["completed_at"] else None,
                error_message=row["error_message"],
                best_model_path=row["best_model_path"],
                wandb_run_id=row["wandb_run_id"]
            )
        finally:
            conn.close()
    
    def get_evaluations(self, experiment_id: int) -> list[Evaluation]:
        """Get all evaluations for an experiment, ordered by epoch."""
        conn = self._get_connection()
        try:
            cursor = conn.execute(
                """
                SELECT id, experiment_id, epoch, metrics_json, created_at
                FROM evaluations
                WHERE experiment_id = ?
                ORDER BY epoch
                """,
                (experiment_id,)
            )
            return [
                Evaluation(
                    id=row["id"],
                    experiment_id=row["experiment_id"],
                    epoch=row["epoch"],
                    metrics=json.loads(row["metrics_json"]),
                    created_at=datetime.fromisoformat(row["created_at"])
                )
                for row in cursor.fetchall()
            ]
        finally:
            conn.close()
    
    def get_status_summary(self) -> dict[str, int]:
        """Get counts of experiments by status."""
        conn = self._get_connection()
        try:
            cursor = conn.execute(
                """
                SELECT status, COUNT(*) as count
                FROM experiments
                GROUP BY status
                """
            )
            summary = {s.value: 0 for s in ExperimentStatus}
            for row in cursor.fetchall():
                summary[row["status"]] = row["count"]
            summary["total"] = sum(summary.values())
            return summary
        finally:
            conn.close()
    
    def reset_failed_to_pending(self) -> int:
        """Reset all failed experiments to pending. Returns count reset."""
        conn = self._get_connection()
        try:
            cursor = conn.execute(
                """
                UPDATE experiments
                SET status = ?, worker_id = NULL, started_at = NULL, 
                    completed_at = NULL, error_message = NULL
                WHERE status = ?
                """,
                (ExperimentStatus.PENDING.value, ExperimentStatus.FAILED.value)
            )
            return cursor.rowcount
        finally:
            conn.close()
    
    def reset_running_to_pending(self) -> int:
        """
        Reset all running experiments to pending. 
        Useful for recovering from interrupted runs.
        Returns count reset.
        """
        conn = self._get_connection()
        try:
            cursor = conn.execute(
                """
                UPDATE experiments
                SET status = ?, worker_id = NULL, started_at = NULL
                WHERE status = ?
                """,
                (ExperimentStatus.PENDING.value, ExperimentStatus.RUNNING.value)
            )
            return cursor.rowcount
        finally:
            conn.close()
    
    def get_all_experiments(
        self, 
        status: Optional[ExperimentStatus] = None
    ) -> list[Experiment]:
        """Get all experiments, optionally filtered by status."""
        conn = self._get_connection()
        try:
            if status is not None:
                cursor = conn.execute(
                    """
                    SELECT id, config_hash, config_json, seed, status,
                           worker_id, started_at, completed_at, error_message,
                           best_model_path, wandb_run_id
                    FROM experiments
                    WHERE status = ?
                    ORDER BY id
                    """,
                    (status.value,)
                )
            else:
                cursor = conn.execute(
                    """
                    SELECT id, config_hash, config_json, seed, status,
                           worker_id, started_at, completed_at, error_message,
                           best_model_path, wandb_run_id
                    FROM experiments
                    ORDER BY id
                    """
                )
            
            return [
                Experiment(
                    id=row["id"],
                    config_hash=row["config_hash"],
                    config=json.loads(row["config_json"]),
                    seed=row["seed"],
                    status=ExperimentStatus(row["status"]),
                    worker_id=row["worker_id"],
                    started_at=datetime.fromisoformat(row["started_at"]) if row["started_at"] else None,
                    completed_at=datetime.fromisoformat(row["completed_at"]) if row["completed_at"] else None,
                    error_message=row["error_message"],
                    best_model_path=row["best_model_path"],
                    wandb_run_id=row["wandb_run_id"]
                )
                for row in cursor.fetchall()
            ]
        finally:
            conn.close()
    
    def export_results(self) -> list[dict]:
        """
        Export all completed experiments with their evaluations.
        
        Returns a list of dicts suitable for conversion to DataFrame/JSON.
        """
        conn = self._get_connection()
        try:
            experiments = self.get_all_experiments(ExperimentStatus.COMPLETED)
            results = []
            
            for exp in experiments:
                evaluations = self.get_evaluations(exp.id)
                
                # Flatten config into result
                result = {
                    "experiment_id": exp.id,
                    "config_hash": exp.config_hash,
                    "seed": exp.seed,
                    "best_model_path": exp.best_model_path,
                    "wandb_run_id": exp.wandb_run_id,
                    "started_at": exp.started_at.isoformat() if exp.started_at else None,
                    "completed_at": exp.completed_at.isoformat() if exp.completed_at else None,
                    **{f"config.{k}": v for k, v in exp.config.items()},
                }
                
                # Add final evaluation metrics
                if evaluations:
                    final_eval = evaluations[-1]
                    result["final_epoch"] = final_eval.epoch
                    for k, v in final_eval.metrics.items():
                        result[f"final.{k}"] = v
                
                results.append(result)
            
            return results
        finally:
            conn.close()

