# -*- coding: utf-8 -*-
"""
python run_benchmark_compare.py --cores 4 --chains 2 --sample_ratio 0.01 --slope_mode item --engine pymc --draws 1000 --tune 1000
python run_benchmark_compare.py --train_per_bench 500 --val_per_bench 100 --test_per_bench 100 --slope_mode rasch --cores 4 --chains 4 --draws 500 --tune 800 --seed 123
"""

import os
import argparse
from pathlib import Path
import sys
from typing import Optional, List, Tuple, Dict

sys.setrecursionlimit(10000)

import numpy as np
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
from scipy.special import expit
from scipy.optimize import minimize
import arviz as az
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
import logging

# Set logging level
logger = logging.getLogger("pymc")
logger.setLevel(logging.INFO)


# ======= IRT cite Method Related Functions =======
def sigmoid(z):
    return 1/(1+np.exp(-z))

def item_curve(theta, a, b):
    z = np.clip(a*theta - b, -30, 30).sum(axis=1)
    return sigmoid(z)

def fit_theta(responses_test, seen_items, A, B, theta_init=None, eps=1e-10, optimizer="BFGS"):
    D = A.shape[1]
    # Define the negative log likelihood function
    def neg_log_like(x):
        P = item_curve(x.reshape(1, D, 1), A[:, :, seen_items], B[:, :, seen_items]).squeeze()
        log_likelihood = np.sum(responses_test[seen_items] * np.log(P + eps) + (1 - responses_test[seen_items]) * np.log(1 - P + eps))
        return -log_likelihood
    # Use the minimize function to find the ability parameters that minimize the negative log likelihood
    optimal_theta = minimize(neg_log_like, np.zeros(D), method = optimizer).x[None,:,None] 
    return optimal_theta

# ======= Data Preparation: Construct Y by Single Benchmark =======
def build_inputs_single_bench(csv_path: Path,
                              bench_name: str,
                              sample_ratio: float,
                              seed: int = 42,
                              train_per_bench: Optional[int] = None,
                              val_per_bench: Optional[int] = None,
                              test_per_bench: Optional[int] = None
                              ) -> Tuple[int,int,List[str],np.ndarray,Optional[np.ndarray],Optional[np.ndarray],Optional[np.ndarray]]:
    """
    Filter columns of a specific benchmark from the merged *_is_correct matrix and sample.
    - If fixed item counts per bench for train/val/test are provided (any non-empty), enable set-wise fixed sampling mode, sample according to each item count (try to satisfy as much as possible when insufficient), their sum is the total sampling upper limit for the benchmark; ignore sample_ratio.
    - Otherwise, sample by sample_ratio proportion.
    Returns: N(number of models), J(number of items), models(model name list), Y_full(N,J), and (optional) column position indices of train/val/test in the subset matrix.
    """
    df = pd.read_csv(csv_path, index_col=0)
    def parse_bench(col: str) -> str:
        return col.split("_", 1)[0]
    col_bench = [parse_bench(c) for c in df.columns]

    # Find columns under this benchmark
    bench_cols_idx = [i for i, b in enumerate(col_bench) if b == bench_name]
    if len(bench_cols_idx) == 0:
        print(f"[WARN] No columns found under {bench_name}, skipping.")
        return 0, 0, [], np.zeros((0,0)), None, None, None

    rng = np.random.default_rng(seed)
    use_setwise = any(x is not None for x in (train_per_bench, val_per_bench, test_per_bench))

    if use_setwise:
        t_tr = int(train_per_bench or 0)
        t_va = int(val_per_bench or 0)
        t_te = int(test_per_bench or 0)
        t_total = t_tr + t_va + t_te
        if t_total <= 0:
            print(f"[WARN] {bench_name}: train/val/test target total is 0, skipping.")
            return 0, 0, [], np.zeros((0,0)), None, None, None
        k_total = min(t_total, len(bench_cols_idx))
        if k_total <= 0:
            print(f"[WARN] {bench_name}: Available columns is 0, skipping.")
            return 0, 0, [], np.zeros((0,0)), None, None, None
        chosen = rng.choice(bench_cols_idx, size=k_total, replace=False)
        k_tr = min(t_tr, k_total)
        k_va = min(t_va, max(0, k_total - k_tr))
        k_te = min(t_te, max(0, k_total - k_tr - k_va))
        tr = chosen[:k_tr]
        va = chosen[k_tr:k_tr+k_va]
        te = chosen[k_tr+k_va:k_tr+k_va+k_te]
        df_sub = df.iloc[:, np.sort(chosen)]
        models = df_sub.index.tolist()
        Y_full = df_sub.values.astype(int)
        N, J = Y_full.shape
        # Map original column indices to subset column space
        col_map = {orig_idx: new_pos for new_pos, orig_idx in enumerate(sorted(chosen))}
        train_pos = np.array([col_map[i] for i in tr], dtype=int) if k_tr > 0 else np.array([], dtype=int)
        val_pos   = np.array([col_map[i] for i in va], dtype=int) if k_va > 0 else np.array([], dtype=int)
        test_pos  = np.array([col_map[i] for i in te], dtype=int) if k_te > 0 else np.array([], dtype=int)
        print(f"[{bench_name}] set-wise sampling: train={len(train_pos)}, val={len(val_pos)}, test={len(test_pos)} / original {len(bench_cols_idx)}")
        return N, J, models, Y_full, train_pos, val_pos, test_pos
    else:
        # Sample columns of this benchmark by proportion
        k = max(1, int(len(bench_cols_idx) * sample_ratio))
        k = min(k, len(bench_cols_idx))
        chosen = rng.choice(bench_cols_idx, size=k, replace=False)
        df_sub = df.iloc[:, np.sort(chosen)]
        models = df_sub.index.tolist()
        Y_full = df_sub.values.astype(int)
        N, J = Y_full.shape
        print(f"[{bench_name}] After sampling: J={J} (original {len(bench_cols_idx)} columns, sample ratio {sample_ratio:.3f})")
        return N, J, models, Y_full, None, None, None


