"""Logging utilities for budgeted optimization experiments.

This module provides structured logging for optimization runs,
outputting CSV logs and JSON metadata for reproducibility and plotting.
"""

from __future__ import annotations

import csv
import json
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Sequence

import numpy as np

from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)


@dataclass
class StepRecord:
    """Record for a single optimization step.

    Attributes:
        step: Oracle call index (0-indexed)
        method: Proposer method name
        seed: Random seed for this run
        init: Initialization method
        smiles: Decoded SMILES string
        qed: QED value
        neg_sa: -SA value
        valid: Whether molecule was valid
        hv: Current hypervolume
        hvi: HV improvement vs initial
        cumulative_validity: Running validity rate
    """

    step: int
    method: str
    seed: int
    init: str
    smiles: str
    qed: float
    neg_sa: float
    valid: bool
    hv: float
    hvi: float
    cumulative_validity: float


@dataclass
class RunMetadata:
    """Metadata for an optimization run.

    Attributes:
        method: Proposer method name
        init: Initialization method
        seed: Random seed
        budget: Total oracle budget
        n_init: Number of initial molecules
        ref_point: Reference point for HV computation
        sense: Optimization sense per objective
        timestamp: Run start time
    """

    method: str
    init: str
    seed: int
    budget: int
    n_init: int
    ref_point: list[float]
    sense: list[str] = None
    timestamp: str = None

    def __post_init__(self):
        if self.sense is None:
            self.sense = ["max", "max"]
        if self.timestamp is None:
            self.timestamp = datetime.now().isoformat()


class OptimizationLogger:
    """Logger for optimization experiments.

    Handles:
    - Per-step CSV logging
    - Run metadata JSON
    - Pareto front snapshots

    Args:
        output_dir: Directory for output files
        method: Proposer method name
        init: Initialization method
        seed: Random seed
        budget: Oracle budget
        n_init: Initial dataset size
        ref_point: Reference point for HV
        pareto_snapshot_interval: Steps between Pareto snapshots (0 to disable)

    Example:
        >>> logger = OptimizationLogger(
        ...     output_dir="outputs/run_001",
        ...     method="moltenflow",
        ...     init="random",
        ...     seed=42,
        ...     budget=100,
        ... )
        >>> logger.log_step(step=0, smiles="CCO", qed=0.5, ...)
        >>> logger.finalize()
    """

    # CSV column names
    CSV_COLUMNS = [
        "step",
        "method",
        "seed",
        "init",
        "smiles",
        "qed",
        "neg_sa",
        "valid",
        "hv",
        "hvi",
        "cumulative_validity",
    ]

    def __init__(
        self,
        output_dir: str | Path,
        method: str,
        init: str,
        seed: int,
        budget: int,
        n_init: int = 20,
        ref_point: Sequence[float] = (0.0, -10.0),
        pareto_snapshot_interval: int = 25,
    ):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        self.method = method
        self.init = init
        self.seed = seed
        self.budget = budget
        self.n_init = n_init
        self.ref_point = list(ref_point)
        self.pareto_snapshot_interval = pareto_snapshot_interval

        # Create metadata
        self.metadata = RunMetadata(
            method=method,
            init=init,
            seed=seed,
            budget=budget,
            n_init=n_init,
            ref_point=self.ref_point,
        )

        # Initialize CSV file
        self.csv_path = self.output_dir / "optimization_log.csv"
        self._init_csv()

        # Save metadata
        self._save_metadata()

        # Track records for summary
        self._records: list[StepRecord] = []

        logger.info(f"Optimization logger initialized: {self.output_dir}")

    def _init_csv(self) -> None:
        """Initialize CSV file with headers."""
        with open(self.csv_path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=self.CSV_COLUMNS)
            writer.writeheader()

    def _save_metadata(self) -> None:
        """Save run metadata to JSON."""
        metadata_path = self.output_dir / "run_metadata.json"
        with open(metadata_path, "w") as f:
            json.dump(asdict(self.metadata), f, indent=2)

    def log_step(
        self,
        step: int,
        smiles: str,
        qed: float,
        neg_sa: float,
        valid: bool,
        hv: float,
        hvi: float,
        cumulative_validity: float,
    ) -> None:
        """Log a single optimization step.

        Args:
            step: Oracle call index
            smiles: Decoded SMILES
            qed: QED value
            neg_sa: -SA value
            valid: Whether valid
            hv: Current HV
            hvi: HV improvement
            cumulative_validity: Running validity rate
        """
        record = StepRecord(
            step=step,
            method=self.method,
            seed=self.seed,
            init=self.init,
            smiles=smiles,
            qed=qed,
            neg_sa=neg_sa,
            valid=valid,
            hv=hv,
            hvi=hvi,
            cumulative_validity=cumulative_validity,
        )

        self._records.append(record)

        # Append to CSV
        with open(self.csv_path, "a", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=self.CSV_COLUMNS)
            writer.writerow(asdict(record))

        # Check for Pareto snapshot
        if self.pareto_snapshot_interval > 0:
            if (step + 1) % self.pareto_snapshot_interval == 0:
                # Pareto snapshot will be saved externally by runner
                pass

    def save_pareto_snapshot(
        self,
        step: int,
        smiles: Sequence[str],
        objectives: np.ndarray,
    ) -> None:
        """Save a snapshot of the Pareto front.

        Args:
            step: Current step
            smiles: Pareto-optimal SMILES
            objectives: Objectives array (n, 2)
        """
        snapshot_path = self.output_dir / f"pareto_t{step}.csv"

        with open(snapshot_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["smiles", "qed", "neg_sa"])
            for i, smi in enumerate(smiles):
                writer.writerow([smi, objectives[i, 0], objectives[i, 1]])

        logger.debug(f"Saved Pareto snapshot at step {step}: {len(smiles)} molecules")

    def finalize(self) -> dict[str, Any]:
        """Finalize logging and return summary statistics.

        Returns:
            Dictionary with summary statistics
        """
        if not self._records:
            return {}

        # Compute summary stats
        final_record = self._records[-1]
        valid_records = [r for r in self._records if r.valid]

        summary = {
            "method": self.method,
            "init": self.init,
            "seed": self.seed,
            "budget": self.budget,
            "n_evaluated": len(self._records),
            "n_valid": len(valid_records),
            "final_hv": final_record.hv,
            "final_hvi": final_record.hvi,
            "final_validity": final_record.cumulative_validity,
        }

        # Save summary
        summary_path = self.output_dir / "summary.json"
        with open(summary_path, "w") as f:
            json.dump(summary, f, indent=2)

        logger.info(
            f"Optimization summary: HVI={summary['final_hvi']:.4f}, validity={summary['final_validity']:.2%}"
        )

        return summary

    def get_records(self) -> list[StepRecord]:
        """Return all logged records."""
        return self._records.copy()


