#!/usr/bin/env python3
"""Run CARESL gamma tuning on injected-bias judge outputs and summarize metrics."""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import numpy as np
import pandas as pd
import scipy as sp
from sklearn.metrics import mean_absolute_error as mae
from sklearn.model_selection import train_test_split

REPO_ROOT = Path(__file__).resolve().parents[1]
SRC_ROOT = REPO_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from pgm_tools import caresl_aggregate, sanitize_correlation, ws_aggregate, uws_aggregate, majority_vote  # noqa: E402


BIAS_NAME_MAP = {
    "rich_content": "Beauty Bias",
    "factual_error": "Fallacy Oversight Bias",
    "gender": "Gender Bias",
    "reference": "Authority Bias",
}

DEFAULT_BIASES: Tuple[str, ...] = tuple(BIAS_NAME_MAP.keys())
DEFAULT_GAMMA_GRID: Tuple[float, ...] = (0.1, 0.2, 0.5, 1, 2, 3, 5, 7, 10)
DROP_COLUMN = "Phi-4-mini-instruct"


def prepare_judgement(folder: Path) -> pd.DataFrame:
    csv_paths = sorted(folder.glob("*.csv"))
    parsed_outputs = {
        path.stem: pd.read_csv(path, usecols=["parsed_output"])["parsed_output"]
        for path in csv_paths
    }
    return pd.DataFrame(parsed_outputs)


def is_valid_score(value: object, *, min_rating: float = 1.0, max_rating: float = 10.0) -> bool:
    try:
        numeric = float(value)
    except (TypeError, ValueError):
        return False
    return min_rating <= numeric <= max_rating


