import os, sys, inspect
from conformal import * 
from utils import *
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
import torch
import random
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import csv

class OptimalFunction(torch.nn.Module):
    def __init__(self, m):
        super().__init__()
        self.coefficients = torch.nn.Parameter(torch.zeros(m, dtype=torch.float32))
    def forward(self, s):
        return fourier_integral_exp_vectorized(s, self.coefficients)

def search_optimal_function(logits_cal, logits_train, beta, gamma, T, epsilon, m, alpha, kreg, lamda, bsz):
    calib_scores, calib_labels = get_softmax_targets(logits_cal)
    calib_I, calib_ordered = index_save(calib_scores, kreg, lamda)

    f_model = OptimalFunction(m)
    optimizer = torch.optim.AdamW(f_model.parameters(), lr=gamma, weight_decay=1e-4)

    prev_coefficients = None
    loader_train = torch.utils.data.DataLoader(logits_train, batch_size=bsz, shuffle=False, pin_memory=True)

    for t in range(T):
        for i, (train_logits, targets) in enumerate(loader_train):
            with torch.no_grad():
                current_coefficients = f_model.coefficients.clone()
                new_scores = fourier_integral_exp_vectorized(calib_ordered, current_coefficients)
                sample_idx, true_label_pos = find_quantile_sample(new_scores, calib_I, calib_labels, alpha)
                tau = calib_ordered[sample_idx][:true_label_pos + 1]

            optimizer.zero_grad()
            train_scores = torch.softmax(train_logits, dim=1)
            _, train_ordered = index_save(train_scores, kreg, lamda)
            loss = calculate_loss(f_model, train_ordered, beta, tau)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                current_coefficients = f_model.coefficients.clone()
                if prev_coefficients is not None:
                    if torch.norm(current_coefficients - prev_coefficients) < epsilon:
                        return current_coefficients.detach()
                prev_coefficients = current_coefficients

    return f_model.coefficients.clone().detach()

def trial(model, logits, alpha, kreg, lamda, randomized, n_data_conf, n_data_val, n_data_train_conf, n_data_train, bsz, gamma, T, m):
    logits_cal, logits_val, logits_train_conf, logits_train = split4(logits, n_data_conf, n_data_val, n_data_train_conf, n_data_train)
    loader_cal = torch.utils.data.DataLoader(logits_cal, batch_size=bsz, shuffle=False, pin_memory=True)
    loader_val = torch.utils.data.DataLoader(logits_val, batch_size=bsz, shuffle=False, pin_memory=True)

    beta = 1.0
    epsilon = 1e-6

    c_model = ConformalModelLogits(model, loader_cal, alpha=alpha, kreg=kreg, lamda=lamda, randomized=randomized, allow_zero_sets=True)
    cvg_avg_1, sz_avg_1 = validate(loader_val, c_model, print_bool=False)

    kreg = c_model.kreg
    lamda = c_model.lamda
    optimal_coefficients = search_optimal_function(logits_train_conf, logits_train, beta, gamma, T, epsilon, m, alpha, kreg, lamda, bsz)
    conformal_model = ConformalModelLogits(model, loader_cal, alpha=alpha, transform_coefficients=optimal_coefficients, kreg=kreg, lamda=lamda, randomized=randomized, allow_zero_sets=True)
    cvg_avg, sz_avg = validate(loader_val, conformal_model, print_bool=False)

    return cvg_avg, sz_avg, cvg_avg_1, sz_avg_1

def experiment(modelname, datasetname, datasetpath, model, logits,
               num_trials, alpha, randomized,
               n_data_conf, n_data_val, n_data_train_conf, n_data_train,
               bsz, predictor, gamma, T, m,
               output_dir="results"):
    os.makedirs(output_dir, exist_ok=True)
    kreg = None
    lamda = None
    if predictor in ['APS', 'APS+']:
        lamda = 0
        kreg = 0

    coverages, sizes = [], []
    coverages_1, sizes_1 = [], []

    for i in tqdm(range(num_trials)):
        cvg_avg, sz_avg, cvg_avg_1, sz_avg_1 = trial(
            model, logits, alpha, kreg, lamda,
            randomized, n_data_conf, n_data_val,
            n_data_train_conf, n_data_train, bsz, gamma, T, m
        )
        coverages.append(cvg_avg)
        sizes.append(sz_avg)
        coverages_1.append(cvg_avg_1)
        sizes_1.append(sz_avg_1)

    results = {
        "model": modelname,
        "predictor": predictor,
        "alpha": alpha,
        "gamma": gamma,
        "T": T,
        "m": m,
        "coverage_mean": float(np.mean(coverages)),
        "coverage_median": float(np.median(coverages)),
        "size_mean": float(np.mean(sizes)),
        "size_median": float(np.median(sizes)),
        "coverage_se": float(np.std(coverages) / np.sqrt(len(coverages))),
        "size_se": float(np.std(sizes) / np.sqrt(len(sizes))),
        "original_coverage_mean": float(np.mean(coverages_1)),
        "original_coverage_median": float(np.median(coverages_1)),
        "original_size_mean": float(np.mean(sizes_1)),
        "original_size_median": float(np.median(sizes_1)),
        "original_coverage_se": float(np.std(coverages_1) / np.sqrt(len(coverages_1))),
        "original_size_se": float(np.std(sizes_1) /np.sqrt( len(sizes_1)))
    }

    csv_path = os.path.join(output_dir, f"{predictor}_a{alpha}_g{gamma}_m{m}.csv")
    write_header = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(results.keys()))
        if write_header:
            writer.writeheader()
        writer.writerow(results)

    print(f"Saved results to {csv_path}")
    return results

if __name__ == "__main__":
    seed = 0
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)

    # parameter setup
    modelnames = ['ResNeXt101','ResNet152','ResNet101','DenseNet161','VGG16','ShuffleNet']
    alphas = [0.05,0.10]
    predictors = ['RAPS+','APS+']
    gamma_list = [1e-2,1e-3,1e-4]
    m_list = [3,5,7,9]
    T=1

    datasetname = 'Imagenet'
    datasetpath = 'imagenet_val'
    num_trials = 10
    randomized = True
    n_data_conf = 10000
    n_data_val = 20000
    n_data_train_conf = 10000
    n_data_train = 10000


    '''
    # setup for imagenetv2
    datasetname = 'Imagenetv2'
    datasetpath = 'imagenetv2-matched-frequency'
    num_trials = 100
    randomized = True
    n_data_conf = 2000
    n_data_val = 4000
    n_data_train_conf = 2000
    n_data_train = 2000
    '''
    bsz = 128
    cudnn.benchmark = True

    for modelname in modelnames:
        model = get_model(modelname)
        logits = get_logits_dataset(modelname, datasetname, datasetpath)
        for alpha in alphas:
            for predictor in predictors:
                for gamma in gamma_list:
                    for m in m_list:
                        experiment(
                            modelname, datasetname, datasetpath, model, logits,
                            num_trials, alpha, randomized,
                            n_data_conf, n_data_val, n_data_train_conf, n_data_train,
                            bsz, predictor, gamma, T , m
                        )