def load_optimization_log(log_path: str | Path) -> tuple[list[dict], dict]:
    """Load optimization log from CSV file.

    Args:
        log_path: Path to optimization_log.csv

    Returns:
        Tuple of (records list, metadata dict)
    """
    log_path = Path(log_path)

    # Load CSV
    records = []
    with open(log_path, "r") as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Convert types
            row["step"] = int(row["step"])
            row["seed"] = int(row["seed"])
            row["qed"] = float(row["qed"])
            row["neg_sa"] = float(row["neg_sa"])
            row["valid"] = row["valid"].lower() == "true"
            row["hv"] = float(row["hv"])
            row["hvi"] = float(row["hvi"])
            row["cumulative_validity"] = float(row["cumulative_validity"])
            records.append(row)

    # Load metadata if available
    metadata_path = log_path.parent / "run_metadata.json"
    if metadata_path.exists():
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
    else:
        metadata = {}

    return records, metadata


def load_experiment_logs(
    log_dir: str | Path,
    methods: Sequence[str] | None = None,
    seeds: Sequence[int] | None = None,
) -> dict[str, list[dict]]:
    """Load logs from multiple runs in an experiment directory.

    Expects directory structure:
        log_dir/
            method_init_seed/
                optimization_log.csv
                run_metadata.json

    Args:
        log_dir: Root directory containing run subdirectories
        methods: Filter to specific methods (default: all)
        seeds: Filter to specific seeds (default: all)

    Returns:
        Dictionary mapping run_id to list of records
    """
    log_dir = Path(log_dir)
    all_logs = {}

    for run_dir in log_dir.iterdir():
        if not run_dir.is_dir():
            continue

        log_path = run_dir / "optimization_log.csv"
        if not log_path.exists():
            continue

        records, metadata = load_optimization_log(log_path)

        if not records:
            continue

        # Filter by method
        if methods is not None and records[0]["method"] not in methods:
            continue

        # Filter by seed
        if seeds is not None and records[0]["seed"] not in seeds:
            continue

        run_id = run_dir.name
        all_logs[run_id] = records

    logger.info(f"Loaded {len(all_logs)} optimization logs from {log_dir}")

    return all_logs
