#!/usr/bin/env python
"""Fully Gaussian CARESL evaluation with validation-based gamma tuning."""

from __future__ import annotations

import argparse
import csv
import random
from functools import partial
from pathlib import Path

import numpy as np
import pandas as pd
from IPython.display import display
from sklearn.model_selection import train_test_split

import sys
sys.path.append('../src')

from eval_tools import collect_metrics
import pgm_tools

csv.field_size_limit(10**9)

from data_tools import load_judge_dataset_bundle


REPO_NAME = 'llm-judge-bias'
cwd = Path.cwd().resolve()
if (cwd / 'src').exists() and (cwd / 'notebooks').exists():
    LLM_ROOT = cwd
elif (cwd / REPO_NAME).exists():
    LLM_ROOT = cwd / REPO_NAME
else:
    for parent in cwd.parents:
        candidate = parent / REPO_NAME
        if candidate.exists():
            LLM_ROOT = candidate
            break
    else:
        raise RuntimeError(f'Could not locate {REPO_NAME} repository from {cwd}')

NOTEBOOK_DIR = LLM_ROOT / 'notebooks' if (LLM_ROOT / 'notebooks').exists() else LLM_ROOT
SRC_ROOT = LLM_ROOT / 'src'

VAL_FRACTION = 0.1
DEFAULT_RANDOM_SEED = 2024
GAMMA_GRID = [0.1, 0.2, 0.25, 0.5, 0.75, 1, 2, 3, 5, 7, 10]
LEARN_STRUCTURE_SOLVER_KW = dict(max_iters=10000)

DATASETS = [
    'feedbackqa',
    'helpsteer2',
    'helpsteer3',
    'ultrafeedback',
    'summarize_from_feedback',
    'yelp_with_scores',
    'tripadvisor_reviews',
    'asset',
    'review_5k',
]

ORDERED_METHODS = ['MV', 'AVG', 'WS', 'UWS', 'CARE']
METHOD_LABELS = {
    'mv': 'MV',
    'avg': 'AVG',
    'ws': 'WS',
    'uws': 'UWS',
    'care': 'CARE',
}
DATASET_LABELS = {
    'feedbackqa': 'feedbackqa',
    'helpsteer2': 'helpsteer2',
    'ultrafeedback': 'ultrafeedback',
    'helpsteer3': 'helpsteer3',
    'summarize_from_feedback': 'summarize',
    'yelp_with_scores': 'yelp',
    'tripadvisor_reviews': 'tripadvisor',
    'asset': 'asset',
    'review_5k': 'review5k',
}


def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    try:
        import torch
    except ImportError:  # pragma: no cover - optional dependency
        return
    torch.manual_seed(seed)
    if torch.cuda.is_available():  # pragma: no cover - environment specific
        torch.cuda.manual_seed_all(seed)


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description='Fully Gaussian CARESL evaluation with validation-based gamma search.')
    parser.add_argument('--seed', type=int, default=DEFAULT_RANDOM_SEED, help='Random seed for validation splits (default: 2024).')
    parser.add_argument('--output-dir', type=Path, help='Optional override for the notebooks/results directory.')
    return parser.parse_args(argv)


def tune_gamma_with_validation(train_df: pd.DataFrame, val_df: pd.DataFrame, val_eval: pd.Series, dataset_name: str, corr_matrix_train: pd.DataFrame):
    """Return best gamma and validation telemetry for the given dataset."""
    validation_entries = []
    gamma_candidates = []
    for gamma in GAMMA_GRID:
        try:
            _, weights = pgm_tools.caresl_aggregate(
                train_df,
                gamma=gamma,
                verbose=False,
                corr_matrix=corr_matrix_train,
                return_weights=True,
                **LEARN_STRUCTURE_SOLVER_KW,
            )
            val_pred = pgm_tools.caresl_aggregate(val_df, weights=weights)
            val_mae, val_tau = collect_metrics(val_pred, val_eval)
        except Exception as exc:
            print(f'Skipping gamma {gamma} for {dataset_name}: {exc}')
            continue
        validation_entries.append({
            'dataset': dataset_name,
            'gamma': gamma,
            'val_mae': val_mae,
            'val_kendall_tau': val_tau,
        })
        if not np.isnan(val_mae):
            gamma_candidates.append((gamma, val_mae))

    best_gamma = GAMMA_GRID[0]
    if gamma_candidates:
        best_gamma = min(gamma_candidates, key=lambda item: (item[1], item[0]))[0]
    return best_gamma, validation_entries


load_dataset_bundle = partial(
    load_judge_dataset_bundle,
    project_root=LLM_ROOT,
    allow_trim=True,
    valid_ratio_threshold=0.7,
)


