import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,0,1"
import torch
import torch.optim as optim
import numpy as np
import copy
import argparse
from torch.utils.data import DataLoader
from ast import literal_eval
from helper_functions import *



def main(scenarios, seeds, epoch, lr, lam, gamma, delta, cp_calib, dataset, model_type, batch_size, mode, text):
    import random
    print(f"\nRunning CQMU with the following parameters:")
    print(f"Scenarios: {scenarios}")
    print(f"Seeds: {seeds}")
    print(f"Epochs: {epoch}, Learning Rate: {lr}, Lambda: {lam}, Gamma: {gamma}")
    print(f"Dataset: {dataset}, Model Type: {model_type}, Batch Size: {batch_size}, Mode: {mode}")
    if text == 1:
        trans = True
    else:
        trans = False

    if dataset == "cifar100" or dataset == "imagenet":
        num_classes = 100
    elif dataset == "news":
        num_classes = 4
    elif dataset == "20_newsgroups":
        num_classes = 20

    # -----------------------------------------------------------------------------
    def load_pretrained_model(model_type, file_name, device):
        model = create_model(model_type, device, num_classes=num_classes)
        model.load_state_dict(torch.load(f"../models/{file_name}.pth"))
        model = model.to(device)
        model.eval()
        return model
    

    # -----------------------------------------------------------------------------
    # Device Setup
    # -----------------------------------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPUs!")

    with open(f"../results/CQMU_{dataset}_{mode}.txt", 'a+') as f:
        f.write(f"\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}, over seeds: {seeds}, batch_size: {batch_size}\n")
        f.write(f"Dataset: {dataset}, Model Type: {model_type}, Devices: {num_gpus},\nMode: {mode}, CP Calibration: {cp_calib}, Fixed = {not vary}\n")

    # a place to stash your per‐seed results
    all_stats = []

    original_accuracy_list = []
    seed_data = 0
    rng_data = set_seed(seed_data)
    for scenario in scenarios:
        c_s, alphas, forgot_set = scenario
        print(f"Working on scenario: {scenario}")
        forgotten_set = forgot_set
        loaders = OLD_load_data(dataset, forgotten_set, rng_data, mode, vary, batch_size, 2)

        # now unpack each entry into its own variable
        train_loader               = loaders["train_loader"]
        calibration_loader         = loaders["cal_test_loader"]
        forget_loader              = loaders["Df_loader"]
        retained_loader            = loaders["Dr_loader"]
        calibration_val_loader     = loaders["cal_unlearn_loader"]
        forget_val_loader          = loaders["Vf_loader"]
        retained_val_loader        = loaders["Vr_loader"]
        S_loader                   = loaders["Tf_loader"]
        R_loader                   = loaders["Tr_loader"]
        calibration_forget_loader  = loaders["cal_test_forget_loader"]
        D_calib_val                = loaders["cal_unlearn_subset"]
        D_calib                    = loaders["cal_test_subset"]
        test_dataset               = loaders["test_subset"]

        if train_loader is None:
            length_train = 0
        else:
            length_train = len(train_loader.dataset)

        print(f"Loaded dataset with {length_train} training samples, "
                f"{len(calibration_loader.dataset)} calibration samples, "
                f"{len(forget_loader.dataset)} forget samples, "
                f"{len(retained_loader.dataset)} retained samples, "
                f"{len(calibration_val_loader.dataset)} calibration validation samples, "
                f"{len(forget_val_loader.dataset)} forget validation samples, "
                f"{len(retained_val_loader.dataset)} retained validation samples, "
                f"{len(S_loader.dataset)} S samples, "
                f"{len(R_loader.dataset)} R samples.")
        # -----------------------------------------------------------------------------
        # Set up the model and training parameters
        # -----------------------------------------------------------------------------
        if cp_calib == "calib":
            cp_set = D_calib
            unlearn_set = D_calib_val
        else:
            cp_set = D_calib_val
            unlearn_set = D_calib

        # We will do hyperparam searching over these lists
        nonconf_func = "one_minus"
        optimizing = "CQMU"
        norm = "l2_sq"
        print(f"{length_train/1000}k training samples, ")
    
        for seed in seeds:
            rng = set_seed(seed)
            random.seed(seed)
            
            model_filename = f"{dataset}_{model_type}_base_{int(length_train/1000)}k_seed{seed_data}"
        
            # -----------------------------------------------------------------------------
            # Initial Training on D_train (all classes)
            # -----------------------------------------------------------------------------
            if os.path.exists(f"../models/{model_filename}.pth"):
                model = load_pretrained_model(model_type, model_filename, device)
            else:
                model, train_time = build_base_model(model_type, model_filename, train_loader, device, num_classes, trans=trans)
                print(f"Initial training time: {train_time:.2f} seconds")
            # -----------------------------------------------------------------------------
        
            original_state_dict = copy.deepcopy(model.state_dict())
            original_model = load_pretrained_model(model_type, model_filename, device)

            # Evaluate initial model performance on entire test set
            initial_test_accuracy = evaluate(original_model, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2), device, trans=trans)
            print(f"Initial model test accuracy: {initial_test_accuracy*100:.2f}%")

            accuracy_S_o = evaluate(original_model, S_loader, device, trans=trans)
            accuracy_R_o = evaluate(original_model, R_loader, device, trans=trans)
            print(f"Accuracy before on D_train subset S (forgotten classes): {accuracy_S_o*100:.2f}%")
            print(f"Accuracy before on D_train subset R (retained classes): {accuracy_R_o*100:.2f}%")

            unlearned_accuracy_retained_o = evaluate(original_model, retained_loader, device, trans=trans)
            unlearned_accuracy_forget_o = evaluate(original_model, forget_loader, device, trans=trans)
            print(f"Original model accuracy on retained set (D_r): {unlearned_accuracy_retained_o*100:.2f}%")
            print(f"Original model accuracy on forgotten set (D_f): {unlearned_accuracy_forget_o*100:.2f}%")

            #get validation origianl accuracies
            accuracy_val_for_o = evaluate(original_model, forget_val_loader, device, trans=trans)
            accuracy_val_ret_o = evaluate(original_model, retained_val_loader, device, trans=trans)
            print(f"Original model accuracy on validation set V_f: {accuracy_val_for_o*100:.2f}%") 
            print(f"Original model accuracy on validation set V_r: {accuracy_val_ret_o*100:.2f}%")
            
            original_accuracy_list.append([accuracy_S_o, accuracy_R_o, unlearned_accuracy_retained_o, unlearned_accuracy_forget_o, accuracy_val_for_o, accuracy_val_ret_o])

            for alpha in alphas:
                print(f"→ CQMU   with gamma={gamma}, λ={lam}, α={alpha}, epoch={epoch}, lr={lr}, Working on {forgot_set} fogotten classes.")
                # -----------------------------------------------------------------------------
                # CP-based Machine Unlearning
                # -----------------------------------------------------------------------------
                alpha = alpha
                beta = alpha
                length_forget = len(forget_loader.dataset)
                model = load_pretrained_model(model_type, model_filename, device)
                optimizer_un = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)
                
                unlearn_file_name = f"./checkpoints/{dataset}_{model_type}_CQMU_{mode}_forgot{forgot_set}_seed{seed}_alpha{alpha}_lam{lam}_gamma{gamma}_epoch{epoch}_lr{lr}_Df{int(length_forget)}.pth"
                
                model_unlearned, total_unlearn_time, q_time_percent = cqmu(
                    model, original_state_dict,
                    forget_loader, retained_loader, unlearn_set,
                    alpha, beta, gamma, lam, epoch, optimizer_un, delta, 
                    norm_name=norm, score_func_name="one_minus", trans=trans,file_name=unlearn_file_name
                )


                print(f"Total unlearning time is: {total_unlearn_time}")
                
                # Find the quantile scores of the original model and the unlearned model.
                q_hat_original, waste = conformal_prediction_quantile_and_returnall(original_model, cp_set, alpha, nonconf_func, trans=trans)
                q_hat_unlearned, waste = conformal_prediction_quantile_and_returnall(model_unlearned, cp_set, alpha, nonconf_func, trans=trans)
                print(f"q_hat (original model): {q_hat_original:.4f}")
                print(f"q_hat (unlearned model): {q_hat_unlearned:.4f}")
                
                
                # Final evaluations
                unlearned_accuracy_retained = evaluate(model_unlearned, retained_loader, device, trans=trans)
                unlearned_accuracy_forget = evaluate(model_unlearned, forget_loader, device, trans=trans)
                print(f"Unlearned model accuracy on retained set (D_r): {unlearned_accuracy_retained*100:.2f}%")
                print(f"Unlearned model accuracy on forgotten set (D_f): {unlearned_accuracy_forget*100:.2f}%")
            
                accuracy_S = evaluate(model_unlearned, S_loader, device, trans=trans)
                accuracy_R = evaluate(model_unlearned, R_loader, device, trans=trans)
            
                print(f"Accuracy after on D_train subset S (forgotten classes): {accuracy_S*100:.2f}%")
                print(f"Accuracy after on D_train subset R (retained classes): {accuracy_R*100:.2f}%")
            

                accuracy_val_for = evaluate(model_unlearned, forget_val_loader, device, trans=trans)
                accuracy_val_ret = evaluate(model_unlearned, retained_val_loader, device, trans=trans)
                print(f"Accuracy on validation set V_f: {accuracy_val_for*100:.2f}%")
                print(f"Accuracy on validation set V_r: {accuracy_val_ret*100:.2f}%")

                # Compute membership inference attack scores
                mia_scores_un, mia_mean_acc_un = compute_membership_inference_attack(
                                model=model_unlearned,
                                member_loader=S_loader,
                                nonmember_loader=calibration_forget_loader,
                                device=device,
                                n_splits=10,
                                random_state=seed, trans=trans
                            )
                
                for c in c_s:
                    print(f"→ CQMU   with φ={gamma}, λ={lam}, α={alpha}, epoch={epoch}, lr={lr}, c={c}, Working on {forgot_set} forgotten classes.")
                    # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                    efn_ret, efn_for, avg_D = compute_ccucr(
                        model_unlearned, forget_loader, retained_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                
                    # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                    efn_r, efn_s, avg_T = compute_ccucr(
                        model_unlearned, S_loader, R_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                    
                    efn_ret_val, efn_for_val, avg_V = compute_ccucr(
                        model_unlearned, forget_val_loader, retained_val_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )

                    cr_ret, cr_for = compute_cr(model_unlearned, forget_loader, retained_loader, q_hat_unlearned, nonconf_func, trans=trans)
                    cr_r, cr_s = compute_cr(model_unlearned, S_loader, R_loader, q_hat_unlearned, nonconf_func, trans=trans)
                    cr_ret_val, cr_for_val = compute_cr(model_unlearned, forget_val_loader, retained_val_loader, q_hat_unlearned, nonconf_func, trans=trans)

                    # compute harmonic mean H
                    H1 = harmonic_mean([efn_ret, efn_r, efn_ret_val])
                    H2 = harmonic_mean([efn_for, efn_s, efn_for_val])
                    H = harmonic_mean([H1, H2])
                        
                    all_stats.append({
                        'seed': seed,
                        'forgotten': forgot_set,
                        'lambda_reg': lam,
                        'alpha': alpha,
                        'beta': beta,
                        'epochs': epoch,
                        'lr': lr,
                        'gamma': gamma,
                        'delta': delta,
                        'non-conformity': nonconf_func,
                        'loss_function': optimizing,
                        'c': c,
                        'total_unlearn_time': total_unlearn_time,
                        'q_time_percent': q_time_percent,
                        'initial_test_acc': initial_test_accuracy,
                        # 'qhat_orig': q_hat_original,
                        'qhat_unlearn': q_hat_unlearned,
                        'acc_S_before': accuracy_S_o,
                        'acc_S_after': accuracy_S,
                        'acc_R_before': accuracy_R_o,
                        'acc_R_after': accuracy_R,
                        'acc_Dr': unlearned_accuracy_retained,
                        'acc_Df': unlearned_accuracy_forget,
                        'acc_val_for': accuracy_val_for,
                        'acc_val_ret': accuracy_val_ret,
                        'mia_score_difference': mia_mean_acc_un,
                        'efn_for': efn_for,
                        'cover_ret': efn_ret,
                        'efn_s': efn_s,
                        'cover_r': efn_r,
                        'efn_val_for': efn_for_val,
                        'cover_val_ret': efn_ret_val,
                        'H_retain': H1,
                        'H_forget': H2,
                        'H': H,
                        'cr_for': cr_for,
                        'cr_ret': cr_ret,
                        'cr_s': cr_s,
                        'cr_r': cr_r,
                        'cr_val_for': cr_for_val,
                        'cr_val_ret': cr_ret_val,
                        'avg_D': avg_D,
                        'avg_T': avg_T,
                        'avg_V': avg_V,
                    })
                    with open(f"../results/CQMU_{dataset}_{mode}.txt", 'a+') as f:
                        f.write(f"\n\n------------------------------------\nauxiliary result:\n{all_stats[-1]}")
                    # ─────────────────────────────────────────────
    
    #find the average and std of each accuracy over all seeds (assume one scenario)
    acc_S_mean = np.mean([x[0] for x in original_accuracy_list])
    acc_S_std = np.std([x[0] for x in original_accuracy_list])
    acc_R_mean = np.mean([x[1] for x in original_accuracy_list])
    acc_R_std = np.std([x[1] for x in original_accuracy_list])
    acc_Dr_mean = np.mean([x[2] for x in original_accuracy_list])
    acc_Dr_std = np.std([x[2] for x in original_accuracy_list])
    acc_Df_mean = np.mean([x[3] for x in original_accuracy_list])
    acc_Df_std = np.std([x[3] for x in original_accuracy_list])
    acc_val_for_mean = np.mean([x[4] for x in original_accuracy_list])
    acc_val_for_std = np.std([x[4] for x in original_accuracy_list])
    acc_val_ret_mean = np.mean([x[5] for x in original_accuracy_list])
    acc_val_ret_std = np.std([x[5] for x in original_accuracy_list])
    print(f"\nOriginal model average accuracy on D_train subset S (forgotten classes): {acc_S_mean*100:.2f}% ± {acc_S_std*100:.2f}%")
    print(f"Original model average accuracy on D_train subset R (retained classes): {acc_R_mean*100:.2f}% ± {acc_R_std*100:.2f}%")
    print(f"Original model average accuracy on retained set (D_r): {acc_Dr_mean*100:.2f}% ± {acc_Dr_std*100:.2f}%")
    print(f"Original model average accuracy on forgotten set (D_f): {acc_Df_mean*100:.2f}% ± {acc_Df_std*100:.2f}%")
    print(f"Original model average accuracy on validation set V_f: {acc_val_for_mean*100:.2f}% ± {acc_val_for_std*100:.2f}%")
    print(f"Original model average accuracy on validation set V_r: {acc_val_ret_mean*100:.2f}% ± {acc_val_ret_std*100:.2f}%\n")
    

    from collections import defaultdict
    from uncertainties import ufloat

    # pick the keys on which you want to group (everything except seed and the actual metrics)
    group_keys = ['forgotten', 'lambda_reg', 'gamma', 'alpha', 'beta', 'epochs',
                'lr', 'delta', 'non-conformity', 'loss_function', 'c']

    # bucket stats by those keys
    buckets = defaultdict(list)
    for st in all_stats:
        key = tuple((len(st[k]) if isinstance(st[k], list) else st[k])
                    for k in group_keys)
        buckets[key].append(st)

    # prepare containers for plotting data
    H_dict = {}         # will hold H values per (c, alpha)
    efn_for_dict = {}   # will hold efn_for per (c, alpha)
    efn_ret_dict = {}   # will hold efn_ret per (c, alpha)

    with open(f"../results/CQMU_{dataset}_{mode}.txt", 'a+') as f:
        f.write(f"\n\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}\n")
        f.write(f"Dataset: {dataset}, Model Type: {model_type}, Mode: {mode}\n")

        for key, recs in buckets.items():
            # unpack group_keys
            lambda_val = key[1]
            c_val     = key[-1]

            # pull out each metric into numpy arrays
            unlearn_time     = np.array([r['total_unlearn_time']     for r in recs])
            mia_scores_un    = np.array([r['mia_score_difference']   for r in recs])
            acc_S_aft_vec    = np.array([r['acc_S_after']          for r in recs])
            acc_R_aft_vec    = np.array([r['acc_R_after']          for r in recs])
            acc_Dr_vec       = np.array([r['acc_Dr']               for r in recs])
            acc_Df_vec       = np.array([r['acc_Df']               for r in recs])
            acc_val_for_vec  = np.array([r['acc_val_for']          for r in recs])
            acc_val_ret_vec  = np.array([r['acc_val_ret']          for r in recs])
            efn_for_arr      = np.array([r['efn_for']              for r in recs])
            efn_ret_arr      = np.array([r['cover_ret']            for r in recs])
            efn_s_arr        = np.array([r['efn_s']                for r in recs])
            efn_r_arr        = np.array([r['cover_r']              for r in recs])
            efn_val_for_arr  = np.array([r['efn_val_for']          for r in recs])
            efn_val_ret_arr  = np.array([r['cover_val_ret']        for r in recs])
            cr_for_arr      = np.array([r['cr_for']               for r in recs])
            cr_ret_arr      = np.array([r['cr_ret']               for r in recs])
            cr_s_arr        = np.array([r['cr_s']                 for r in recs])
            cr_r_arr        = np.array([r['cr_r']                 for r in recs])
            cr_val_for_arr  = np.array([r['cr_val_for']           for r in recs])
            cr_val_ret_arr  = np.array([r['cr_val_ret']           for r in recs])
            avg_D_arr     = np.array([r['avg_D']                for r in recs])
            avg_T_arr     = np.array([r['avg_T']                for r in recs])
            avg_V_arr     = np.array([r['avg_V']                for r in recs])

            # compute means and stds
            mean_time    = unlearn_time.mean();        std_time    = unlearn_time.std()
            mean_mia_scores_un = mia_scores_un.mean(); std_mia_scores_un = mia_scores_un.std()
            mean_Sa, std_Sa = acc_S_aft_vec.mean(), acc_S_aft_vec.std()
            mean_Ra, std_Ra = acc_R_aft_vec.mean(), acc_R_aft_vec.std()
            mean_Dr, std_Dr = acc_Dr_vec.mean(),     acc_Dr_vec.std()
            mean_Df, std_Df = acc_Df_vec.mean(),     acc_Df_vec.std()
            mean_val_for, std_val_for = acc_val_for_vec.mean(), acc_val_for_vec.std()
            mean_val_ret, std_val_ret = acc_val_ret_vec.mean(), acc_val_ret_vec.std()
            mean_efn_for, std_efn_for = efn_for_arr.mean(),     efn_for_arr.std()
            mean_efn_ret, std_efn_ret = efn_ret_arr.mean(),     efn_ret_arr.std()
            mean_efn_s,   std_efn_s   = efn_s_arr.mean(),       efn_s_arr.std()
            mean_efn_r,   std_efn_r   = efn_r_arr.mean(),       efn_r_arr.std()
            mean_efn_val_for, std_efn_val_for = efn_val_for_arr.mean(), efn_val_for_arr.std()
            mean_efn_val_ret, std_efn_val_ret = efn_val_ret_arr.mean(), efn_val_ret_arr.std()
            mean_cr_for, std_cr_for = cr_for_arr.mean(),     cr_for_arr.std()
            mean_cr_ret, std_cr_ret = cr_ret_arr.mean(),     cr_ret_arr.std()
            mean_cr_s,   std_cr_s   = cr_s_arr.mean(),       cr_s_arr.std()
            mean_cr_r,   std_cr_r   = cr_r_arr.mean(),       cr_r_arr.std()
            mean_cr_val_for, std_cr_val_for = cr_val_for_arr.mean(), cr_val_for_arr.std()
            mean_cr_val_ret, std_cr_val_ret = cr_val_ret_arr.mean(), cr_val_ret_arr.std()
           
            mean_avg_D, std_avg_D = avg_D_arr.mean(),     avg_D_arr.std()
            mean_avg_T, std_avg_T = avg_T_arr.mean(),     avg_T_arr.std()
            mean_avg_V, std_avg_V = avg_V_arr.mean(),     avg_V_arr.std()

            # harmonic mean H over the six efn metrics
            e1 = ufloat(mean_efn_for,     std_efn_for)
            e2 = ufloat(mean_efn_ret,     std_efn_ret)
            e3 = ufloat(mean_efn_s,       std_efn_s)
            e4 = ufloat(mean_efn_r,       std_efn_r)
            e5 = ufloat(mean_efn_val_for, std_efn_val_for)
            e6 = ufloat(mean_efn_val_ret, std_efn_val_ret)
            try:
                H = 6 / (1/e1 + 1/e2 + 1/e3 + 1/e4 + 1/e5 + 1/e6)
            except ZeroDivisionError:
                H = ufloat(0.0, 0.0)

            # record into plotting dicts
            H_dict.setdefault(c_val, {})[lambda_val]           = (H.nominal_value, H.std_dev)
            efn_for_dict.setdefault(c_val, {})[lambda_val]     = (mean_efn_for, std_efn_for)
            efn_ret_dict.setdefault(c_val, {})[lambda_val]     = (mean_efn_ret, std_efn_ret)

            # --- write to file as before ---
            f.write("\n" + "#"*50 + "\n")
            params_str = ", ".join(f"{k}={v}" for k, v in zip(group_keys, key))
            f.write(f"{params_str}\n\n")
            f.write(f"Unlearn cycle time:        {mean_time:.2f} ± {std_time:.2f}s\n\n")
            f.write(f"MIA score difference:      {mean_mia_scores_un*100:.2f}% ± {std_mia_scores_un*100:.2f}%\n\n")
            f.write(f"A_Tr after:                {mean_Ra*100:.2f}% ± {std_Ra*100:.2f}%\n")
            f.write(f"A_Tf after:                {mean_Sa*100:.2f}% ± {std_Sa*100:.2f}%\n\n")
            f.write(f"A_Dr:                      {mean_Dr*100:.2f}% ± {std_Dr*100:.2f}%\n")
            f.write(f"A_Df:                      {mean_Df*100:.2f}% ± {std_Df*100:.2f}%\n\n")
            f.write(f"A_Vr:                      {mean_val_ret*100:.2f}% ± {std_val_ret*100:.2f}%\n")
            f.write(f"A_Vf:                      {mean_val_for*100:.2f}% ± {std_val_for*100:.2f}%\n\n")
            f.write(f"frakC_Dr:                  {mean_efn_ret:.2f} ± {std_efn_ret:.2f}\n")
            f.write(f"frakN_Df:                  {mean_efn_for:.2f} ± {std_efn_for:.2f}\n")
            f.write(f"frakC_Tr:                  {mean_efn_r:.2f} ± {std_efn_r:.2f}\n")
            f.write(f"frakN_Tf:                  {mean_efn_s:.2f} ± {std_efn_s:.2f}\n")
            f.write(f"frakC_Vr:                  {mean_efn_val_ret:.2f} ± {std_efn_val_ret:.2f}\n")
            f.write(f"frakN_Vf:                  {mean_efn_val_for:.2f} ± {std_efn_val_for:.2f}\n")
            f.write(f"H:                         {H.nominal_value:.2f} ± {H.std_dev:.2f}\n\n\n")

            f.write(f"CR_Dr:                     {mean_cr_ret:.2f} ± {std_cr_ret:.2f}\n")
            f.write(f"CR_Df:                     {mean_cr_for:.2f} ± {std_cr_for:.2f}\n")
            f.write(f"CR_Tr:                     {mean_cr_r:.2f} ± {std_cr_r:.2f}\n")
            f.write(f"CR_Tf:                     {mean_cr_s:.2f} ± {std_cr_s:.2f}\n")
            f.write(f"CR_Vr:                     {mean_cr_val_ret:.2f} ± {std_cr_val_ret:.2f}\n")
            f.write(f"CR_Vf:                     {mean_cr_val_for:.2f} ± {std_cr_val_for:.2f}\n")

            f.write(f"Avg set_size Df:     {mean_avg_D:.4f} ± {std_avg_D:.4f}\n")
            f.write(f"Avg set_size Tf:     {mean_avg_T:.4f} ± {std_avg_T:.4f}\n")
            f.write(f"Avg set_size Vf:     {mean_avg_V:.4f} ± {std_avg_V:.4f}\n")

        f.write(f"\n\n-----------------###################################-------------------\n\n\n\n\n\n")
    