# ======= Train/Test Split =======
def split_train_test(Y_full: np.ndarray, train_ratio: float = 0.9, seed: int = 42):
    rng = np.random.default_rng(seed)
    mask = rng.random(Y_full.shape) < train_ratio
    Y_train = Y_full.astype(float)
    Y_train[~mask] = np.nan
    Y_test_mask = ~mask
    return Y_train, Y_test_mask, Y_full


# ======= Single Benchmark 2PL/Rasch IRT Modeling =======
def build_MCMC_single_bench(N: int,
                            J: int,
                            Y_train: np.ndarray,
                            draws: int = 1000,
                            tune: int = 1000,
                            chains: int = 2,
                            cores: int = 1,
                            engine: str = "pymc",
                            slope_mode: str = "item",
                            target_accept: float = 0.9) -> az.InferenceData:
    """
    Standard IRT for a single benchmark:
      theta_i ~ N(0,1)
      If slope_mode='rasch': a_j = 1
      Otherwise: a_j ~ LogNormal(log 1, 0.5)
      b_j ~ N(0,1)
      y_ij ~ Bernoulli(logit( a_j*(theta_i - b_j) ))
    """
    with pm.Model() as model:
        theta = pm.Normal("theta", mu=0.0, sigma=1.0, shape=N)
        if slope_mode == "rasch":
            a = pt.ones((J,), dtype="float64")
            a = pm.Deterministic("a", a)
        else:
            a = pm.LogNormal("a", mu=np.log(1.0), sigma=0.5, shape=J)
        b = pm.Normal("b", mu=0.0, sigma=1.0, shape=J)

        logits = a[None, :] * (theta[:, None] - b[None, :])

        # Place likelihood only at observed positions
        obs_i, obs_j = np.where(~np.isnan(Y_train))
        y_obs_vec = Y_train[obs_i, obs_j].astype(int)
        logits_obs = logits[obs_i, obs_j]
        pm.Bernoulli("y_obs", logit_p=logits_obs, observed=y_obs_vec)

        print(f"  [MCMC-SINGLE] Start sampling: draws={draws}, tune={tune}, chains={chains}, cores={cores}, engine={engine}, slope_mode={slope_mode}")
        trace = pm.sample(
            draws=draws,
            tune=tune,
            chains=chains,
            cores=cores,
            target_accept=target_accept,
            return_inferencedata=True,
            progressbar=True,
            nuts_sampler=engine,
            compute_convergence_checks="eager"
        )
        print("  [MCMC-SINGLE] Sampling completed")
    return trace


