import argparse
import os
import time
from datetime import datetime
import random

import numpy as np
import pandas as pd
from tqdm import tqdm
from joblib import Parallel, delayed
from scipy import stats

from model_utils import train_model_feature


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)


# ------------------------- Data Generators (same as synthetic_exp) -------------------------
def sin_noise(n_sample, std=0.1):
    d = 10
    X = np.random.randn(n_sample, d)
    b = np.ones(d) / d
    Y = np.zeros(n_sample)
    Y += np.matmul(X, b)
    return X, np.sin(Y) + np.random.random(n_sample) * std


def easyX_sparse_normal(n_sample):
    d = 300
    s = 18
    X = np.random.randn(n_sample, d)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.normal(size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


def easyX_sparse_simpleT(n_sample):
    d = 300
    s = 18
    nu = 3
    X = np.random.randn(n_sample, d)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.standard_t(df=nu, size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


def easyX_dense(n_sample):
    d = 300
    X = np.random.randn(n_sample, d)
    Y = np.mean(X, axis=1) + (np.random.randn(n_sample) * (1 / d))
    return X, Y


def easytX_sparse_normal(n_sample):
    d = 300
    s = 18
    nu = 3
    X = np.zeros((n_sample, d))
    Id = np.eye(d)
    mean = np.zeros(d)
    X += stats.multivariate_t.rvs(mean, Id, df=nu, size=n_sample)
    b = np.array([max(0, (j - 1) % 20 - s) for j in range(1, d + 1)])
    Y = np.zeros(n_sample)
    Y += np.random.normal(size=n_sample)
    Y += np.matmul(X, b)
    return X, Y


def avg_split_conformal(cal_mu, cal_y, cal_sigma, test_mu, test_sigma, test_y, alpha):
    n = cal_mu.shape[0]
    k = int(np.ceil((n + 1) * (1 - alpha)))
    scores = np.sort(np.abs(cal_mu - cal_y) / (cal_sigma + 1e-8))
    S = scores[k - 1]
    cover = float(np.abs(test_y - test_mu) <= S * test_sigma)
    length = float(2 * S * test_sigma)
    return cover, length


def run_avg_baseline(
    data_gene,
    N_rep,
    n_tr,
    n_cal,
    M,
    alpha,
    num_partitions,
    seed,
):
    set_seed(seed)

    # 1) Train models
    X_tr, Y_tr = data_gene(n_tr)
    Mmodels = train_model_feature(M, X_tr, Y_tr, number_partitions=num_partitions)
    M = len(Mmodels)

    # 2) Precompute cal/test datasets
    total_cal = N_rep * n_cal
    total_test = N_rep
    X_cal_all, Y_cal_all = data_gene(total_cal)
    X_test_all, Y_test_all = data_gene(total_test)

    # 3) Precompute predictions and residual scales for each model
    cal_preds_all = np.array([m[0].predict(X_cal_all) for m in Mmodels])
    cal_sigma_all = np.clip(
        np.array([m[1].predict(X_cal_all) for m in Mmodels]), 1e-6, np.inf
    )
    test_pred_all = np.array([m[0].predict(X_test_all) for m in Mmodels])
    test_sigma_all = np.clip(
        np.array([m[1].predict(X_test_all) for m in Mmodels]), 1e-6, np.inf
    )

    # 4) Average predictor and average sigma
    cal_mu_avg_all = np.mean(cal_preds_all, axis=0)  # (total_cal,)
    cal_sigma_avg_all = np.mean(cal_sigma_all, axis=0)  # (total_cal,)
    test_mu_avg_all = np.mean(test_pred_all, axis=0)  # (total_test,)
    test_sigma_avg_all = np.mean(test_sigma_all, axis=0)  # (total_test,)

    # 5) Evaluate per repetition (parallelized)
    def _one_rep(t):
        cs, ce = t * n_cal, (t + 1) * n_cal
        cal_mu = cal_mu_avg_all[cs:ce]
        cal_sig = cal_sigma_avg_all[cs:ce]
        cal_y = Y_cal_all[cs:ce]
        test_mu = test_mu_avg_all[t]
        test_sig = test_sigma_avg_all[t]
        test_y = Y_test_all[t]
        return avg_split_conformal(cal_mu, cal_y, cal_sig, test_mu, test_sig, test_y, alpha)

    results = Parallel(n_jobs=-1, backend="loky")(delayed(_one_rep)(t) for t in range(N_rep))
    covers = np.array([r[0] for r in results])
    lengths = np.array([r[1] for r in results])

    cov_mean, cov_std = float(np.mean(covers)), float(np.std(covers))
    len_mean, len_std = float(np.mean(lengths)), float(np.std(lengths))
    return cov_mean, cov_std, len_mean, len_std


def main(args):
    distribution_map = {
        "sin_noise": sin_noise,
        "easyX_sparse_normal": easyX_sparse_normal,
        "easyX_sparse_simpleT": easyX_sparse_simpleT,
        "easyX_dense": easyX_dense,
        "easytX_sparse_normal": easytX_sparse_normal,
    }
    # Determine distributions to run
    if args.distribution and args.distribution.lower() != "all":
        if args.distribution not in distribution_map:
            raise ValueError(f"Invalid distribution: {args.distribution}")
        distributions = [args.distribution]
    else:
        distributions = list(distribution_map.keys())

    # Parse M values
    if args.M_values:
        try:
            M_values = [int(m.strip()) for m in args.M_values.split(',')]
            if not all(m > 0 for m in M_values):
                raise ValueError("All M values must be positive integers")
        except ValueError as e:
            raise ValueError(
                f"Invalid M values format. Please provide comma-separated integers. Error: {e}"
            )
    else:
        M_values = [10, 20, 30, 40, 50, 60, 70, 80]

    part_str = f"P{args.num_partition}" if args.num_partition > 0 else "Homog"
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = "avg_baseline"
    os.makedirs(results_dir, exist_ok=True)

    for dist_name in distributions:
        data_generator = distribution_map[dist_name]
        base_filename = f"SYNTH_{dist_name}_{part_str}_Ncal{args.N_cal}"
        filename = os.path.join(results_dir, f"{base_filename}_{timestamp}.csv")
        print(f"Results will be saved to: {filename}")

        all_rows = []

        for M in M_values:
            for seed in range(args.num_seeds):
                print(f"\n{'='*50}")
                print(
                    f"AVG BASELINE: M={M}, Seed={seed+1}/{args.num_seeds}, N_cal={args.N_cal}, Partitioning={part_str}, Dist={dist_name}"
                )
                print(f"{'='*50}")

                start = time.time()
                cov_mean, cov_std, len_mean, len_std = run_avg_baseline(
                    data_gene=data_generator,
                    N_rep=args.N_rep,
                    n_tr=args.n_tr,
                    n_cal=args.N_cal,
                    M=M,
                    alpha=args.alpha,
                    num_partitions=args.num_partition,
                    seed=seed,
                )
                print(
                    f"AvgSplit: Coverage={cov_mean:.4f} +/- {cov_std:.4f}, Length={len_mean:.4f} +/- {len_std:.4f} (elapsed {time.time()-start:.2f}s)"
                )

                row = {
                    "M": M,
                    "N_cal": args.N_cal,
                    "Partitions": args.num_partition,
                    "Seed": seed,
                    "Distribution": dist_name,
                    "AvgSplit_CovMean": cov_mean,
                    "AvgSplit_CovStd": cov_std,
                    "AvgSplit_LenMean": len_mean,
                    "AvgSplit_LenStd": len_std,
                }
                all_rows.append(row)

                # Overwrite file each run for robustness (same behavior style)
                pd.DataFrame(all_rows).to_csv(filename, index=False, mode="w", header=True)
                print(f"Results updated and saved to {filename}")

        print(f"\nAll AVG baseline experiments complete for {dist_name}. File: {filename}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Average predictor split conformal baseline")
    parser.add_argument("--distribution", type=str, default="all", help="Data distribution (or 'all')")
    parser.add_argument("--n_tr", type=int, default=2000, help="Training set size")
    parser.add_argument("--N_cal", type=int, default=400, help="Calibration set size per repetition")
    parser.add_argument("--N_rep", type=int, default=500, help="Number of repetitions (test points)")
    parser.add_argument("--alpha", type=float, default=0.1, help="Miscoverage level")
    parser.add_argument(
        "--num_partition",
        type=int,
        default=5,
        help="Num partitions (-1 for homogeneous)",
    )
    parser.add_argument(
        "--M_values",
        type=str,
        default="",
        help="Comma-separated list of M values to test (e.g., '10,20,30'). If not provided, uses default values [10,20,...,100]",
    )
    parser.add_argument("--num_seeds", type=int, default=20, help="Number of random seeds")

    args = parser.parse_args()
    if args.N_cal < 2:
        raise ValueError("N_cal must be at least 2.")
    main(args)