if __name__ == "__main__":
    p = argparse.ArgumentParser(
        description="Run experiment with various scenarios and seeds"
    )

    # scenarios: list of [ [ints...], [floats...], int ] tuples
    p.add_argument(
        "--scenarios",
        type=literal_eval,
        default=[[[100],[0.1],20]],
        help=(
            "list of scenarios, e.g. "
            "[[[c_s],[alphas],number_of_classes_or_clusters_or_points], [[...]]]"
        )
    )

    # seeds: list of ints
    p.add_argument(
        "--seeds",
        type=literal_eval,
        default=[0],
        help="list of integer seeds, e.g. [0,1,2]"
    )

    # epoch: single integer
    p.add_argument(
        "--epoch",
        type=int,
        default=8,
        help="number of training epochs (integer)"
    )

    # lr: learning rate float
    p.add_argument(
        "--lr",
        type=float,
        default=0.01,
        help="learning rate (float)"
    )

    # lam: regularization parameter float
    p.add_argument(
        "--lam",
        type=float,
        default=0.001,
        help="regularization lambda (float)"
    )

    # gamma: steepness parameter float
    p.add_argument(
        "--gamma",
        type=float,
        default=10.0,
        help="surrogate steepness parameter (float)"
    )

    # delta: margin buffer float
    p.add_argument(
        "--delta",
        type=float,
        default=0.0001,
        help="margin buffer parameter (float)"
    )

    # dataset: simple string
    p.add_argument(
        "--dataset",
        type=str,
        choices=["cifar100", "imagenet", "20_newsgroups", "news"],
        default="cifar100",
        help="dataset name, e.g. cifar100 or imagenet"
    )

    # model type: simple string
    p.add_argument(
        "--model_type",
        type=str,
        choices=["resnet18", "resnet18_imagenet", "vit", "berta_distill"],
        default="resnet18",
        help="model architecture, e.g. resnet18 or efficientnetv2l"
    )

    # batch size: integer
    p.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="batch size for training, unlearning, and evaluation (integer)"
    )

    # mode: simple string
    p.add_argument(
        "--mode",
        type=str,
        choices=["label", "pca", "cluster", "instance-label", "instance-pca", "instance-cluster", "random", "instance-random"],
        default="label",
        help="operating mode, e.g. label or pca"
    )

    # calibration set name: simple string
    p.add_argument(
        "--cp_calib",
        type=str,
        choices=["calib", "calib_val"],
        default="calib",
        help="calibration set, e.g. calib or calib_val"
    )

    # text type dataset: boolean
    p.add_argument(
        "--text",
        type=int,
        choices=[0,1],
        default=0,
        help="if 1, the dataset is text, if 0, it is image/else (integer)"
    )

    # variable forgotten class/cluster choice
    p.add_argument(
        "--fixed-forgotten",
        type=bool,
        choices=[0,1],
        default=0,
        help="if 1, use fixed forgotten classes/clusters, if 0, random (boolean)"
    )

    args = p.parse_args()

    if args.fixed_forgotten == 0:
        vary = True
    else:
        vary = False

    main(args.scenarios, args.seeds, args.epoch, args.lr, args.lam, args.gamma, args.delta, args.cp_calib, args.dataset, args.model_type, args.batch_size, args.mode, args.text)
    # -----------------------------------------------------------------------------
    # Example usage:
    # python CQMU_general.py --scenarios [[[100],[0.1],20]] --seeds [0] --epoch 20 --lr 0.01 --lam 0.1 --gamma 1.0 --mode label --dataset cifar100 --model_type resnet18 --text 0 --batch_size 256