# ======= IRT cite Method Modeling =======
def build_IRT_cite_single_bench(Y_train: np.ndarray, 
                                Y_true: np.ndarray, 
                                Y_test_mask: np.ndarray) -> Dict[str, float]:
    """
    Model and evaluate a single benchmark using the IRT cite method
    """
    print("  [IRT-CITE] Starting modeling with IRT cite method...")
    
    # Get number of models and items
    N, J = Y_train.shape
    
    # Initialize parameters
    # Here we simplify the processing, assuming each item has only one dimensional parameter
    # In practical applications, more complex parameter initialization may be needed
    A = np.ones((1, 1, J))  # Simplified to 1 dimension
    B = np.zeros((1, 1, J)) # Simplified to 1 dimension
    
    # Estimate theta for each model
    theta_estimates = []
    for i in range(N):
        # Get response data for this model
        responses = Y_train[i, :]
        # Determine observed items
        seen_items = np.where(~np.isnan(responses))[0]
        
        if len(seen_items) > 0:
            # Estimate theta
            theta = fit_theta(responses, seen_items, A, B)
            theta_estimates.append(theta.flatten())
        else:
            theta_estimates.append(np.zeros(1))
    
    theta_estimates = np.array(theta_estimates)
    
    # Make predictions using estimated theta
    predictions = []
    true_values = []
    
    for i in range(N):
        for j in range(J):
            if Y_test_mask[i, j]:  # If it's test set data
                # Make predictions using item characteristic curve
                pred = item_curve(theta_estimates[i:i+1, :], A[:, :, j:j+1], B[:, :, j:j+1])
                predictions.append(pred[0])
                true_values.append(Y_true[i, j])
    
    if len(predictions) > 0:
        predictions = np.array(predictions)
        true_values = np.array(true_values)
        
        # Calculate metrics
        pred_binary = (predictions > 0.5).astype(int)
        acc = accuracy_score(true_values, pred_binary)
        try:
            auc = roc_auc_score(true_values, predictions)
        except:
            auc = float('nan')
            
        print(f"  [IRT-CITE] Modeling completed, Accuracy: {acc:.4f}, AUC: {auc:.4f}")
        return {"accuracy": acc, "auc": auc}
    else:
        print("  [IRT-CITE] No test data available for evaluation")
        return {"accuracy": float('nan'), "auc": float('nan')}


# ======= MCMC Diagnostics =======
def evaluate_mcmc_reliability(trace: az.InferenceData, var_names: List[str], save_dir: Path):
    summary = az.summary(trace, var_names=var_names)
    summary_path = save_dir / "mcmc_summary.csv"
    summary.to_csv(summary_path)
    rhat = summary["r_hat"]
    ess = summary["ess_bulk"]
    if (rhat > 1.01).any():
        print("⚠️ Convergence warning (r_hat > 1.01):")
        print(rhat[rhat > 1.01])
    else:
        print("✅ Parameters converged well")
    if (ess < 200).any():
        print("⚠️ Insufficient samples (ess_bulk < 200):")
        print(ess[ess < 200])
    else:
        print("✅ Sufficient effective samples")

    fig = az.plot_trace(trace, var_names=var_names)
    plt.tight_layout()
    plt.savefig(save_dir / "trace_plots.png", dpi=200)
    plt.close()

    az.plot_energy(trace)
    plt.tight_layout()
    plt.savefig(save_dir / "energy_plot.png", dpi=200)
    plt.close()


# ======= Metric Evaluation (Aligned with mix standard)=======
def evaluate_model_single(trace: az.InferenceData,
                          Y_true: np.ndarray,
                          Y_mask: np.ndarray,
                          save_dir: Path,
                          tag: str = "eval"):
    """
    Produce Overall Accuracy / AUC / Confusion Matrix aligned with run_mix_benchmark.py;
    And save scatter: True Overall Acc vs theta_hat
    """
    hat = {var: trace.posterior[var].mean(dim=("chain","draw")).values
           for var in ['theta','a','b']}
    theta_hat, a_hat, b_hat = hat['theta'], hat['a'], hat['b']
    N, J = Y_true.shape

    logits = a_hat[None, :] * (theta_hat[:, None] - b_hat[None, :])
    p_pred = 1.0 / (1.0 + np.exp(-logits))
    y_pred = (p_pred > 0.5).astype(int)

    mask_all = ~np.isnan(Y_true) if Y_mask is None else Y_mask
    y_true_all = Y_true[mask_all]
    p_all      = p_pred[mask_all]
    yhat_all   = y_pred[mask_all]
    acc_all = accuracy_score(y_true_all, yhat_all)
    # AUC may error out in all 0/1 extremes or single class, add protection
    try:
        auc_all = roc_auc_score(y_true_all, p_all)
    except Exception:
        auc_all = float('nan')
    cm_all  = confusion_matrix(y_true_all, yhat_all)

    print(f"[{tag}][Overall] Accuracy: {acc_all:.4f}, AUC: {auc_all:.4f}")
    print("Confusion Matrix:\n", cm_all)

    with open(save_dir / f"{tag}_report.txt", "w") as f:
        f.write(f"[{tag}] Overall Accuracy: {acc_all:.6f}, AUC: {auc_all:.6f}\n")
        f.write(f"Confusion Matrix:\n{cm_all}\n")

    true_acc_overall = np.nanmean(Y_true, axis=1)
    plt.figure(figsize=(5,4))
    plt.scatter(true_acc_overall, theta_hat, alpha=0.6)
    plt.xlabel("True Overall Accuracy"); plt.ylabel("Estimated theta")
    plt.title(f"Overall: True Acc vs. theta_hat ({tag})")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_dir / f"{tag}_overall_scatter.png", dpi=200)
    plt.close()

    # Return metrics for aggregation
    return {"accuracy": acc_all, "auc": auc_all}


