# Standard library
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2,1,3,0"
import time
import argparse
from ast import literal_eval
from dataset import *
from unlearn import *
from metrics import UnLearningScore
from utils import *

# Numerical & plotting
import numpy as np


# PyTorch core
import torch
import torch.optim as optim
import torch.nn.functional as F

# PyTorch data utils
from torch.utils.data import DataLoader

# Project-specific utilities
from metrics import *
from helper_functions import *
import itertools




def main(scenarios, seeds, lr, cp_calib, dataset, model_type, batch_size, mode, baseline, text, parameters):
    print(f"\nRunning {baseline} with the following parameters:")
    print(f"Scenarios: {scenarios}")
    print(f"Seeds: {seeds}")
    print(f"Dataset: {dataset}, Model Type: {model_type}, Batch Size: {batch_size}, Mode: {mode}")
    if text == 1:
        trans = True
    else:
        trans = False

    if dataset == "cifar100" or dataset == "imagenet":
        num_classes = 100
    elif dataset == "news":
        num_classes = 4
    elif dataset == "20_newsgroups":
        num_classes = 20
    elif dataset == "ucf101":
        num_classes = 101
    elif "kin400" in dataset:
        num_classes = 400
    # -----------------------------------------------------------------------------
    def load_pretrained_model(model_type, file_name, device):
        model = create_model(model_type, device, num_classes=num_classes)
        model.load_state_dict(torch.load(f"../../models/{file_name}.pth"))
        model = model.to(device)
        model.eval()
        return model

    ##Evaluate Functions ###############################################
    def evaluate_all_metrics(baseline,
                            original_model,
                            model_unlearned,
                            parameters,
                            retained_loader,
                            forget_loader,
                            calibration_loader,
                            S_loader,
                            R_loader,
                            forget_val_loader,
                            retained_val_loader,
                            calibration_val_loader,
                            test_dataset,
                            calibration_forget_loader,
                            seed,
                            device,
                            number_of_forget,
                            time,
                            alphas,
                            all_stats,
                            c_s= [0],
                            nonconf_func = "one_minus"):
        # Evaluate initial model performance on entire test set
        initial_test_accuracy = evaluate(original_model, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2), device, trans=trans)
        print(f"Initial model test accuracy: {initial_test_accuracy*100:.2f}%")
        
        # Final evaluations
        unlearned_accuracy_retained = evaluate(model_unlearned, retained_loader, device, trans=trans)
        unlearned_accuracy_forget = evaluate(model_unlearned, forget_loader, device, trans=trans)
        print(f"Unlearned model accuracy on retained set (D_r): {unlearned_accuracy_retained*100:.2f}%")
        print(f"Unlearned model accuracy on forgotten set (D_f): {unlearned_accuracy_forget*100:.2f}%")

        accuracy_S = evaluate(model_unlearned, S_loader, device, trans=trans)
        accuracy_R = evaluate(model_unlearned, R_loader, device, trans=trans)

        print(f"Accuracy after on D_train subset S (forgotten classes): {accuracy_S*100:.2f}%")
        print(f"Accuracy after on D_train subset R (retained classes): {accuracy_R*100:.2f}%")
        
        accuracy_S_o = 0 #evaluate(original_model, S_loader, device, trans=trans)
        accuracy_R_o = 0 #evaluate(original_model, R_loader, device, trans=trans)

        accuracy_val_for = evaluate(model_unlearned, forget_val_loader, device, trans=trans)
        accuracy_val_ret = evaluate(model_unlearned, retained_val_loader, device, trans=trans)

        print(f"Accuracy before on D_train subset S (forgotten classes): {accuracy_S_o*100:.2f}%")
        print(f"Accuracy before on D_train subset R (retained classes): {accuracy_R_o*100:.2f}%")
        # Compute membership inference attack scores
        mia_scores_un, mia_mean_acc_un = compute_membership_inference_attack(
                        model=model_unlearned,
                        member_loader=S_loader,
                        nonmember_loader=calibration_forget_loader,
                        device=device,
                        n_splits=10,
                        random_state=seed, trans=trans
                    )

        for alpha in alphas:
            q_hat_original, waste = conformal_prediction_quantile_and_returnall(original_model, D_calib, alpha, nonconf_func, trans=trans)
            q_hat_unlearned, waste = conformal_prediction_quantile_and_returnall(model_unlearned, D_calib, alpha, nonconf_func, trans=trans)
            print(f"q_hat (original model): {q_hat_original:.4f}")
            print(f"q_hat (unlearned model): {q_hat_unlearned:.4f}")
            for c in c_s:
                # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                efn_ret, efn_for = compute_ccucr(
                    model_unlearned, forget_loader, retained_loader,
                    q_hat_unlearned, nonconf_func, c=c, trans=trans
                )
            
                # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                efn_r, efn_s = compute_ccucr(
                    model_unlearned, S_loader, R_loader,
                    q_hat_unlearned, nonconf_func, c=c, trans=trans
                )

                efn_ret_val, efn_for_val = compute_ccucr(
                        model_unlearned, forget_val_loader, retained_val_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                
                # cr_ret, cr_for = compute_cr(model_unlearned, forget_loader, retained_loader, q_hat_unlearned, nonconf_func, trans=trans)
                # cr_r, cr_s = compute_cr(model_unlearned, S_loader, R_loader, q_hat_unlearned, nonconf_func, trans=trans)
                # cr_ret_val, cr_for_val = compute_cr(model_unlearned, forget_val_loader, retained_val_loader, q_hat_unlearned, nonconf_func, trans=trans)
            
                # compute harmonic mean H
                H1 = harmonic_mean([efn_ret, efn_r, efn_ret_val])
                H2 = harmonic_mean([efn_for, efn_s, efn_for_val])
                H = harmonic_mean([H1, H2])

                all_stats.append({
                    'model_name': baseline,
                    'seed': seed,
                    'params': parameters,
                    'forgotten': number_of_forget,
                    'alpha': alpha,
                    'non-conformity': nonconf_func,
                    'c': c,
                    'total_unlearn_time': time,
                    'initial_test_acc': initial_test_accuracy,
                    'qhat_orig': q_hat_original,
                    'qhat_unlearn': q_hat_unlearned,
                    'acc_S_before': accuracy_S_o,
                    'acc_S_after': accuracy_S,
                    'acc_R_before': accuracy_R_o,
                    'acc_R_after': accuracy_R,
                    'acc_Dr': unlearned_accuracy_retained,
                    'acc_Df': unlearned_accuracy_forget,
                    'acc_val_for': accuracy_val_for,
                    'acc_val_ret': accuracy_val_ret,
                    'mia_score_difference': mia_mean_acc_un,
                    'efn_for':     efn_for,
                    'cover_ret':     efn_ret,
                    'efn_s':       efn_s,
                    'cover_r':       efn_r,
                    'efn_val_for': efn_for_val,
                    'cover_val_ret': efn_ret_val,
                    'H_retain':       H1,
                    'H_forget':       H2,
                    'H': H,
                    # 'cr_for': cr_for,
                    # 'cr_ret': cr_ret,
                    # 'cr_s': cr_s,
                    # 'cr_r': cr_r,
                    # 'cr_val_for': cr_for_val,
                    # 'cr_val_ret': cr_ret_val
                })
                with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
                    f.write(f"\n\n------------------------------------\nauxiliary result:\n{all_stats[-1]}")
        return all_stats
    
    # -----------------------------------------------------------------------------
    # Device Setup
    # -----------------------------------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPUs!")

    with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
        f.write(f"\n\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}, over seeds: {seeds}\n")
        f.write(f"Baseline: {baseline}, Dataset: {dataset}, Devices: {num_gpus},\nModel Type: {model_type}, Mode: {mode}\n, Batch_size: {batch_size}, CP Calibration: {cp_calib}\n")

    
    # a place to stash your per‐seed results
    all_stats = []


    if baseline == "all":
        baselines = ["BADT", "UNSIR"]
    else:
        baselines = [baseline,]

    for baseline in baselines:
        params = parameters[baseline]
        # a place to stash your per‐seed results
        all_stats = []

        seed_data = 0
        rng_data = set_seed(seed_data)
        for scenario in scenarios:
            c_s, alphas, forgot_set = scenario
            print(f"Working on scenario: {scenario}")
            # call your function and get back a dict
            loaders, forgotten_classes = load_dataset_and_transform(dataset, forgot_set, rng_data, mode, batch_size, 1)

            # now unpack each entry into its own variable
            train_loader               = loaders["train_loader"]
            calibration_loader         = loaders["calibration_loader"]
            forget_loader              = loaders["forget_loader"]
            retained_loader            = loaders["retained_loader"]
            calibration_val_loader     = loaders["calibration_val_loader"]
            forget_val_loader          = loaders["forget_val_loader"]
            retained_val_loader        = loaders["retained_val_loader"]
            S_loader                   = loaders["S_loader"]
            R_loader                   = loaders["R_loader"]
            calibration_forget_loader  = loaders["calibration_forget_loader"]
            calibration_retain_loader  = loaders["calibration_retain_loader"]
            D_calib_val                = loaders["D_calib_val"]
            D_f_val                    = loaders["D_f_val"]
            D_r_val                    = loaders["D_r_val"]
            D_calib                    = loaders["D_calib"]
            D_f                        = loaders["D_f"]
            D_r                        = loaders["D_r"]
            test_dataset               = loaders["test_dataset"]
            R_small_loader             = loaders["R_small_loader"]
            print(f"Loaded dataset with {len(train_loader.dataset)} training samples, "
                f"{len(calibration_loader.dataset)} calibration samples, "
                f"{len(forget_loader.dataset)} forget samples, "
                f"{len(retained_loader.dataset)} retained samples, "
                f"{len(calibration_val_loader.dataset)} calibration validation samples, "
                f"{len(forget_val_loader.dataset)} forget validation samples, "
                f"{len(retained_val_loader.dataset)} retained validation samples, "
                f"{len(S_loader.dataset)} S samples, "
                f"{len(R_loader.dataset)} R samples.")
            # -----------------------------------------------------------------------------
            # Set up the model and training parameters
            # -----------------------------------------------------------------------------

            print(f"{len(train_loader.dataset)/1000}k training samples, ")
            
            for seed in seeds:
                rng = set_seed(seed)
                random.seed(seed)
        
                model_filename = f"{dataset}_{model_type}_base_{int(len(train_loader.dataset)/1000)}k_seed{seed_data}"
                
                # -----------------------------------------------------------------------------
                # Initial Training on D_train (all classes)
                # -----------------------------------------------------------------------------
                if os.path.exists(f"../../models/{model_filename}.pth"):
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    model = load_pretrained_model(model_type, model_filename, device)
                else:
                    original_model, train_time = build_base_model(model_type, model_filename, train_loader, device)
                    print(f"Initial training time: {train_time:.2f} seconds")
                    model = load_pretrained_model(model_type, model_filename, device)
                # -----------------------------------------------------------------------------
                # Unlearning Phase
                # -----------------------------------------------------------------------------
                if baseline == "BADT":
                    unlearning_teacher = create_model(model_type, device, num_classes=num_classes)
                    student_model = load_pretrained_model(model_type, model_filename, device)
                    
                    KL_temperature = params['kl_temp']
                    lr = params['lr']
                    epochs = params['epochs']
                    unlearn_name = "BADT"
                    optimizer = torch.optim.Adam(student_model.parameters(), lr = lr)
                    
                    unlearn_time = blindspot_unlearner(model = student_model, unlearning_teacher = unlearning_teacher, full_trained_teacher = model, 
                            retain_data = D_r, forget_data = D_f, epochs = epochs, optimizer = optimizer, lr = lr, 
                            batch_size = 256, num_workers = 2, device = device, KL_temperature = KL_temperature, trans=trans)

                elif baseline == "UNSIR":
                    unlearn_name = "UNSIR"
                    noise_dim = get_input_dims(D_r)
                    student_model = load_pretrained_model(model_type, model_filename, device)
                    
                    noise_batch_size = batch_size
                    
                    noise = UNSIR_noise(batch_size, *noise_dim).to(device)

                    forget_class_label = forgotten_classes    
                    num_epochs = params['noise_epochs']
                    lr = params['lr']
                    
                    noise_time_start = time.time()
                    noise =  UNSIR_noise_train(noise, student_model, forget_class_label, num_epochs,
                                            noise_batch_size, device=device, trans=trans)
                    
                    noisy_loader = UNSIR_create_noisy_loader(noise, forget_class_label,
                                                            retained_loader, batch_size, device=device, trans=trans)
                    noise_time = time.time() - noise_time_start

                    optimizer = torch.optim.Adam(student_model.parameters(), lr = lr)
                    
                    #impair step
                    epochs = params['impair_epochs']
                    history, total_unlearn_time_unsir1 = fit_one_unlearning_cycle(epochs, student_model, noisy_loader, retained_loader, device = device, lr = lr, trans=trans)
                    
                    #repair step
                    # other_samples = []
                    # for i in range(len(retain_samples)):
                    #     other_samples.append((retain_samples[i][0].cpu(), torch.tensor(retain_samples[i][2]),
                    #                             torch.tensor(retain_samples[i][2])))    
                    
                    epochs = params['repair_epochs']
                    history, total_unlearn_time_unsir2 = fit_one_unlearning_cycle(epochs, student_model, retained_loader, calibration_retain_loader, device = device, lr = lr, trans=trans)
                    
                    unlearn_time = total_unlearn_time_unsir1 + total_unlearn_time_unsir2 + noise_time
                
                all_stats = evaluate_all_metrics(unlearn_name,
                                        original_model,
                                        student_model,
                                        params,
                                        retained_loader,
                                        forget_loader,
                                        calibration_loader,
                                        S_loader,
                                        R_loader,
                                        forget_val_loader,
                                        retained_val_loader,
                                        calibration_val_loader,
                                        test_dataset,
                                        calibration_forget_loader,
                                        seed,
                                        device,
                                        forgot_set,
                                        unlearn_time,
                                        alphas,
                                        all_stats,
                                        c_s= c_s,
                                        nonconf_func = "one_minus")

        # -----------------------------------------------------------------------------

        from collections import defaultdict
        from uncertainties import ufloat

        # pick the keys on which you want to group (everything except seed and the actual metrics)
        group_keys = ['model_name', 'forgotten','alpha','c', 'non-conformity']

        # bucket stats by those keys
        buckets = defaultdict(list)
        for st in all_stats:
            key = tuple((len(st[k]) if isinstance(st[k], list) else st[k]) for k in group_keys)
            buckets[key].append(st)
        print(f"The buckets dict is: {buckets}")
        # now compute averages and stds
        with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
            for key, recs in buckets.items():
                # pull out each metric into numpy arrays
                unlearn_time     = np.array([r['total_unlearn_time']     for r in recs])
                mia_scores_un    = np.array([r['mia_score_difference']   for r in recs])
                acc_S_aft_vec    = np.array([r['acc_S_after']          for r in recs])
                acc_R_aft_vec    = np.array([r['acc_R_after']          for r in recs])
                acc_Dr_vec       = np.array([r['acc_Dr']               for r in recs])
                acc_Df_vec       = np.array([r['acc_Df']               for r in recs])
                acc_val_for_vec  = np.array([r['acc_val_for']          for r in recs])
                acc_val_ret_vec  = np.array([r['acc_val_ret']          for r in recs])
                efn_for_arr      = np.array([r['efn_for']              for r in recs])
                efn_ret_arr      = np.array([r['cover_ret']            for r in recs])
                efn_s_arr        = np.array([r['efn_s']                for r in recs])
                efn_r_arr        = np.array([r['cover_r']              for r in recs])
                efn_val_for_arr  = np.array([r['efn_val_for']          for r in recs])
                efn_val_ret_arr  = np.array([r['cover_val_ret']        for r in recs])
                # cr_for_arr      = np.array([r['cr_for']               for r in recs])
                # cr_ret_arr      = np.array([r['cr_ret']               for r in recs])
                # cr_s_arr        = np.array([r['cr_s']                 for r in recs])
                # cr_r_arr        = np.array([r['cr_r']                 for r in recs])
                # cr_val_for_arr  = np.array([r['cr_val_for']           for r in recs])
                # cr_val_ret_arr  = np.array([r['cr_val_ret']           for r in recs])

                # compute means and stds
                mean_time    = unlearn_time.mean();        std_time    = unlearn_time.std()
                mean_mia_scores_un = mia_scores_un.mean(); std_mia_scores_un = mia_scores_un.std()
                mean_Sa, std_Sa = acc_S_aft_vec.mean(), acc_S_aft_vec.std()
                mean_Ra, std_Ra = acc_R_aft_vec.mean(), acc_R_aft_vec.std()
                mean_Dr, std_Dr = acc_Dr_vec.mean(),     acc_Dr_vec.std()
                mean_Df, std_Df = acc_Df_vec.mean(),     acc_Df_vec.std()
                mean_val_for, std_val_for = acc_val_for_vec.mean(), acc_val_for_vec.std()
                mean_val_ret, std_val_ret = acc_val_ret_vec.mean(), acc_val_ret_vec.std()
                mean_efn_for, std_efn_for = efn_for_arr.mean(),     efn_for_arr.std()
                mean_efn_ret, std_efn_ret = efn_ret_arr.mean(),     efn_ret_arr.std()
                mean_efn_s,   std_efn_s   = efn_s_arr.mean(),       efn_s_arr.std()
                mean_efn_r,   std_efn_r   = efn_r_arr.mean(),       efn_r_arr.std()
                mean_efn_val_for, std_efn_val_for = efn_val_for_arr.mean(), efn_val_for_arr.std()
                mean_efn_val_ret, std_efn_val_ret = efn_val_ret_arr.mean(), efn_val_ret_arr.std()
                # mean_cr_for, std_cr_for = cr_for_arr.mean(),     cr_for_arr.std()
                # mean_cr_ret, std_cr_ret = cr_ret_arr.mean(),     cr_ret_arr.std()
                # mean_cr_s,   std_cr_s   = cr_s_arr.mean(),       cr_s_arr.std()
                # mean_cr_r,   std_cr_r   = cr_r_arr.mean(),       cr_r_arr.std()
                # mean_cr_val_for, std_cr_val_for = cr_val_for_arr.mean(), cr_val_for_arr.std()
                # mean_cr_val_ret, std_cr_val_ret = cr_val_ret_arr.mean(), cr_val_ret_arr.std()

                # harmonic mean H over the six efn metrics
                e1 = ufloat(mean_efn_for,     std_efn_for)
                e2 = ufloat(mean_efn_ret,     std_efn_ret)
                e3 = ufloat(mean_efn_s,       std_efn_s)
                e4 = ufloat(mean_efn_r,       std_efn_r)
                e5 = ufloat(mean_efn_val_for, std_efn_val_for)
                e6 = ufloat(mean_efn_val_ret, std_efn_val_ret)
                try:
                    H = 6 / (1/e1 + 1/e2 + 1/e3 + 1/e4 + 1/e5 + 1/e6)
                except ZeroDivisionError:
                    H = ufloat(0.0, 0.0)

                # --- write to file as before ---
                f.write("\n" + "#"*50 + "\n")
                params_str = ", ".join(f"{k}={v}" for k, v in zip(group_keys, key))
                f.write(f"{params_str}\n\n")
                f.write(f"Unlearn cycle time:        {mean_time:.2f} ± {std_time:.2f}s\n\n")
                f.write(f"MIA score difference:      {mean_mia_scores_un*100:.2f}% ± {std_mia_scores_un*100:.2f}%\n\n")
                f.write(f"A_Tr after:                {mean_Ra*100:.2f}% ± {std_Ra*100:.2f}%\n")
                f.write(f"A_Tf after:                {mean_Sa*100:.2f}% ± {std_Sa*100:.2f}%\n\n")
                f.write(f"A_Dr:                      {mean_Dr*100:.2f}% ± {std_Dr*100:.2f}%\n")
                f.write(f"A_Df:                      {mean_Df*100:.2f}% ± {std_Df*100:.2f}%\n\n")
                f.write(f"A_Vr:                      {mean_val_ret*100:.2f}% ± {std_val_ret*100:.2f}%\n")
                f.write(f"A_Vf:                      {mean_val_for*100:.2f}% ± {std_val_for*100:.2f}%\n\n")
                f.write(f"frakC_Dr:                  {mean_efn_ret:.2f} ± {std_efn_ret:.2f}\n")
                f.write(f"frakN_Df:                  {mean_efn_for:.2f} ± {std_efn_for:.2f}\n")
                f.write(f"frakC_Tr:                  {mean_efn_r:.2f} ± {std_efn_r:.2f}\n")
                f.write(f"frakN_Tf:                  {mean_efn_s:.2f} ± {std_efn_s:.2f}\n")
                f.write(f"frakC_Vr:                  {mean_efn_val_ret:.2f} ± {std_efn_val_ret:.2f}\n")
                f.write(f"frakN_Vf:                  {mean_efn_val_for:.2f} ± {std_efn_val_for:.2f}\n")
                f.write(f"H:                         {H.nominal_value:.2f} ± {H.std_dev:.2f}\n\n\n")

                # f.write(f"CR_Dr:                     {mean_cr_ret:.2f} ± {std_cr_ret:.2f}\n")
                # f.write(f"CR_Df:                     {mean_cr_for:.2f} ± {std_cr_for:.2f}\n")
                # f.write(f"CR_Tr:                     {mean_cr_r:.2f} ± {std_cr_r:.2f}\n")
                # f.write(f"CR_Tf:                     {mean_cr_s:.2f} ± {std_cr_s:.2f}\n")
                # f.write(f"CR_Vr:                     {mean_cr_val_ret:.2f} ± {std_cr_val_ret:.2f}\n")
                # f.write(f"CR_Vf:                     {mean_cr_val_for:.2f} ± {std_cr_val_for:.2f}\n")

            f.write(f"\n\n-----------------###################################-------------------\n\n\n\n\n\n")




if __name__ == "__main__":
    p = argparse.ArgumentParser(
        description="Run experiment with various scenarios and seeds"
    )

    # scenarios: list of [ [ints...], [floats...], int ] tuples
    p.add_argument(
        "--scenarios",
        type=literal_eval,
        default=[[[100],[0.1],20]],
        help=(
            "list of scenarios, e.g. "
            "[[[c_s],[alphas],number_of_classes_or_clusters_or_points], [[...]]]"
        )
    )

    # seeds: list of ints
    p.add_argument(
        "--seeds",
        type=literal_eval,
        default=[0],
        help="list of integer seeds, e.g. [0,1,2]"
    )

    # dataset: simple string
    p.add_argument(
        "--dataset",
        type=str,
        choices=["cifar100", "imagenet", "20_newsgroups", "news"],
        default="cifar100",
        help="dataset name, e.g. cifar100 or imagenet"
    )

    # model type: simple string
    p.add_argument(
        "--model_type",
        type=str,
        choices=["resnet18", "resnet18_imagenet", "vit", "berta_distill"],
        default="resnet18",
        help="model architecture, e.g. resnet18 or berta_distill"
    )

    # batch size: integer
    p.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="batch size for training, unlearning, and evaluation (integer)"
    )

    # mode: simple string
    p.add_argument(
        "--mode",
        type=str,
        choices=["label", "pca", "cluster", "instance-label", "instance-pca", "instance-cluster", "random", "instance-random"],
        default="label",
        help="operating mode, e.g. label or cluster"
    )

    # calibration set name: simple string
    p.add_argument(
        "--cp_calib",
        type=str,
        choices=["calib", "calib_val"],
        default="calib",
        help="calibration set, e.g. calib or calib_val"
    )

    # text type dataset: boolean
    p.add_argument(
        "--text",
        type=int,
        choices=[0,1],
        default=0,
        help="if 1, the dataset is text, if 0, it is image/else (integer)"
    )

    # calibration set name: simple string
    p.add_argument(
        "--baseline",
        type=str,
        choices=["BADT", "UNSIR", "all"],
        default="BADT",
        help="unlearning baseline, e.g. UNSIR or BADT"
    )

    # --- BADT Hyperparameters ---
    group_badt = p.add_argument_group('BADT Hyperparameters')
    group_badt.add_argument('--kl-temp', type=float, default=1, help="Temperature of the KL-divergence.")
    group_badt.add_argument('--badt-lr', type=float, default=1e-4, help="Learning rate for BADT.")
    group_badt.add_argument('--badt-epochs', type=int, default=1, help="Number of epochs for BADT.")

    # --- AMN Hyperparameters ---
    group_unsir = p.add_argument_group('UNSIR Hyperparameters')
    group_unsir.add_argument('--noise-epochs', type=int, default=250, help="Number of epochs for UNSIR.")
    group_unsir.add_argument('--impair-epochs', type=int, default=1, help="Number of epochs for UNSIR.")
    group_unsir.add_argument('--repair-epochs', type=int, default=1, help="Number of epochs for UNSIR.")
    group_unsir.add_argument('--unsir-lr', type=float, default=1e-4, help="Learning rate for UNSIR.")


    # ==========================================================================
    # 2. PARSE ARGUMENTS AND RECONSTRUCT THE NESTED DICTIONARY
    # ==========================================================================

    # This line comes after all p.add_argument() calls
    args = p.parse_args()

    # Manually construct the nested dictionaries from the flat argparse namespace
    params_badt = {
        'kl_temp': args.kl_temp,
        'lr': args.badt_lr,
        'epochs':args.badt_epochs
    }

    params_unsir = {
        'noise_epochs': args.noise_epochs,
        'impair_epochs': args.impair_epochs,
        'repair_epochs': args.repair_epochs,
        'lr': args.unsir_lr
    }

    # Assemble the final dictionary to be passed to the main function
    parameters = {
        'BADT': params_badt,
        'UNSIR': params_unsir,
    }

    main(args.scenarios, args.seeds, args.lr, args.cp_calib, args.dataset, args.model_type, args.batch_size, args.mode, args.baseline, args.text, parameters)

    # -----------------------------------------------------------------------------
    #example usage:
    # python nabla_amn_general.py --scenarios [[[100],[0.1],20]] --seeds [0] --lr 0.001 --dataset cifar100 --model_type resnet18 --batch_size 256 --mode label --cp_calib calib --baseline nabla_tau