def map_valid_score(raw_df: pd.DataFrame, bias_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    raw_mask = raw_df.map(is_valid_score).all(axis=1)
    bias_mask = bias_df.map(is_valid_score).all(axis=1)
    valid_idx = raw_df.index[raw_mask & bias_mask]
    return (
        raw_df.loc[valid_idx].reset_index(drop=True),
        bias_df.loc[valid_idx].reset_index(drop=True),
    )


def standardize(values: Iterable[float]) -> np.ndarray:
    arr = np.asarray(list(values), dtype=float)
    mean = arr.mean()
    std = arr.std()
    if std == 0:
        return np.zeros_like(arr)
    return (arr - mean) / std


def compute_metrics(raw_scores: np.ndarray, bias_scores: np.ndarray) -> Dict[str, float]:
    raw_scores = np.asarray(raw_scores, dtype=float)
    bias_scores = np.asarray(bias_scores, dtype=float)
    metrics = {
        "mae": float(mae(raw_scores, bias_scores)),
        "mae_std": float(mae(standardize(raw_scores), standardize(bias_scores))),
    }
    pearson = np.corrcoef(raw_scores, bias_scores)[0, 1]
    metrics["pearson_r"] = float(pearson)
    metrics["spearman_rho"] = float(sp.stats.spearmanr(raw_scores, bias_scores).correlation)
    metrics["kendall_tau"] = float(sp.stats.kendalltau(raw_scores, bias_scores).correlation)
    return metrics


def tune_gamma(
    raw_train: pd.DataFrame,
    bias_train: pd.DataFrame,
    raw_val: pd.DataFrame,
    bias_val: pd.DataFrame,
    gamma_grid: Iterable[float],
    penalize_dependency: bool,
    solver_kwargs: Dict[str, object],
) -> float:
    best_gamma = float(next(iter(gamma_grid)))
    best_mae = float("inf")
    corr_raw = sanitize_correlation(raw_train.corr())
    corr_bias = sanitize_correlation(bias_train.corr())
    for gamma in gamma_grid:
        try:
            _, raw_weights = caresl_aggregate(
                raw_train,
                gamma=gamma,
                verbose=False,
                corr_matrix=corr_raw,
                return_weights=True,
                penalize_dependency=penalize_dependency,
                **solver_kwargs,
            )
            raw_val_pred = caresl_aggregate(
                raw_val,
                weights=raw_weights,
                penalize_dependency=penalize_dependency,
            )
            _, bias_weights = caresl_aggregate(
                bias_train,
                gamma=gamma,
                verbose=False,
                corr_matrix=corr_bias,
                return_weights=True,
                penalize_dependency=penalize_dependency,
                **solver_kwargs,
            )
            bias_val_pred = caresl_aggregate(
                bias_val,
                weights=bias_weights,
                penalize_dependency=penalize_dependency,
            )
            val_mae = mae(raw_val_pred, bias_val_pred)
        except Exception:
            continue
        if np.isnan(val_mae):
            continue
        if (val_mae + 1e-9) < best_mae or (np.isclose(val_mae, best_mae) and gamma < best_gamma):
            best_mae = float(val_mae)
            best_gamma = float(gamma)
    return best_gamma


def aggregate_methods(df: pd.DataFrame) -> Dict[str, np.ndarray]:
    return {
        "MV": np.asarray(majority_vote(df), dtype=float),
        "AVG": df.mean(axis=1).to_numpy(),
        "WS": np.asarray(ws_aggregate(df), dtype=float),
        "UWS": np.asarray(uws_aggregate(df), dtype=float),
    }


def run_bias_analysis(
    judge_root: Path,
    results_path: Path,
    *,
    val_fraction: float,
    random_state: int,
    gamma_grid: Iterable[float],
    penalize_dependency: bool,
    solver_kwargs: Dict[str, object],
    biases: Iterable[str],
) -> None:
    results_path.parent.mkdir(parents=True, exist_ok=True)
    summary_rows: List[Dict[str, float]] = []
    chosen_gammas: Dict[str, float] = {}

    for bias in biases:
        bias_dir = judge_root / bias
        raw_dir = judge_root / f"{bias}_raw"
        if not bias_dir.exists() or not raw_dir.exists():
            print(f"Skipping {bias}: missing judge outputs", file=sys.stderr)
            continue

        bias_df = prepare_judgement(bias_dir)
        raw_df = prepare_judgement(raw_dir)
        if DROP_COLUMN in bias_df.columns:
            bias_df = bias_df.drop(columns=[DROP_COLUMN])
        if DROP_COLUMN in raw_df.columns:
            raw_df = raw_df.drop(columns=[DROP_COLUMN])

        bias_df = bias_df.apply(pd.to_numeric, errors="coerce")
        raw_df = raw_df.apply(pd.to_numeric, errors="coerce")
        raw_df, bias_df = map_valid_score(raw_df, bias_df)

        bias_df = bias_df.astype(float)
        raw_df = raw_df.astype(float)
        n_samples = len(bias_df)
        if n_samples == 0:
            print(f"Skipping {bias}: no overlapping samples", file=sys.stderr)
            continue

        indices = np.arange(n_samples)
        if n_samples == 1:
            train_idx = val_idx = indices
        else:
            val_size = int(np.floor(n_samples * val_fraction))
            val_size = max(1, min(val_size, n_samples - 1))
            train_idx, val_idx = train_test_split(
                indices,
                test_size=val_size,
                random_state=random_state,
                shuffle=True,
            )
        raw_train = raw_df.iloc[train_idx].reset_index(drop=True)
        raw_val = raw_df.iloc[val_idx].reset_index(drop=True)
        bias_train = bias_df.iloc[train_idx].reset_index(drop=True)
        bias_val = bias_df.iloc[val_idx].reset_index(drop=True)

        gamma = tune_gamma(
            raw_train,
            bias_train,
            raw_val,
            bias_val,
            gamma_grid,
            penalize_dependency,
            solver_kwargs,
        )
        chosen_gammas[bias] = gamma

        try:
            raw_care, _ = caresl_aggregate(
                raw_df,
                gamma=gamma,
                verbose=False,
                corr_matrix=sanitize_correlation(raw_df.corr()),
                return_weights=True,
                penalize_dependency=penalize_dependency,
                **solver_kwargs,
            )
        except Exception:
            raw_care = raw_df.mean(axis=1).to_numpy()

        try:
            bias_care, _ = caresl_aggregate(
                bias_df,
                gamma=gamma,
                verbose=False,
                corr_matrix=sanitize_correlation(bias_df.corr()),
                return_weights=True,
                penalize_dependency=penalize_dependency,
                **solver_kwargs,
            )
        except Exception:
            bias_care = bias_df.mean(axis=1).to_numpy()

        base_aggs = aggregate_methods(bias_df)
        raw_aggs = aggregate_methods(raw_df)

        metrics = {
            name: compute_metrics(raw_aggs[name], base_aggs[name])
            for name in base_aggs
        }
        metrics["CARE"] = compute_metrics(raw_care, bias_care)

        summary_rows.append(
            {
                "Bias": BIAS_NAME_MAP.get(bias, bias),
                "MV": metrics["MV"]["mae"],
                "AVG": metrics["AVG"]["mae"],
                "WS": metrics["WS"]["mae"],
                "UWS": metrics["UWS"]["mae"],
                "CARE": metrics["CARE"]["mae"],
            }
        )

    if not summary_rows:
        raise RuntimeError("No bias results computed")

    summary_df = pd.DataFrame(summary_rows).round(4)
    summary_df.to_csv(results_path, index=False)

    print("=================== MAE ===================")
    print(summary_df.set_index("Bias").T)

    gamma_path = results_path.with_suffix(".gammas.json")
    gamma_path.write_text(json.dumps(chosen_gammas, indent=2, sort_keys=True))
    print(f"Saved MAE table to {results_path}")
    print(f"Saved gamma selections to {gamma_path}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "--judge-root",
        type=Path,
        default=REPO_ROOT / "judge_outputs" / "miscellaneous" / "injected_bias_1-10",
        help="Directory containing per-bias judge CSV folders.",
    )
    parser.add_argument(
        "--results-path",
        type=Path,
        default=REPO_ROOT / "results" / "injected_bias_mae.csv",
        help="Where to write the summary CSV (gamma JSON written alongside).",
    )
    parser.add_argument(
        "--val-fraction",
        type=float,
        default=0.1,
        help="Fraction of examples to hold out for validation when tuning gamma.",
    )
    parser.add_argument(
        "--random-state",
        type=int,
        default=2024,
        help="Random seed for the train/validation split.",
    )
    parser.add_argument(
        "--gamma-grid",
        type=float,
        nargs="*",
        default=DEFAULT_GAMMA_GRID,
        help="List of gamma values to search (space separated).",
    )
    parser.add_argument(
        "--penalize-dependency",
        action="store_true",
        help="Forward the penalize_dependency flag to CARESL solver.",
    )
    parser.add_argument(
        "--solver-max-iters",
        type=int,
        default=10000,
        help="Maximum iterations for the convex solver.",
    )
    parser.add_argument(
        "--biases",
        type=str,
        nargs="*",
        default=DEFAULT_BIASES,
        help="Subset of bias folders to process.",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    solver_kwargs = {"max_iters": args.solver_max_iters}
    run_bias_analysis(
        judge_root=args.judge_root,
        results_path=args.results_path,
        val_fraction=args.val_fraction,
        random_state=args.random_state,
        gamma_grid=args.gamma_grid,
        penalize_dependency=args.penalize_dependency,
        solver_kwargs=solver_kwargs,
        biases=args.biases,
    )


if __name__ == "__main__":
    main()
