"""Data loading utilities for sweep results."""

from pathlib import Path

import polars as pl
import yaml

from sweep import SweepConfig, get_experiment_configs

# Default path for cached parquet data
DEFAULT_PARQUET_PATH = Path("assets/sweep_data.parquet")


def get_parquet_path(sweep_config_path: str) -> Path:
    """Get the parquet cache path for a sweep config.

    Args:
        sweep_config_path: Path to the sweep config YAML.

    Returns:
        Path to the parquet file.
    """
    config_name = Path(sweep_config_path).stem
    return Path(f"assets/sweep_data_{config_name}.parquet")


def save_sweep_data(df: pl.DataFrame, path: Path = DEFAULT_PARQUET_PATH) -> None:
    """Save sweep data to a parquet file for sharing.

    Args:
        df: DataFrame to save.
        path: Path to save the parquet file.
    """
    path.parent.mkdir(parents=True, exist_ok=True)
    df.write_parquet(path)


def load_sweep_data(
    sweep_config_path: str = "configs/sweep/default.yaml",
    parquet_path: Path | None = None,
    force_logs: bool = False,
) -> pl.DataFrame:
    """Load all experiment results from a sweep into a single DataFrame.

    Tries to load from .logs directory first. If no results found,
    falls back to loading from a cached parquet file.

    Args:
        sweep_config_path: Path to the sweep config YAML.
        parquet_path: Path to fallback parquet file. If None, uses config-specific path.
        force_logs: If True, only load from .logs (no parquet fallback).

    Returns:
        DataFrame with columns: name, value, t, type, g, batch_size, tol, output_dim, top_k, seed, eid
    """
    if parquet_path is None:
        parquet_path = get_parquet_path(sweep_config_path)

    # Try loading from .logs first
    df = _load_from_logs(sweep_config_path)

    if df is not None:
        return df

    # Fall back to parquet if available
    if not force_logs and parquet_path.exists():
        return pl.read_parquet(parquet_path)

    raise ValueError(
        f"No experiment results found for sweep: {sweep_config_path}. "
        f"Also checked parquet fallback at: {parquet_path}"
    )


def _load_from_logs(sweep_config_path: str) -> pl.DataFrame | None:
    """Load sweep data from .logs directory.

    Args:
        sweep_config_path: Path to the sweep config YAML.

    Returns:
        DataFrame if results found, None otherwise.
    """
    config_path = Path(sweep_config_path)
    if not config_path.exists():
        return None

    with open(config_path, "r") as f:
        sweep_config_dict = yaml.safe_load(f)
    config = SweepConfig.from_dict(sweep_config_dict)
    configs = list(get_experiment_configs(config))

    dfs = []
    for exp in configs:
        path = Path(f".logs/{exp.id}/metrics.jsonl")

        if not path.exists():
            continue

        _df = pl.read_ndjson(path)

        # Add experiment parameters as columns
        _df = _df.with_columns(
            pl.lit(exp.calibrator_config["batch_size"]).alias("batch_size"),
            pl.lit(exp.calibrator_config["tol"]).alias("tol"),
            pl.lit(exp.calibrator_config["output_dim"]).alias("output_dim"),
            pl.lit(exp.dataset_config.get("top_k", 1)).alias("top_k"),
            pl.lit(exp.seed).alias("seed"),
            pl.lit(exp.id).alias("eid"),
        )
        dfs.append(_df)

    if not dfs:
        return None

    return pl.concat(dfs, how="diagonal")