def run(seed: int, output_dir: Path | None = None) -> Path:
    results = []
    validation_records = []
    best_gamma_records = []
    judge_performance_records = []

    for dataset_name in DATASETS:
        print('Processing dataset:', dataset_name)
        try:
            judge_df, human_eval = load_dataset_bundle(dataset_name)
        except (FileNotFoundError, ValueError) as exc:
            print(f'Skipping {dataset_name}: {exc}')
            continue

        judge_df = judge_df.reset_index(drop=True)
        human_eval = human_eval.reset_index(drop=True)

        numeric_judges = judge_df.select_dtypes(include=[np.number])
        for judge_name in numeric_judges.columns:
            judge_mae, _ = collect_metrics(numeric_judges[judge_name], human_eval)
            judge_performance_records.append({
                'dataset': dataset_name,
                'judge': judge_name,
                'mae': judge_mae,
            })

        best_gamma = GAMMA_GRID[0]
        try:
            train_df, val_df, _, val_eval = train_test_split(
                judge_df,
                human_eval,
                test_size=VAL_FRACTION,
                random_state=seed,
                shuffle=True,
            )
            train_df = train_df.reset_index(drop=True)
            val_df = val_df.reset_index(drop=True)
            val_eval = val_eval.reset_index(drop=True)
            corr_matrix_train = pgm_tools.sanitize_correlation(train_df.corr())
            best_gamma, gamma_validation_entries = tune_gamma_with_validation(
                train_df,
                val_df,
                val_eval,
                dataset_name,
                corr_matrix_train,
            )
            validation_records.extend(gamma_validation_entries)
        except Exception as exc:
            print(f'Validation split failed for {dataset_name}: {exc}')

        best_gamma_records.append({'dataset': dataset_name, 'best_gamma': best_gamma})

        predictions = {
            'mv': pgm_tools.majority_vote(judge_df),
            'avg': judge_df.mean(axis=1),
        }

        corr_matrix_full = pgm_tools.sanitize_correlation(judge_df.corr())

        try:
            encoded_df, inverse_mapping = pgm_tools.encode_for_label_models(judge_df)
            ws_indices = pgm_tools.run_label_model(encoded_df)
            predictions['ws'] = np.array([inverse_mapping.get(idx, np.nan) for idx in ws_indices], dtype=float)
        except Exception as exc:
            print(f'Skipping WS for {dataset_name} due to error: {exc}')
            predictions['ws'] = None

        try:
            predictions['uws'] = pgm_tools.uws_aggregate(judge_df)
        except Exception as exc:
            print(f'Skipping UWS for {dataset_name} due to error: {exc}')
            predictions['uws'] = None

        try:
            predictions['care'], _ = pgm_tools.caresl_aggregate(
                judge_df,
                gamma=best_gamma,
                verbose=False,
                corr_matrix=corr_matrix_full,
                return_weights=True,
                **LEARN_STRUCTURE_SOLVER_KW,
            )
        except Exception as exc:
            print(f'Skipping CARE for {dataset_name}: {exc}')
            predictions['care'] = None

        for name, pred in predictions.items():
            mae, kendall = collect_metrics(pred, human_eval)
            row = {
                'dataset': dataset_name,
                'pred': name,
                'mae': mae,
                'kendall_tau': kendall,
            }
            if name == 'care':
                row['gamma'] = best_gamma
            results.append(row)

    results_df = pd.DataFrame(results)
    best_gamma_df = pd.DataFrame(best_gamma_records)
    validation_df = pd.DataFrame(validation_records)
    judge_results_df = pd.DataFrame(judge_performance_records)

    try:
        display(best_gamma_df)
        if not validation_df.empty:
            display(validation_df)
        display(results_df)
        if not judge_results_df.empty:
            display(judge_results_df)
    except Exception:  # pragma: no cover - display only works in notebooks
        pass

    results_dir = output_dir or (NOTEBOOK_DIR / 'results')
    results_dir.mkdir(parents=True, exist_ok=True)
    results_path = results_dir / 'fully_gaussian_main.csv'
    results_df.to_csv(results_path, index=False)
    best_gamma_path = results_dir / 'fully_gaussian_main_best_gamma.csv'
    best_gamma_df.to_csv(best_gamma_path, index=False)
    if not validation_df.empty:
        validation_path = results_dir / 'fully_gaussian_main_validation.csv'
        validation_df.to_csv(validation_path, index=False)
    if not judge_results_df.empty:
        judge_results_path = results_dir / 'fully_gaussian_main_judges.csv'
        judge_results_df.to_csv(judge_results_path, index=False)

    mae_table = (
        results_df
        .replace({'pred': METHOD_LABELS})
        .replace({'dataset': DATASET_LABELS})
        .pivot_table(index='pred', columns='dataset', values='mae')
        .rename_axis(index='method_label', columns='dataset_label')
    ).reindex(ORDERED_METHODS)

    try:
        display(mae_table.style.set_caption('MAE by method and dataset'))
    except Exception:  # pragma: no cover
        pass

    return results_path


def main(argv: list[str] | None = None) -> int:
    args = parse_args(argv)
    seed = int(args.seed)
    set_global_seed(seed)
    run(seed=seed, output_dir=args.output_dir)
    return 0


if __name__ == '__main__':
    raise SystemExit(main())
