import argparse
import math
import os
import time
import random
from datetime import datetime
from functools import partial  # Import partial for wrapping

import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from scipy.special import softmax
from scipy import stats
import rescale_residual as RR
from model_utils import train_model_feature

# =======================================================================
ETA_PRIOR = 1


# =======================================================================
# Helper Functions (Unchanged)
# =======================================================================
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)


def sin_noise(n_sample, std=0.1):
    d = 10
    X = np.random.randn(n_sample, d)
    b = np.ones(d) / d
    Y = np.zeros(n_sample)
    Y += np.matmul(X, b)
    return X, np.sin(Y) + np.random.random(n_sample) * std


# New dataset generators
def easyX_sparse_normal(n_sample):
    d = 300
    s = 18
    X = np.random.randn(n_sample, d)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.normal(size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


def easyX_sparse_simpleT(n_sample):
    d = 300
    s = 18
    nu = 3
    X = np.random.randn(n_sample, d)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.standard_t(df=nu, size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


def easyX_dense(n_sample):
    d = 300
    X = np.random.randn(n_sample, d)
    Y = np.mean(X, axis=1) + (np.random.randn(n_sample) * (1 / d))
    return X, Y


def easytX_sparse_normal(n_sample):
    d = 300
    s = 18
    nu = 3
    X = np.zeros((n_sample, d))
    Id = np.eye(d)
    mean = np.zeros(d)
    X += stats.multivariate_t.rvs(mean, Id, df=nu, size=n_sample)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.normal(size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


# =======================================================================
# Wrapper Definitions (Unchanged from V11)
# =======================================================================


def wrap_existing_conformal(base_func, split_portion=None):
    """Wrapper 1: Wraps prior-agnostic methods"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        call_args = {
            "cal_preds": cal_preds,
            "cal_y": cal_y,
            "cal_sigma": cal_sigma,
            "test_pred": test_pred,
            "test_sigma": test_sigma,
            "test_y": test_y,
            "alpha": alpha,
        }
        if split_portion is not None:
            call_args["split_portion"] = split_portion
        return base_func(**call_args)

    return configured_func


def wrap_stable_conformal(base_func, eta=1.0, ignore_prior=False):
    """Wrapper 2: Wraps RR.stable_conformal"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior
        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha=alpha,
            b_prior=prior_to_use,
            eta=eta,
        )

    return configured_func


def wrap_adaptive_stable_conformal(base_func, ratio=0.75, ignore_prior=False):
    """Wrapper 3: Wraps RR.adaptive_stable_conformal"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior
        alpha_prime = alpha * ratio
        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha_prime=alpha_prime,
            alpha=alpha,
            b_prior=prior_to_use,
        )

    return configured_func


def wrap_internal_split_calibration(
    base_func,
    alpha_pre_selection=0.1,
    alpha_post_selection=0.75,
    N_resamples=10,
    preliminary_gamma=0.9,
    aux_split_ratio=0.5,
    ignore_prior=False,
):
    """Wrapper 4: Wraps calibrate_after_selection_resampling_no_split"""

    def configured_func(
        cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, b_prior
    ):
        M = cal_preds.shape[0]
        prior_to_use = np.ones(M) / M if ignore_prior else b_prior

        config_args_for_base = {
            "N_resamples": N_resamples,
            "preliminary_gamma": preliminary_gamma,
            "aux_split_ratio": aux_split_ratio,
        }

        return base_func(
            cal_preds=cal_preds,
            cal_y=cal_y,
            cal_sigma=cal_sigma,
            test_pred=test_pred,
            test_sigma=test_sigma,
            test_y=test_y,
            alpha=alpha,
            b_prior=prior_to_use,
            alpha_pre_selection=alpha_pre_selection,
            alpha_post_selection=alpha_post_selection,
            **config_args_for_base,
        )

    return configured_func


def process_rep(
    t,
    n_cal,
    cal_preds_all,
    cal_sigma_all,
    Y_cal_all,
    test_pred_all,
    test_sigma_all,
    Y_test_all,
    configured_aggregators,
    train_prior,
    alpha,
):
    cal_start, cal_end = t * n_cal, (t + 1) * n_cal
    cal_preds, cal_sigma, Y_cal = (
        cal_preds_all[:, cal_start:cal_end],
        cal_sigma_all[:, cal_start:cal_end],
        Y_cal_all[cal_start:cal_end],
    )
    test_pred_point, test_sigma_point, y_test_point = (
        test_pred_all[:, t],
        test_sigma_all[:, t],
        Y_test_all[t],
    )
    rep_results = {}
    for name, alg_func in configured_aggregators.items():
        coverage, length, *_ = alg_func(
            cal_preds=cal_preds,
            cal_y=Y_cal,
            cal_sigma=cal_sigma,
            test_pred=test_pred_point,
            test_sigma=test_sigma_point,
            test_y=y_test_point,
            alpha=alpha,
            b_prior=train_prior,  # Always pass train_prior
        )
        rep_results[name] = (float(coverage), float(length))
    return rep_results


# =======================================================================
# Experiment Runner Function (Unchanged from V11)
# =======================================================================
def experiment_rr(
    data_gene,
    configured_aggregators,  # Single dict of ALL methods
    N_rep,
    n_tr,
    n_cal,
    M,
    alpha,
    num_partitions,
    seed,
):
    set_seed(seed)
    # 1. Training
    X_tr, Y_tr = data_gene(n_tr)
    Mmodels = train_model_feature(M, X_tr, Y_tr, number_partitions=num_partitions)
    M = len(Mmodels)
    print(f"Trained {M} models.")
    # 2. Precompute Data
    total_cal = N_rep * n_cal
    total_test = N_rep
    X_cal_all, Y_cal_all = data_gene(total_cal)
    X_test_all, Y_test_all = data_gene(total_test)
    # 3. Precompute Predictions
    print("Precomputing predictions...")
    start_pred = time.time()
    cal_preds_all = np.array([m[0].predict(X_cal_all) for m in Mmodels])
    cal_sigma_all = np.clip(
        np.array([m[1].predict(X_cal_all) for m in Mmodels]), 1e-6, np.inf
    )
    test_pred_all = np.array([m[0].predict(X_test_all) for m in Mmodels])
    test_sigma_all = np.clip(
        np.array([m[1].predict(X_test_all) for m in Mmodels]), 1e-6, np.inf
    )
    print(f"Prediction precomputation time: {time.time()-start_pred:.2f}s")
    # 3prime. Compute ONLY train_prior
    train_sigma_for_prior = np.clip(
        np.array([m[1].predict(X_tr) for m in Mmodels]), 1e-6, np.inf
    )

    avg_train_sigma = np.sum(train_sigma_for_prior, axis=1)
    train_prior = softmax(ETA_PRIOR * -avg_train_sigma)
    print("Training Prior computed.")

    # 4. Parallel execution (Pass only train_prior)
    print("Starting parallel processing...")
    start_par = time.time()
    results_per_rep = Parallel(n_jobs=-1, backend="loky")(
        delayed(process_rep)(
            t=t,
            n_cal=n_cal,
            cal_preds_all=cal_preds_all,
            cal_sigma_all=cal_sigma_all,
            Y_cal_all=Y_cal_all,
            test_pred_all=test_pred_all,
            test_sigma_all=test_sigma_all,
            Y_test_all=Y_test_all,
            configured_aggregators=configured_aggregators,  # Pass single dict
            train_prior=train_prior,  # Pass only train_prior
            alpha=alpha,
        )
        for t in tqdm(range(N_rep), desc="Processing reps")
    )
    print(f"Parallel processing finished. Elapsed: {time.time()-start_par:.2f}s")

    all_names = sorted(list(configured_aggregators.keys()))
    coverage_dict = {k: np.zeros(N_rep) for k in all_names}
    length_dict = {k: np.zeros(N_rep) for k in all_names}
    for t, rep_results in enumerate(results_per_rep):
        for k in all_names:
            coverage_dict[k][t], length_dict[k][t] = rep_results.get(
                k, (np.nan, np.nan)
            )
    report_coverage, report_length = {}, {}
    print("\n--- Aggregated Results ---")
    for k in all_names:
        valid_cov = coverage_dict[k][~np.isnan(coverage_dict[k])]
        valid_len = length_dict[k][~np.isnan(length_dict[k])]
        if len(valid_cov) > 0:
            report_coverage[k] = [np.mean(valid_cov), np.std(valid_cov)]
            report_length[k] = [np.mean(valid_len), np.std(valid_len)]
            print(
                f"{k:<30}: Coverage={report_coverage[k][0]:.4f} ± {report_coverage[k][1]:.4f}, Length={report_length[k][0]:.4f} ± {report_length[k][1]:.4f} ({len(valid_cov)} reps)"
            )
        else:
            print(f"{k:<30}: No valid results.")
            report_coverage[k] = [np.nan, np.nan]
            report_length[k] = [np.nan, np.nan]
    print("-" * 31)
    return report_coverage, report_length, all_names


# =======================================================================
# Main Execution Block
# =======================================================================
def main(args):
    distribution_map = {
        "sin_noise": sin_noise,
        "easyX_sparse_normal": easyX_sparse_normal,
        "easyX_sparse_simpleT": easyX_sparse_simpleT,
        "easyX_dense": easyX_dense,
        "easytX_sparse_normal": easytX_sparse_normal,
    }
    if args.distribution not in distribution_map:
        raise ValueError(f"Invalid distribution: {args.distribution}")
    data_generator = distribution_map[args.distribution]
    print(f"Using distribution: {args.distribution}")

    # Parse M values from argument or use default
    if args.M_values:
        try:
            M_values = [int(m.strip()) for m in args.M_values.split(',')]
            if not all(m > 0 for m in M_values):
                raise ValueError("All M values must be positive integers")
        except ValueError as e:
            raise ValueError(f"Invalid M values format. Please provide comma-separated integers. Error: {e}")
    else:
        M_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]  # Default values

    # --- Create the Single Dictionary for ALL Configured Aggregators ---
    all_configured_aggregators = {}
    # 1. Wrap Existing methods
    existing_base_funcs = {
        "ModSel": RR.ModSel_rescaleRes,
        "YKbaseline": RR.YKbaseline_rescaleRes,
        "YK_adj": RR.YK_adj_rescaleRes,
        "YKsplit": RR.YKsplit_rescaleRes,
    }
    existing_fixed_params = {"YKsplit": {"split_portion": args.split_portion}}
    for name, base_func in existing_base_funcs.items():
        fixed_params = existing_fixed_params.get(name, {})
        all_configured_aggregators[name] = wrap_existing_conformal(
            base_func, **fixed_params
        )
    # 2. Wrap Stable methods
    stable_method_base_funcs = {"StableEta1.0": RR.stable_conformal}
    stable_method_fixed_params = {
        "StableEta1.0": {"eta": 1.0},
        "StableEta2.0": {"eta": 2.0},
        "StableEta0.1": {"eta": 0.1},
    }
    for base_name, base_func in stable_method_base_funcs.items():
        fixed_params = stable_method_fixed_params.get(base_name, {})
        all_configured_aggregators[f"UP {base_name}"] = wrap_stable_conformal(
            base_func, **fixed_params, ignore_prior=True
        )
    # 3. Wrap Adaptive Stable methods
    adaptive_stable_configs = {
        "AdaStable0.50": {"ratio": 0.50},
        "AdaStable0.75": {"ratio": 0.75},
        "AdaStable0.90": {"ratio": 0.90},
    }
    for base_name, fixed_params in adaptive_stable_configs.items():
        all_configured_aggregators[f"UP {base_name}"] = wrap_adaptive_stable_conformal(
            base_func=RR.adaptive_stable_conformal, **fixed_params, ignore_prior=True
        )

    internal_split_base_configs = {
        "InstCal0.2": {
            "base_func": RR.calibrate_after_selection_resampling,
            "config": {
                "alpha_pre_selection": 0.1,
                "alpha_post_selection": 0.2,
                "N_resamples": args.N_resamples,
                "preliminary_gamma": 1 - args.alpha,
                "aux_split_ratio": 0.5,
            },
        },
        "InstCal0.5": {
            "base_func": RR.calibrate_after_selection_resampling,
            "config": {
                "alpha_pre_selection": 0.1,
                "alpha_post_selection": 0.5,
                "N_resamples": args.N_resamples,
                "preliminary_gamma": 1 - args.alpha,
                "aux_split_ratio": 0.5,
            },
        },
        "InstCal0.8": {
            "base_func": RR.calibrate_after_selection_resampling,
            "config": {
                "alpha_pre_selection": 0.1,
                "alpha_post_selection": 0.8,
                "N_resamples": args.N_resamples,
                "preliminary_gamma": 1 - args.alpha,
                "aux_split_ratio": 0.5,
            },
        },
        "InstCal1.0": {
            "base_func": RR.calibrate_after_selection_resampling,
            "config": {
                "alpha_pre_selection": 0.1,
                "alpha_post_selection": 1.0,
                "N_resamples": args.N_resamples,
                "preliminary_gamma": 1 - args.alpha,
                "aux_split_ratio": 0.5,
            },
        },
    }

    for base_name, cfg in internal_split_base_configs.items():
        all_configured_aggregators[f"{base_name}-UP"] = wrap_internal_split_calibration(
            base_func=cfg["base_func"], **cfg["config"], ignore_prior=True
        )

    # --- Results File Setup ---
    dist_name = args.distribution
    part_str = f"P{args.num_partition}" if args.num_partition > 0 else "Homog"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    base_filename = f"SYNTH_{dist_name}_{part_str}_Ncal{args.N_cal}"  # V11 naming
    results_dir = "synthetic_results"
    os.makedirs(results_dir, exist_ok=True)
    filename = os.path.join(results_dir, f"{base_filename}_{timestamp}.csv")
    print(f"Results will be saved to: {filename}")

    # --- Main Loop ---
    all_results_data = []  # Initialize list to store results from all runs

    for M in M_values:
        for seed in range(args.num_seeds):
            print(f"\n{'='*50}")
            print(
                f"STARTING: M={M}, Seed={seed+1}/{args.num_seeds}, N_cal={args.N_cal}, Partitioning={part_str}, Dist={dist_name}"
            )
            print(f"{'='*50}")
            set_seed(seed)

            cov_rep, leng_rep, result_names = experiment_rr(
                data_gene=data_generator,
                configured_aggregators=all_configured_aggregators,  # Pass single dict
                N_rep=args.N_rep,
                n_tr=args.n_tr,
                n_cal=args.N_cal,  # Total cal size per rep
                M=M,
                alpha=args.alpha,
                num_partitions=args.num_partition,
                seed=seed,
            )
            # --- Store results for this run ---
            row_data = {
                "M": M,
                "N_cal": args.N_cal,
                "Partitions": args.num_partition,
                "Seed": seed,
                "Distribution": dist_name,
            }
            for name in result_names:
                cov_stats = cov_rep.get(name, [np.nan, np.nan])
                len_stats = leng_rep.get(name, [np.nan, np.nan])
                # Use more descriptive column names
                row_data[f"{name}_CovMean"] = cov_stats[0]
                row_data[f"{name}_CovStd"] = cov_stats[1]
                row_data[f"{name}_LenMean"] = len_stats[0]
                row_data[f"{name}_LenStd"] = len_stats[1]
            all_results_data.append(row_data)

            # --- Save accumulated results to CSV after each run ---
            current_results_df = pd.DataFrame(all_results_data)
            # Use mode 'w' to overwrite, ensure header is written correctly
            current_results_df.to_csv(filename, index=False, mode="w", header=True)
            print(f"Results updated and saved to {filename} after M={M}, Seed={seed}")
            # first_write = False # No longer needed with mode='w'

    # Final message (Saving happens inside the loop now)
    print(f"\nAll experiments complete. Final results saved in: {filename}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run Synthetic Conformal Prediction Experiments"
    )
    # Data Generation & Sizes
    parser.add_argument(
        "--distribution", type=str, default="sin_noise", help="Data distribution"
    )
    parser.add_argument("--n_tr", type=int, default=2000, help="Training set size")
    parser.add_argument(
        "--N_cal", type=int, default=400, help="Total calibration set size (per rep)"
    )
    parser.add_argument(
        "--N_rep", type=int, default=500, help="Number of repetitions (test points)"
    )
    # Model & Conformal Params
    parser.add_argument("--alpha", type=float, default=0.1, help="coverage level")
    parser.add_argument(
        "--num_partition",
        type=int,
        default=5,
        help="Num partitions (-1 for homogeneous)",
    )
    parser.add_argument(
        "--split_portion",
        type=float,
        default=0.5,
        help="Internal split portion for YKsplit, Recalibration",
    )
    parser.add_argument(
        "--N_resamples",
        type=int,
        default=1,
        help="N for resampling in post-hoc calibration",
    )
    parser.add_argument(
        "--M_values",
        type=str,
        default="",
        help="Comma-separated list of M values to test (e.g., '10,20,30'). If not provided, uses default values [10,20,...,100]",
    )
    # Execution Params
    parser.add_argument(
        "--num_seeds", type=int, default=20, help="Number of random seeds"
    )

    args = parser.parse_args()
    if args.N_cal < 2:
        raise ValueError("N_cal must be at least 2.")
    main(args)