def compute_mvs_summary(df: pl.DataFrame) -> pl.DataFrame:
    """Compute MVS summary statistics per experiment.

    Extracts:
    - best_mvs: best MVS across all grid cells (min value for grid type)
    - ew_final_mvs: final equally-weighted MVS (last t value)
    - oracle_mvs: oracle MVS value

    Args:
        df: Raw metrics DataFrame from load_sweep_data.

    Returns:
        Summary DataFrame with one row per experiment.
    """
    mvs_df = df.filter(pl.col("name") == "mvs")

    # Get oracle MVS per experiment
    oracle = (
        mvs_df.filter(pl.col("type") == "oracle")
        .group_by("eid")
        .agg(pl.col("value").first().alias("oracle_mvs"))
    )

    # Get best grid MVS per experiment (minimum across all grid cells)
    best_grid = (
        mvs_df.filter(pl.col("type") == "grid")
        .group_by("eid")
        .agg(pl.col("value").max().alias("best_grid_mvs"))
    )

    # Get final equally-weighted MVS (at max t)
    ew_final = (
        mvs_df.filter(pl.col("type") == "ew")
        .group_by("eid")
        .agg(
            pl.col("value")
            .filter(pl.col("t") == pl.col("t").max())
            .first()
            .alias("ew_final_mvs")
        )
    )

    # Get pre-calibration equally-weighted MVS (at min t)
    ew_pre = (
        mvs_df.filter(pl.col("type") == "ew")
        .group_by("eid")
        .agg(
            pl.col("value")
            .filter(pl.col("t") == pl.col("t").min())
            .first()
            .alias("ew_pre_mvs")
        )
    )

    # Get experiment params (one row per eid)
    params = df.select(
        "eid", "batch_size", "tol", "output_dim", "top_k", "seed"
    ).unique(subset=["eid"])

    # Join all together
    summary = (
        params.join(oracle, on="eid", how="left")
        .join(best_grid, on="eid", how="left")
        .join(ew_final, on="eid", how="left")
        .join(ew_pre, on="eid", how="left")
    )

    # Compute improvement metrics
    summary = summary.with_columns(
        (pl.col("ew_final_mvs") / pl.col("oracle_mvs") * 100).alias("mvs_oracle_ratio"),
        # Best-grid baseline metrics
        (pl.col("ew_final_mvs") - pl.col("best_grid_mvs")).alias("mvs_diff_bestgrid"),
        (
            (pl.col("ew_final_mvs") - pl.col("best_grid_mvs"))
            / pl.col("best_grid_mvs")
            * 100
        ).alias("mvs_pct_change_bestgrid"),
        (pl.col("oracle_mvs") - pl.col("ew_final_mvs")).alias(
            "mvs_oracle_gap_bestgrid"
        ),
        (
            (pl.col("ew_final_mvs") - pl.col("best_grid_mvs"))
            / (pl.col("oracle_mvs") - pl.col("best_grid_mvs"))
            * 100
        ).alias("mvs_gap_closure_pct_bestgrid"),
        (pl.col("ew_final_mvs") / pl.col("oracle_mvs") * 100).alias(
            "mvs_oracle_ratio_bestgrid"
        ),
        # EW-pre baseline metrics
        (pl.col("ew_final_mvs") - pl.col("ew_pre_mvs")).alias("mvs_diff_self"),
        (
            (pl.col("ew_final_mvs") - pl.col("ew_pre_mvs")) / pl.col("ew_pre_mvs") * 100
        ).alias("mvs_pct_change_self"),
        (pl.col("oracle_mvs") - pl.col("ew_final_mvs")).alias("mvs_oracle_gap_self"),
        (
            (pl.col("ew_final_mvs") - pl.col("ew_pre_mvs"))
            / (pl.col("oracle_mvs") - pl.col("ew_pre_mvs"))
            * 100
        ).alias("mvs_gap_closure_pct_self"),
        (pl.col("ew_final_mvs") / pl.col("oracle_mvs") * 100).alias(
            "mvs_oracle_ratio_self"
        ),
    )

    return summary


def compute_mse_summary(df: pl.DataFrame) -> pl.DataFrame:
    """Compute MSE summary statistics per experiment.

    Extracts:
    - mse_first: MSE at t=0 (before calibration)
    - mse_last: MSE at max t (after calibration)

    Args:
        df: Raw metrics DataFrame from load_sweep_data.

    Returns:
        Summary DataFrame with one row per experiment.
    """
    mse_df = df.filter(pl.col("name") == "mse")

    # Get first MSE (t=0)
    mse_first = mse_df.group_by("eid").agg(
        pl.col("value")
        .filter(pl.col("t") == pl.col("t").min())
        .first()
        .alias("mse_first")
    )

    # Get last MSE (max t)
    mse_last = mse_df.group_by("eid").agg(
        pl.col("value")
        .filter(pl.col("t") == pl.col("t").max())
        .first()
        .alias("mse_last")
    )

    # Get experiment params (one row per eid)
    params = df.select(
        "eid", "batch_size", "tol", "output_dim", "top_k", "seed"
    ).unique(subset=["eid"])

    # Join all together
    summary = params.join(mse_first, on="eid", how="left").join(
        mse_last, on="eid", how="left"
    )

    # Compute improvement metrics (negative diff means improvement)
    summary = summary.with_columns(
        (pl.col("mse_last") - pl.col("mse_first")).alias("mse_diff"),
        ((pl.col("mse_last") - pl.col("mse_first")) / pl.col("mse_first") * 100).alias(
            "mse_pct_change"
        ),
    )

    return summary


def compute_full_summary(df: pl.DataFrame) -> pl.DataFrame:
    """Compute both MVS and MSE summary statistics per experiment.

    Args:
        df: Raw metrics DataFrame from load_sweep_data.

    Returns:
        Summary DataFrame with MVS and MSE metrics per experiment.
    """
    mvs_summary = compute_mvs_summary(df)
    mse_summary = compute_mse_summary(df)

    # Join on eid, keeping all columns from mvs and mse-specific from mse
    return mvs_summary.join(
        mse_summary.select(
            "eid", "mse_first", "mse_last", "mse_diff", "mse_pct_change"
        ),
        on="eid",
        how="left",
    )