# ======= CLI Main Process =======
def main():
    parser = argparse.ArgumentParser(description="Run per-benchmark IRT on merged is_correct matrix with comparison to IRT cite method.")
    parser.add_argument("--input_csv", type=str,
                        default="/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/data_result_raw/merged_mix_benchmark/merged_is_correct_matrix.csv",
                        help="Path to merged *_is_correct matrix CSV.")
    parser.add_argument("--outputdir", type=str,
                        default="yourpath/result_benchmark_compare",
                        help="Directory to write results; subfolders per BENCH will be created.")
    parser.add_argument("--sample_ratio", type=float, default=0.1,
                        help="Column sampling ratio within each benchmark (0,1].")
    parser.add_argument("--train_ratio", type=float, default=0.9,
                        help="Train ratio for masking.")
    parser.add_argument("--train_per_bench", type=int, default=None, help="Per-benchmark number of items for TRAIN split (overrides sample_ratio).")
    parser.add_argument("--val_per_bench", type=int, default=None, help="Per-benchmark number of items for VALID split.")
    parser.add_argument("--test_per_bench", type=int, default=None, help="Per-benchmark number of items for TEST split.")
    parser.add_argument("--draws", type=int, default=1000, help="MCMC draws.")
    parser.add_argument("--tune", type=int, default=1000, help="MCMC tune.")
    parser.add_argument("--seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--cores", type=int, default=1, help="Number of CPU cores for sampling.")
    parser.add_argument("--chains", type=int, default=2, help="Number of MCMC chains.")
    parser.add_argument("--engine", type=str, default="pymc", choices=["pymc","numpyro","blackjax"], help="NUTS backend.")
    parser.add_argument("--slope_mode", type=str, default="item", choices=["item","rasch"],
                        help="Discrimination parameterization: per-item (2PL) or Rasch (a=1).")
    parser.add_argument("--target_accept", type=float, default=0.9, help="NUTS target_accept.")
    args = parser.parse_args()

    os.environ["PYMC_CORES_OVERRIDE"] = str(max(1, int(args.cores)))
    print(f"[Config] Using cores={os.environ['PYMC_CORES_OVERRIDE']} (can be modified via --cores)")

    out_root = Path(args.outputdir)
    out_root.mkdir(parents=True, exist_ok=True)

    # Only run these three types, consistent with mix code
    allowed_benches = ["CEVAL", "CSQA", "MMLU"]
    aggregate_rows = []

    for bench in allowed_benches:
        print("\n" + "="*80)
        print(f"[Step 1/{bench}] Reading and building single benchmark matrix: {bench}")
        N, J, models, Y_full, train_pos, val_pos, test_pos = build_inputs_single_bench(
            Path(args.input_csv), bench_name=bench,
            sample_ratio=args.sample_ratio, seed=args.seed,
            train_per_bench=args.train_per_bench, val_per_bench=args.val_per_bench, test_per_bench=args.test_per_bench
        )
        if N == 0 or J == 0:
            continue

        bench_dir = out_root / bench
        bench_dir.mkdir(parents=True, exist_ok=True)

        print(f"[{bench}] N={N}, J={J}")

        print(f"[Step 2/{bench}] Splitting train/val/test sets...")
        use_setwise = any(x is not None for x in (args.train_per_bench, args.val_per_bench, args.test_per_bench))
        if use_setwise:
            Y_true = Y_full.copy()
            Y_train = Y_full.astype(float).copy()
            col_mask = np.zeros(J, dtype=bool)
            if train_pos is not None and len(train_pos) > 0:
                col_mask[train_pos] = True
            Y_train[:, ~col_mask] = np.nan
            # Construct 2D mask
            def cols_to_mask(cols: Optional[np.ndarray]) -> np.ndarray:
                if cols is None or len(cols) == 0:
                    return np.zeros_like(Y_true, dtype=bool)
                m = np.zeros_like(Y_true, dtype=bool)
                m[:, cols] = True
                return m
            Y_train_mask = cols_to_mask(train_pos)
            Y_val_mask   = cols_to_mask(val_pos)
            Y_test_mask  = cols_to_mask(test_pos)
        else:
            Y_train, Y_test_mask, Y_true = split_train_test(Y_full, train_ratio=args.train_ratio, seed=args.seed)
            Y_val_mask = None
            Y_train_mask = ~np.isnan(Y_train)

        print(f"[Step 3/{bench}] Starting MCMC sampling...")
        trace = build_MCMC_single_bench(
            N=N, J=J, Y_train=Y_train,
            draws=args.draws, tune=args.tune,
            chains=args.chains, cores=int(os.environ["PYMC_CORES_OVERRIDE"]),
            engine=args.engine, slope_mode=args.slope_mode,
            target_accept=args.target_accept
        )

        print(f"[Step 4/{bench}] Sampling completed, convergence diagnosis...")
        evaluate_mcmc_reliability(trace, ["theta", "a", "b"], save_dir=bench_dir)

        print(f"[Step 5/{bench}] Evaluating train/val/test set performance...")
        train_metrics = evaluate_model_single(trace, Y_true, Y_train_mask if use_setwise else ~np.isnan(Y_train), bench_dir, tag="train")
        if use_setwise and Y_val_mask is not None and Y_val_mask.any():
            print(f"[{bench}] Valid set evaluation...")
            _ = evaluate_model_single(trace, Y_true, Y_val_mask, bench_dir, tag="valid")
        print(f"[{bench}] Test set evaluation...")
        test_metrics = evaluate_model_single(trace, Y_true, Y_test_mask if use_setwise else np.isnan(Y_train), bench_dir, tag="test")
        
        # Run IRT cite method
        print(f"[Step 6/{bench}] Running IRT cite method...")
        irt_cite_metrics = build_IRT_cite_single_bench(Y_train, Y_true, Y_test_mask)

        # Save aggregated results for current benchmark
        aggregate_rows.append({
            "benchmark": bench,
            "N_models": N,
            "J_items_sampled": J,
            "train_accuracy": train_metrics["accuracy"],
            "train_auc": train_metrics["auc"],
            "test_accuracy": test_metrics["accuracy"],
            "test_auc": test_metrics["auc"],
            "irt_cite_test_accuracy": irt_cite_metrics["accuracy"],
            "irt_cite_test_auc": irt_cite_metrics["auc"],
            "slope_mode": args.slope_mode,
            "engine": args.engine,
            "draws": args.draws,
            "tune": args.tune,
            "chains": args.chains,
            "cores": int(os.environ["PYMC_CORES_OVERRIDE"]),
            "target_accept": args.target_accept,
            "train_ratio": args.train_ratio,
            "sample_ratio": args.sample_ratio,
            "train_per_bench": args.train_per_bench,
            "val_per_bench": args.val_per_bench,
            "test_per_bench": args.test_per_bench,
            "seed": args.seed
        })

        # Also save main parameter point estimates for later comparison
        theta_hat = trace.posterior["theta"].mean(dim=("chain","draw")).values
        a_hat     = trace.posterior["a"].mean(dim=("chain","draw")).values
        b_hat     = trace.posterior["b"].mean(dim=("chain","draw")).values
        pd.DataFrame({"model": models, "theta_hat": theta_hat}).to_csv(bench_dir / "theta_hat.csv", index=False)
        pd.DataFrame({"a_hat": a_hat, "b_hat": b_hat}).to_csv(bench_dir / "item_params.csv", index=False)

    # Output aggregated summary (for horizontal comparison with mix model)
    if len(aggregate_rows) > 0:
        df_aggr = pd.DataFrame(aggregate_rows)
        df_aggr.sort_values(by="benchmark", inplace=True)
        df_aggr.to_csv(out_root / "per_benchmark_summary_compare.csv", index=False)
        print("\n[Summary] Written to per_benchmark_summary_compare.csv")
        
        # Print comparison results
        print("\n" + "="*80)
        print("Method comparison results:")
        print("="*80)
        for _, row in df_aggr.iterrows():
            print(f"\nBenchmark: {row['benchmark']}")
            print(f"  Original method (MCMC) Test Accuracy: {row['test_accuracy']:.4f}, AUC: {row['test_auc']:.4f}")
            print(f"  IRT cite method Test Accuracy: {row['irt_cite_test_accuracy']:.4f}, AUC: {row['irt_cite_test_auc']:.4f}")
    else:
        print("\n[Summary] No available benchmarks, no aggregated results generated.")


if __name__ == "__main__":
    main()