# Standard library
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,1,0"
import time
import argparse
from ast import literal_eval

# Numerical & plotting
import numpy as np

# PyTorch core
import torch
import torch.nn as nn
import torch.optim as optim

# PyTorch data utils
from torch.utils.data import DataLoader

# Project-specific utilities
from Machine_Unlearning.Metrics.metrics import *
from helper_functions import *




def main(scenarios, seeds, lr, cp_calib, dataset, model_type, batch_size, mode, baseline, text, parameters):
    print(f"Scenarios: {scenarios}, Baseline: {baseline}")
    print(f"Seeds: {seeds}")
    print(f"Dataset: {dataset}, Model Type: {model_type}, Batch Size: {batch_size}, Mode: {mode}")
    if text == 1:
        trans = True
    else:
        trans = False

    if dataset == "cifar100" or dataset == "imagenet":
        num_classes = 100
    elif dataset == "news":
        num_classes = 4
    elif dataset == "20_newsgroups":
        num_classes = 20
    elif dataset == "ucf101":
        num_classes = 101
    elif "kin400" in dataset:
        num_classes = 400
    # -----------------------------------------------------------------------------
    def load_pretrained_model(model_type, file_name, device):
        model = create_model(model_type, device, num_classes=num_classes)
        model.load_state_dict(torch.load(f"../../models/{file_name}.pth"))
        model = model.to(device)
        model.eval()
        return model
    # -----------------------------------------------------------------------------
    ##Nabla Functions ###############################################
    def unlearning(net, retain, forget, validation, test, params, forget_epochs = 27, use_scheduler = True, trans=False):
        ### FORGETTING ###
        lr = params.get('lr',0.001)
        split = params.get('split', 0.5)
        start_alpha = split * 5/3  # starting_alpha is now derived from split
        criterion = nn.CrossEntropyLoss(reduction='none')
        optimizer = optim.AdamW(net.parameters(), lr= lr )
        forget_epochs = forget_epochs
        scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=forget_epochs)

        net.eval()

        alpha = start_alpha

        def alpha_sched(start_a, a, max_ep, ep):
            return a - (start_a / max_ep)

        net.eval()

        retain_iter = iter(retain)

        def entropy(outputs):
            p = torch.nn.functional.softmax(outputs, dim=-1)
            return (-torch.where(p > 0, p * p.log(), p.new([0.0])).sum(dim=-1, keepdim=False))

        total_start_time = time.time()
        for i in range(forget_epochs):          
                
            net.eval()

            if i%5==0 :
                print("Computing current moments on test set")
                val_loss, first_test_moment, second_test_moment, test_std = compute_moments(net, validation, trans=trans)
                print("Computed moments: "+str(val_loss)+","+str(first_test_moment)+","+str(second_test_moment))


            ft_forget_losses = compute_losses(net, forget, trans=trans)
            ft_test_losses = compute_losses(net, test, trans=trans)

            gen = np.random.default_rng(seed)

            if len(ft_test_losses) > len(ft_forget_losses):
                gen.shuffle(ft_test_losses)
                ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
            else:
                gen.shuffle(ft_forget_losses)
                ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]

            # make sure we have a balanced dataset for the MIA
            assert len(ft_test_losses) == len(ft_forget_losses)

            ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
            labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

            ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

            print(
                f"The MIA has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
            )
            mia_metric_scores.append(ft_mia_scores.mean())

            acc = 100.0 * accuracy(net, test, trans=trans)
            print(f"Accuracy on test set: {acc:.1f} ")
            accuracy_metric_scores.append(acc)


            net.eval()

            print("Forgetting epoch "+str(i))

            step = max(1, (len(retain)//len(forget)))

            if i % step==0:
                print("Resetting retain iterator...")
                retain_iter = iter(retain)

            print("using alpha: "+str(alpha))
            for c , (inputs, targets) in enumerate(forget):


                net.zero_grad()

                if trans:
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                else:
                    inputs = inputs.to(device)
                
                targets = targets.to(device)


                if trans:
                    out = net(**inputs)
                else:
                    out = net(inputs)

                try:
                    r_inputs, r_targets = next(retain_iter)
                except StopIteration:
                    # we've consumed every batch—start over
                    retain_iter = iter(retained_loader)
                    r_inputs, r_targets = next(retain_iter)
                
                if trans:
                    r_inputs = {k: v.to(device) for k, v in r_inputs.items()}
                else:
                    r_inputs = r_inputs.to(device)
                
                r_targets = r_targets.to(device)

                if trans:
                    r_out = net(**r_inputs)
                else:
                    r_out = net(r_inputs)


                forget_losses = criterion(out, targets)
                retain_losses = criterion(r_out,r_targets)

                #Forget loss metrics
                forget_mean = torch.mean(forget_losses)
                #print(forget_mean)
                forget_var = torch.mean((forget_losses-forget_mean)**2)
                forget_std = forget_var**0.5
                forget_skew = torch.mean((forget_losses-forget_mean)**3) / (forget_std**3)

                delta_val_loss =  (val_loss - forget_mean)
                delta_first_moment = (first_test_moment - forget_var)
                delta_second_moment = (second_test_moment - forget_skew)

                #Retain loss metric
                retain_mean = torch.mean(retain_losses)


                if c % 40 == 0:
                    print("delta_val_loss: "+str(delta_val_loss.item()))
                    print("delta_first_moment: "+str(delta_first_moment.item()))
                    print("delta_second_moment: "+str(delta_second_moment.item()))

                loss =  alpha*(torch.nn.functional.relu(delta_val_loss)**2) + (1-alpha)* retain_mean


                loss.backward()
                optimizer.step()

            alpha = alpha_sched(start_alpha,alpha,forget_epochs,i)
            if use_scheduler:
                scheduler.step()

        total_unlearning_time = time.time() - total_start_time          

        net.eval()
        ft_forget_losses = compute_losses(net, forget_loader, trans=trans)
        ft_test_losses = compute_losses(net, test, trans=trans)

        gen = np.random.default_rng(seed)

        if len(ft_test_losses) > len(ft_forget_losses):
            gen.shuffle(ft_test_losses)
            ft_test_losses = ft_test_losses[: len(ft_forget_losses)]
        else:
            gen.shuffle(ft_forget_losses)
            ft_forget_losses = ft_forget_losses[: len(ft_test_losses)]
        # make sure we have a balanced dataset for the MIA
        assert len(ft_test_losses) == len(ft_forget_losses)

        ft_samples_mia = np.concatenate((ft_test_losses, ft_forget_losses)).reshape((-1, 1))
        labels_mia = [0] * len(ft_test_losses) + [1] * len(ft_forget_losses)

        ft_mia_scores = simple_mia(ft_samples_mia, labels_mia)

        print(
            f"The MIA_loss has an accuracy of {ft_mia_scores.mean():.3f} on forgotten vs unseen images"
        )

        mia_metric_scores.append(ft_mia_scores.mean())

        acc = 100.0 * accuracy(net, test, trans=trans)
        print(f"Accuracy on test set: {acc:.1f} ")
        accuracy_metric_scores.append(acc)

        net.eval()
        return net, total_unlearning_time

    ###
    def alpha_sched(start_a,a,max_ep,ep):
        #return start_a
        #if ep > 10:
            #return a - (start_a/(max_ep-10))
        #else:
        # return a
        #return start_a /(ep+1)
        return a - (start_a/(max_ep))

    ###




    ##SCRUB Functions ###############################################
    from SCRUB.thirdparty.repdistiller.helper.util import adjust_learning_rate as sgda_adjust_learning_rate
    from SCRUB.thirdparty.repdistiller.distiller_zoo import DistillKL #, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss
    # from SCRUB.thirdparty.repdistiller.distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss

    from SCRUB.thirdparty.repdistiller.helper.loops import train_distill, validate #, train_distill_hide, train_distill_linear, train_vanilla, train_negrad, train_bcu, train_bcu_distill
    # from SCRUB.thirdparty.repdistiller.helper.pretrain import init

    import copy

    #!mkdir checkpoints

    def scrub(teacher, student, seed, forget_validate_loader, params, trans=False):

        class AttributeDict(dict):
            __getattr__ = dict.__getitem__
            __setattr__ = dict.__setitem__
            __delattr__ = dict.__delitem__
        args = AttributeDict({
        # --- Unlearning Hyperparameters (pulled from grid or defaults) ---
        'gamma': params.get('gamma', 1.0),
        'beta': params.get('beta', 0.0),
        'msteps': params.get('msteps', 3),
        'sstart': params.get('sstart', 10),
        'kd_T': params.get('kd_T', 4.0),
        
        # --- Optimizer Hyperparameters (pulled from grid or defaults) ---
        'sgda_epochs': params.get('sgda_epochs', 6),
        'sgda_learning_rate': params.get('sgda_lr', 0.0005), # Note the key change
        'sgda_weight_decay': params.get('sgda_weight_decay', 5e-4),
        'sgda_momentum': params.get('sgda_momentum', 0.9),

        # --- Fixed or Non-Tuned Parameters ---
        'optim': 'sgd',
        'distill': 'kd',
        'model': "resnet18",
        'dataset': "cifar100",
        'seed': seed,
        
        # These were in the original list but are less commonly tuned.
        # They are set to their default here but could be added to the grid.
        'alpha': params.get('alpha', 0.5),
        'smoothing': params.get('smoothing', 0.5),
        'clip': params.get('clip', 0.2),
        'lr_decay_epochs': params.get('lr_decay_epochs', [3, 5, 9]),
        'lr_decay_rate': params.get('lr_decay_rate', 0.1),
        })


        print(args)
        # print(args.clip)
        model_t = copy.deepcopy(teacher)
        model_s = copy.deepcopy(student)

        #For SGDA smoothing
        beta = 0.1
        def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (
            1 - beta) * averaged_model_parameter + beta * model_parameter
        swa_model = torch.optim.swa_utils.AveragedModel(
            model_s, avg_fn=avg_fn)

        module_list = nn.ModuleList([])
        module_list.append(model_s)
        trainable_list = nn.ModuleList([])
        trainable_list.append(model_s)

        criterion_cls = nn.CrossEntropyLoss()
        criterion_div = DistillKL(args.kd_T)
        criterion_kd = DistillKL(args.kd_T)


        criterion_list = nn.ModuleList([])
        criterion_list.append(criterion_cls)    # classification loss
        criterion_list.append(criterion_div)    # KL divergence loss, original knowledge distillation
        criterion_list.append(criterion_kd)     # other knowledge distillation loss

        acc_fs = []

        # optimizer
        if args.optim == "sgd":
            optimizer = optim.SGD(trainable_list.parameters(),
                                lr=args.sgda_learning_rate,
                                momentum=args.sgda_momentum,
                                weight_decay=args.sgda_weight_decay)
        elif args.optim == "adam":
            optimizer = optim.Adam(trainable_list.parameters(),
                                lr=args.sgda_learning_rate,
                                weight_decay=args.sgda_weight_decay)
        elif args.optim == "rmsp":
            optimizer = optim.RMSprop(trainable_list.parameters(),
                                lr=args.sgda_learning_rate,
                                momentum=args.sgda_momentum,
                                weight_decay=args.sgda_weight_decay)

        module_list.append(model_t)

        if torch.cuda.is_available():
            module_list.cuda()
            criterion_list.cuda()
            import torch.backends.cudnn as cudnn
            cudnn.benchmark = True
            swa_model.cuda()

        scrub_name = "checkpoints/scrub_{}_{}_seed{}_step".format(args.model, args.dataset, args.seed)
        total_start_time = time.time()
        for epoch in range(1, args.sgda_epochs + 1):
                
            lr = sgda_adjust_learning_rate(epoch, args, optimizer)

            acc_f, acc5_f, loss_f = validate(forget_validate_loader, model_s, criterion_cls, args, True, trans=trans)
            acc_fs.append(100-acc_f.item())


            maximize_loss = 0
            if epoch <= args.msteps:
                maximize_loss = train_distill(epoch, forget_loader, module_list, swa_model, criterion_list, optimizer, args, "maximize", trans=trans)
            train_acc, train_loss = train_distill(epoch, retained_loader, module_list, swa_model, criterion_list, optimizer, args, "minimize", trans=trans)
            if epoch >= args.sstart:
                swa_model.update_parameters(model_s)

            torch.save(model_s.state_dict(), scrub_name+str(epoch)+".pt")


            print ("maximize loss: {:.2f}\t minimize loss: {:.2f}\t train_acc: {}".format(maximize_loss, train_loss, train_acc))
                
        total_unlearning_time = time.time() - total_start_time

        acc_f, acc5_f, loss_f = validate(forget_validate_loader, model_s, criterion_cls, args, True, trans=trans)
        acc_fs.append(100-acc_f.item())


        try:
            selected_idx, _ = min(enumerate(acc_fs), key=lambda x: abs(x[1]-acc_fvs[-1]))
        except:
            selected_idx = len(acc_fs) - 1
        print ("the selected index is {}".format(selected_idx))
        #selected_model = "checkpoints/scrub_{}_{}_seed{}_step{}.pt".format(args.model, args.dataset, args.seed, int(selected_idx))
        model_s_final = copy.deepcopy(model_s)
        #model_s.load_state_dict(torch.load(selected_model))


        return model_s, model_s_final, total_unlearning_time



    ##SSD Functions ###############################################
    #!pip install wandb

    import SSD.src.ssd as ssd


    def ssd_tuning(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        full_train_dl,
        device,
        params,
        trans=False,
        **kwargs,
    ):
        parameters = {
        "dampening_constant": params['dampening_constant'],
        "selection_weighting": params['selection_weighting'],
        # Other SSD params are kept constant for this tuning
        "lower_bound": 1, "exponent": 1, "magnitude_diff": None,
        "min_layer": -1, "max_layer": -1, "forget_threshold": 1,
    }

        # load the trained model
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
        total_start_time = time.time()
        pdr = ssd.ParameterPerturber(model, optimizer, device, parameters, trans=trans)

        model = model.eval()

        # Calculation of the forget set importances
        sample_importances = pdr.calc_importance(forget_train_dl)

        # Calculate the importances of D (see paper); this can also be done at any point before forgetting.
        original_importances = pdr.calc_importance(full_train_dl)

        # Dampen selected parameters
        pdr.modify_weight(original_importances, sample_importances)
        total_unlearning_time = time.time() - total_start_time
        return model, total_unlearning_time



    ##Amensiac Functions ###############################################
    ##Amensiac Functions ###############################################
    def epoch_end(model, epoch, result):
        print(
            "Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}".format(
                epoch,
                result["lrs"][-1],
                result["train_loss"],
                result["Loss"],
                #result["Acc"],
            )
        )

    def training_step(model, batch, device, trans=False):
        images, clabels = batch
        if trans:
            images = {k: v.to(device) for k, v in images.items()}
        else:
            images = images.to(device)

        clabels = clabels.to(device)
        
        if trans:
            out = model(**images)  # Generate predictions
        else:
            out = model(images)
        
        loss = nn.functional.cross_entropy(out, clabels)  # Calculate loss
        return loss


    @torch.no_grad()
    def evaluate_amen(model, val_loader, device, trans=False):
        model.eval()
        outputs = [validation_step(model, batch, device, trans=trans) for batch in val_loader]
        return validation_epoch_end(model, outputs)

    def validation_step(model, batch, device, trans=False):
        images, clabels = batch
        if trans:
            images = {k: v.to(device) for k, v in images.items()}
        else:
            images = images.to(device)

        clabels = clabels.to(device)
        
        if trans:
            out = model(**images)  # Generate predictions
        else:
            out = model(images)
        loss = nn.functional.cross_entropy(out, clabels)  # Calculate loss
        #acc = accuracy(out, clabels)  # Calculate accuracy
        return {"Loss": loss.detach()}#, "Acc": acc}

    def validation_epoch_end(model, outputs):
        batch_losses = [x["Loss"] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()  # Combine losses
        #batch_accs = [x["Acc"] for x in outputs]
        #epoch_acc = torch.stack(batch_accs).mean()  # Combine accuracies
        return {"Loss": epoch_loss.item()}#, "Acc": epoch_acc.item()}

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def fit_one_unlearning_cycle(epochs, model, train_loader, val_loader, lr, device, trans=False):
        history = []

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        total_start_time = time.time()
        for epoch in range(epochs):            
            model.train()
            train_losses = []
            lrs = []
            for batch in train_loader:
                loss = training_step(model, batch, device, trans=trans)
                loss.backward()
                train_losses.append(loss.detach().cpu())

                optimizer.step()
                optimizer.zero_grad()

                lrs.append(get_lr(optimizer))

            result = evaluate_amen(model, val_loader, device, trans=trans)

            result["train_loss"] = torch.stack(train_losses).mean()
            result["lrs"] = lrs
            epoch_end(model, epoch, result)
            history.append(result)
            
        total_unlearning_time = time.time() - total_start_time

        return history, total_unlearning_time

    def amnesiac(
        model,
        unlearning_teacher,
        retain_train_dl,
        retain_valid_dl,
        forget_train_dl,
        forget_valid_dl,
        valid_dl,
        num_classes,
        device,
        params,
        trans=False,
        **kwargs,
    ):
        unlearninglabels = list(range(num_classes))
        unlearning_trainset = []

        for x, clabel in forget_train_dl.dataset:
            rnd = random.choice(unlearninglabels)
            while rnd == clabel:
                rnd = random.choice(unlearninglabels)
            unlearning_trainset.append((x, rnd))

        for x, y in retain_train_dl.dataset:
            unlearning_trainset.append((x, y))

        unlearning_train_set_dl = DataLoader(
            unlearning_trainset, 128, pin_memory=True, shuffle=True
        )

        epochs = params['epochs']
        lr = params['lr']

        _, total_unlearning_time = fit_one_unlearning_cycle(
            epochs, model, unlearning_train_set_dl, retain_valid_dl, device=device, lr=lr, trans=trans
        )
        return model, total_unlearning_time
    
    def check_if_params_exist(params_to_check: dict, log_filepath: str) -> bool:
        """
        Checks if a given set of parameters already exists in the experiment log file,
        where the parameter dictionary is on the line immediately following a header.

        This function is designed to prevent re-running experiments that have already
        completed. It looks for lines containing "auxiliary result:" and then parses
        the *next* line as a dictionary to check for a parameter match.

        Args:
            params_to_check: A dictionary of parameters for the current experiment.
            log_filepath: The path to the text file containing experiment results.

        Returns:
            True if the exact parameter combination is found in the log, False otherwise.
        """
        # If the log file doesn't exist, no params can exist in it.
        if not os.path.exists(log_filepath):
            print("Could not find log file!")
            return False

        try:
            with open(log_filepath, 'r') as f:
                # Create an iterator for the file to allow peeking at the next line
                file_iterator = iter(f)
                for line in file_iterator:
                    # Check if the current line is our header
                    if line.strip() == "auxiliary result:":
                        try:
                            # If it is, the dictionary is on the NEXT line
                            dict_line = next(file_iterator)
                            
                            # Safely parse the string from that next line into a Python dictionary.
                            logged_result = literal_eval(dict_line.strip())
                            
                            # Get the nested parameters dictionary from the parsed result.
                            logged_params = logged_result.get('params')

                            # If the key 'params' exists and its value matches our
                            # current parameters, we have found a duplicate.
                            if logged_params and logged_params == params_to_check:
                                return True
                        
                        except StopIteration:
                            # This happens if the file ends with the "auxiliary result:" line.
                            # The log is incomplete, so we can break the loop.
                            break
                        except (ValueError, SyntaxError):
                            # This can happen if a line after the header is malformed or blank.
                            # It's safe to just skip it and check the next "auxiliary result:" header.
                            continue
                            
        except Exception as e:
            print(f"Warning: Could not read or parse log file '{log_filepath}'. Error: {e}")
            # To be safe, we assume the params don't exist if we can't read the file.
            return False

        # If we've checked the whole file and found no match, return False.
        return False

    ##Evaluate Functions ###############################################
    def evaluate_all_metrics(baseline,
                            model,
                            model_unlearned,
                            parameters,
                            retained_loader,
                            forget_loader,
                            calibration_loader,
                            S_loader,
                            R_loader,
                            forget_val_loader,
                            retained_val_loader,
                            calibration_val_loader,
                            test_dataset,
                            calibration_forget_loader,
                            seed,
                            device,
                            number_of_forget,
                            time,
                            alphas,
                            all_stats,
                            c_s= [0],
                            nonconf_func = "one_minus"):
        # Evaluate initial model performance on entire test set
        initial_test_accuracy = evaluate(model, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4), device, trans=trans)
        print(f"Initial model test accuracy: {initial_test_accuracy*100:.2f}%")

        acc_vf = evaluate(model, forget_val_loader, device, trans=trans)
        acc_vr = evaluate(model, retained_val_loader, device, trans=trans)
        print(f"Accuracy before on Vf (forgotten classes): {acc_vf*100:.2f}%")
        print(f"Accuracy before on Vr (retained classes): {acc_vr*100:.2f}%")

        # Final evaluations
        unlearned_accuracy_retained = evaluate(model_unlearned, retained_loader, device, trans=trans)
        unlearned_accuracy_forget = evaluate(model_unlearned, forget_loader, device, trans=trans)
        print(f"Unlearned model accuracy on retained set (D_r): {unlearned_accuracy_retained*100:.2f}%")
        print(f"Unlearned model accuracy on forgotten set (D_f): {unlearned_accuracy_forget*100:.2f}%")

        accuracy_S = evaluate(model_unlearned, S_loader, device, trans=trans)
        accuracy_R = evaluate(model_unlearned, R_loader, device, trans=trans)

        print(f"Accuracy after on D_train subset S (forgotten classes): {accuracy_S*100:.2f}%")
        print(f"Accuracy after on D_train subset R (retained classes): {accuracy_R*100:.2f}%")
        
        accuracy_S_o = evaluate(model, S_loader, device, trans=trans)
        accuracy_R_o = evaluate(model, R_loader, device, trans=trans)

        accuracy_val_for = evaluate(model_unlearned, forget_val_loader, device, trans=trans)
        accuracy_val_ret = evaluate(model_unlearned, retained_val_loader, device, trans=trans)

        print(f"Accuracy before on D_train subset S (forgotten classes): {accuracy_S_o*100:.2f}%")
        print(f"Accuracy before on D_train subset R (retained classes): {accuracy_R_o*100:.2f}%")
        # Compute membership inference attack scores
        mia_scores_un, mia_mean_acc_un = compute_membership_inference_attack(
                        model=model_unlearned,
                        member_loader=S_loader,
                        nonmember_loader=calibration_forget_loader,
                        device=device,
                        n_splits=10,
                        random_state=seed, trans=trans
                    )

        for alpha in alphas:
            q_hat_original, waste = conformal_prediction_quantile_and_returnall(model, D_calib, alpha, nonconf_func, trans=trans)
            q_hat_unlearned, waste = conformal_prediction_quantile_and_returnall(model_unlearned, D_calib, alpha, nonconf_func, trans=trans)
            print(f"q_hat (original model): {q_hat_original:.4f}")
            print(f"q_hat (unlearned model): {q_hat_unlearned:.4f}")
            alpha_init = alpha
            for c in c_s:
                # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                efn_ret, efn_for = compute_ccucr(
                    model_unlearned, forget_loader, retained_loader,
                    q_hat_unlearned, nonconf_func, c=c, trans=trans
                )
            
                # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                efn_r, efn_s = compute_ccucr(
                    model_unlearned, S_loader, R_loader,
                    q_hat_unlearned, nonconf_func, c=c, trans=trans
                )

                efn_ret_val, efn_for_val = compute_ccucr(
                        model_unlearned, forget_val_loader, retained_val_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                
                # cr_ret, cr_for = compute_cr(model_unlearned, forget_loader, retained_loader, q_hat_unlearned, nonconf_func, trans=trans)
                # cr_r, cr_s = compute_cr(model_unlearned, S_loader, R_loader, q_hat_unlearned, nonconf_func, trans=trans)
                # cr_ret_val, cr_for_val = compute_cr(model_unlearned, forget_val_loader, retained_val_loader, q_hat_unlearned, nonconf_func, trans=trans)
            
                # compute harmonic mean H
                H1 = harmonic_mean([efn_ret, efn_r, efn_ret_val])
                H2 = harmonic_mean([efn_for, efn_s, efn_for_val])
                H = harmonic_mean([H1, H2])

                all_stats.append({
                    'model_name': baseline,
                    'seed': seed,
                    'params': parameters,
                    'forgotten': number_of_forget,
                    'alpha': alpha,
                    'non-conformity': nonconf_func,
                    'c': c,
                    'total_unlearn_time': time,
                    'initial_test_acc': initial_test_accuracy,
                    'qhat_orig': q_hat_original,
                    'qhat_unlearn': q_hat_unlearned,
                    'acc_S_before': accuracy_S_o,
                    'acc_S_after': accuracy_S,
                    'acc_R_before': accuracy_R_o,
                    'acc_R_after': accuracy_R,
                    'acc_Dr': unlearned_accuracy_retained,
                    'acc_Df': unlearned_accuracy_forget,
                    'acc_val_for': accuracy_val_for,
                    'acc_val_ret': accuracy_val_ret,
                    'mia_score_difference': mia_mean_acc_un,
                    'efn_for':     efn_for,
                    'cover_ret':     efn_ret,
                    'efn_s':       efn_s,
                    'cover_r':       efn_r,
                    'efn_val_for': efn_for_val,
                    'cover_val_ret': efn_ret_val,
                    'H_retain':       H1,
                    'H_forget':       H2,
                    'H': H,
                    # 'cr_for': cr_for,
                    # 'cr_ret': cr_ret,
                    # 'cr_s': cr_s,
                    # 'cr_r': cr_r,
                    # 'cr_val_for': cr_for_val,
                    # 'cr_val_ret': cr_ret_val
                })
                with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
                    f.write(f"\n\n------------------------------------\nauxiliary result:\n{all_stats[-1]}")
        return all_stats
    

    # -----------------------------------------------------------------------------
    # Device Setup
    # -----------------------------------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPUs!")
    # ----------------------------------------------------------------------------
    with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
        f.write(f"\n\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}, over seeds: {seeds}\n")
        f.write(f"Baseline: {baseline}, Dataset: {dataset}, Devices: {num_gpus},\nModel Type: {model_type}, Mode: {mode},\nCP Calibration: {cp_calib}\n")
    
    if baseline == "all":
        baselines = ["nabla_tau", "SCRUB", "SSD", "AMN"]
    else:
        baselines = [baseline,]

    for method in baselines:
        params = parameters[method]
        # a place to stash your per‐seed results
        all_stats = []

        seed_data = 0
        rng_data = set_seed(seed_data)
        for scenario in scenarios:
            c_s, alphas, forgot_set = scenario
            print(f"Working on scenario: {scenario}")
            # call your function and get back a dict
            loaders = load_dataset_and_transform(dataset, forgot_set, rng_data, mode, batch_size, 1)

            # now unpack each entry into its own variable
            train_loader               = loaders["train_loader"]
            calibration_loader         = loaders["calibration_loader"]
            forget_loader              = loaders["forget_loader"]
            retained_loader            = loaders["retained_loader"]
            calibration_val_loader     = loaders["calibration_val_loader"]
            forget_val_loader          = loaders["forget_val_loader"]
            retained_val_loader        = loaders["retained_val_loader"]
            S_loader                   = loaders["S_loader"]
            R_loader                   = loaders["R_loader"]
            calibration_forget_loader  = loaders["calibration_forget_loader"]
            calibration_retain_loader  = loaders["calibration_retain_loader"]
            D_calib_val                = loaders["D_calib_val"]
            D_f_val                    = loaders["D_f_val"]
            D_r_val                    = loaders["D_r_val"]
            D_calib                    = loaders["D_calib"]
            D_f                        = loaders["D_f"]
            D_r                        = loaders["D_r"]
            test_dataset               = loaders["test_dataset"]
            R_small_loader             = loaders["R_small_loader"]
            print(f"Loaded dataset with {len(train_loader.dataset)} training samples, "
                    f"{len(calibration_loader.dataset)} calibration samples, "
                    f"{len(forget_loader.dataset)} forget samples, "
                    f"{len(retained_loader.dataset)} retained samples, "
                    f"{len(calibration_val_loader.dataset)} calibration validation samples, "
                    f"{len(forget_val_loader.dataset)} forget validation samples, "
                    f"{len(retained_val_loader.dataset)} retained validation samples, "
                    f"{len(S_loader.dataset)} S samples, "
                    f"{len(R_loader.dataset)} R samples.")
            # -----------------------------------------------------------------------------
            # Set up the model and training parameters
            # -----------------------------------------------------------------------------
            print(f"{len(train_loader.dataset)/1000}k training samples, ")

            for seed in seeds:
                rng = set_seed(seed)
                random.seed(seed)
                model_filename = f"{dataset}_{model_type}_base_{int(len(train_loader.dataset)/1000)}k_seed{seed_data}"
                
                # -----------------------------------------------------------------------------
                # Initial Training on D_train (all classes)
                # -----------------------------------------------------------------------------
                if os.path.exists(f"../../models/{model_filename}.pth"):
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    model = load_pretrained_model(model_type, model_filename, device)
                else:
                    original_model, train_time = build_base_model(model_type, model_filename, train_loader, num_classes, device, trans=trans)
                    print(f"Initial training time: {train_time:.2f} seconds")
                    model = load_pretrained_model(model_type, model_filename, device)
                
                if method == "nabla_tau":
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    results = {}
                    torch.cuda.empty_cache()        
                    
                    accuracy_metric_scores = []
                    mia_metric_scores = []
                    print(len(train_loader))
                    print(len(retained_loader))
                    print(len(forget_loader))
                    
                    forget_epochs = int((len(retained_loader) / (len(forget_loader)*2)) * 6)
                    print(forget_epochs)
                    model_unlearned, unlearn_time = unlearning(original_model, retained_loader, forget_loader, forget_val_loader, calibration_forget_loader, params, forget_epochs = forget_epochs, use_scheduler = True, trans=trans)

                elif method == "SCRUB":
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    teacher = load_pretrained_model(model_type, model_filename, device)
                    student = load_pretrained_model(model_type, model_filename, device)
                    model_s, model_s_final, unlearn_time = scrub(teacher, student, seed, forget_val_loader, params, trans=trans)
                    model_unlearned = model_s_final

                elif method == "SSD":
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    unlearning_teacher = create_model(model_type, device, num_classes=num_classes)
                    kwargs = {
                        "model": original_model,
                        "unlearning_teacher": unlearning_teacher,
                        "retain_train_dl": retained_loader,
                        "retain_valid_dl": retained_val_loader,
                        "forget_train_dl": forget_loader,
                        "forget_valid_dl": forget_val_loader,
                        "full_train_dl": train_loader,
                        "valid_dl": calibration_val_loader,
                        "num_classes": num_classes,
                        "dataset_name": dataset,
                        "device": device,
                        "model_name": model_type,
                        "params": params,
                        "num_classes": num_classes,
                        "trans": trans,
                    }
                    model_unlearned, unlearn_time = ssd_tuning(**kwargs)

                elif method == "AMN":
                    original_model = load_pretrained_model(model_type, model_filename, device)
                    unlearning_teacher = create_model(model_type, device, num_classes=num_classes)
                    kwargs = {
                        "model": original_model,
                        "unlearning_teacher": unlearning_teacher,
                        "retain_train_dl": retained_loader,
                        "retain_valid_dl": retained_val_loader,
                        "forget_train_dl": forget_loader,
                        "forget_valid_dl": forget_val_loader,
                        "full_train_dl": train_loader,
                        "valid_dl": calibration_val_loader,
                        "num_classes": num_classes,
                        "dataset_name": dataset,
                        "device": device,
                        "model_name": model_type,
                        "params": params,
                        "num_classes": num_classes,
                        "trans": trans,
                    }
                    model_unlearned, unlearn_time = amnesiac(**kwargs)

                all_stats = evaluate_all_metrics(baseline,
                                        model,
                                        model_unlearned,
                                        params,
                                        retained_loader,
                                        forget_loader,
                                        calibration_loader,
                                        S_loader,
                                        R_loader,
                                        forget_val_loader,
                                        retained_val_loader,
                                        calibration_val_loader,
                                        test_dataset,
                                        calibration_forget_loader,
                                        seed,
                                        device,
                                        forgot_set,
                                        unlearn_time,
                                        alphas,
                                        all_stats,
                                        c_s= c_s,
                                        nonconf_func = "one_minus")



    # #######################################
    # #######################################

        from collections import defaultdict
        from uncertainties import ufloat

        # pick the keys on which you want to group (everything except seed and the actual metrics)
        group_keys = ['model_name', 'forgotten','alpha','c', 'non-conformity']

        # bucket stats by those keys
        buckets = defaultdict(list)
        for st in all_stats:
            key = tuple((len(st[k]) if isinstance(st[k], list) else st[k]) for k in group_keys)
            buckets[key].append(st)
        print(f"The buckets dict is: {buckets}")
        # now compute averages and stds
        with open(f"../../results/{baseline}_{dataset}_{mode}.txt", 'a+') as f:
            for key, recs in buckets.items():
                # pull out each metric into numpy arrays
                # pull out each metric into numpy arrays
                unlearn_time     = np.array([r['total_unlearn_time']     for r in recs])
                mia_scores_un    = np.array([r['mia_score_difference']   for r in recs])
                acc_S_aft_vec    = np.array([r['acc_S_after']          for r in recs])
                acc_R_aft_vec    = np.array([r['acc_R_after']          for r in recs])
                acc_Dr_vec       = np.array([r['acc_Dr']               for r in recs])
                acc_Df_vec       = np.array([r['acc_Df']               for r in recs])
                acc_val_for_vec  = np.array([r['acc_val_for']          for r in recs])
                acc_val_ret_vec  = np.array([r['acc_val_ret']          for r in recs])
                efn_for_arr      = np.array([r['efn_for']              for r in recs])
                efn_ret_arr      = np.array([r['cover_ret']            for r in recs])
                efn_s_arr        = np.array([r['efn_s']                for r in recs])
                efn_r_arr        = np.array([r['cover_r']              for r in recs])
                efn_val_for_arr  = np.array([r['efn_val_for']          for r in recs])
                efn_val_ret_arr  = np.array([r['cover_val_ret']        for r in recs])
                # cr_for_arr      = np.array([r['cr_for']               for r in recs])
                # cr_ret_arr      = np.array([r['cr_ret']               for r in recs])
                # cr_s_arr        = np.array([r['cr_s']                 for r in recs])
                # cr_r_arr        = np.array([r['cr_r']                 for r in recs])
                # cr_val_for_arr  = np.array([r['cr_val_for']           for r in recs])
                # cr_val_ret_arr  = np.array([r['cr_val_ret']           for r in recs])

                # compute means and stds
                mean_time    = unlearn_time.mean();        std_time    = unlearn_time.std()
                mean_mia_scores_un = mia_scores_un.mean(); std_mia_scores_un = mia_scores_un.std()
                mean_Sa, std_Sa = acc_S_aft_vec.mean(), acc_S_aft_vec.std()
                mean_Ra, std_Ra = acc_R_aft_vec.mean(), acc_R_aft_vec.std()
                mean_Dr, std_Dr = acc_Dr_vec.mean(),     acc_Dr_vec.std()
                mean_Df, std_Df = acc_Df_vec.mean(),     acc_Df_vec.std()
                mean_val_for, std_val_for = acc_val_for_vec.mean(), acc_val_for_vec.std()
                mean_val_ret, std_val_ret = acc_val_ret_vec.mean(), acc_val_ret_vec.std()
                mean_efn_for, std_efn_for = efn_for_arr.mean(),     efn_for_arr.std()
                mean_efn_ret, std_efn_ret = efn_ret_arr.mean(),     efn_ret_arr.std()
                mean_efn_s,   std_efn_s   = efn_s_arr.mean(),       efn_s_arr.std()
                mean_efn_r,   std_efn_r   = efn_r_arr.mean(),       efn_r_arr.std()
                mean_efn_val_for, std_efn_val_for = efn_val_for_arr.mean(), efn_val_for_arr.std()
                mean_efn_val_ret, std_efn_val_ret = efn_val_ret_arr.mean(), efn_val_ret_arr.std()
                # mean_cr_for, std_cr_for = cr_for_arr.mean(),     cr_for_arr.std()
                # mean_cr_ret, std_cr_ret = cr_ret_arr.mean(),     cr_ret_arr.std()
                # mean_cr_s,   std_cr_s   = cr_s_arr.mean(),       cr_s_arr.std()
                # mean_cr_r,   std_cr_r   = cr_r_arr.mean(),       cr_r_arr.std()
                # mean_cr_val_for, std_cr_val_for = cr_val_for_arr.mean(), cr_val_for_arr.std()
                # mean_cr_val_ret, std_cr_val_ret = cr_val_ret_arr.mean(), cr_val_ret_arr.std()

                # harmonic mean H over the six efn metrics
                e1 = ufloat(mean_efn_for,     std_efn_for)
                e2 = ufloat(mean_efn_ret,     std_efn_ret)
                e3 = ufloat(mean_efn_s,       std_efn_s)
                e4 = ufloat(mean_efn_r,       std_efn_r)
                e5 = ufloat(mean_efn_val_for, std_efn_val_for)
                e6 = ufloat(mean_efn_val_ret, std_efn_val_ret)
                try:
                    H = 6 / (1/e1 + 1/e2 + 1/e3 + 1/e4 + 1/e5 + 1/e6)
                except ZeroDivisionError:
                    H = ufloat(0.0, 0.0)

                # --- write to file as before ---
                f.write("\n" + "#"*50 + "\n")
                params_str = ", ".join(f"{k}={v}" for k, v in zip(group_keys, key))
                f.write(f"{params_str}\n\n")
                f.write(f"Unlearn cycle time:        {mean_time:.2f} ± {std_time:.2f}s\n\n")
                f.write(f"MIA score difference:      {mean_mia_scores_un*100:.2f}% ± {std_mia_scores_un*100:.2f}%\n\n")
                f.write(f"A_Tr after:                {mean_Ra*100:.2f}% ± {std_Ra*100:.2f}%\n")
                f.write(f"A_Tf after:                {mean_Sa*100:.2f}% ± {std_Sa*100:.2f}%\n\n")
                f.write(f"A_Dr:                      {mean_Dr*100:.2f}% ± {std_Dr*100:.2f}%\n")
                f.write(f"A_Df:                      {mean_Df*100:.2f}% ± {std_Df*100:.2f}%\n\n")
                f.write(f"A_Vr:                      {mean_val_ret*100:.2f}% ± {std_val_ret*100:.2f}%\n")
                f.write(f"A_Vf:                      {mean_val_for*100:.2f}% ± {std_val_for*100:.2f}%\n\n")
                f.write(f"frakC_Dr:                  {mean_efn_ret:.2f} ± {std_efn_ret:.2f}\n")
                f.write(f"frakN_Df:                  {mean_efn_for:.2f} ± {std_efn_for:.2f}\n")
                f.write(f"frakC_Tr:                  {mean_efn_r:.2f} ± {std_efn_r:.2f}\n")
                f.write(f"frakN_Tf:                  {mean_efn_s:.2f} ± {std_efn_s:.2f}\n")
                f.write(f"frakC_Vr:                  {mean_efn_val_ret:.2f} ± {std_efn_val_ret:.2f}\n")
                f.write(f"frakN_Vf:                  {mean_efn_val_for:.2f} ± {std_efn_val_for:.2f}\n")
                f.write(f"H:                         {H.nominal_value:.2f} ± {H.std_dev:.2f}\n\n\n")

                # f.write(f"CR_Dr:                     {mean_cr_ret:.2f} ± {std_cr_ret:.2f}\n")
                # f.write(f"CR_Df:                     {mean_cr_for:.2f} ± {std_cr_for:.2f}\n")
                # f.write(f"CR_Tr:                     {mean_cr_r:.2f} ± {std_cr_r:.2f}\n")
                # f.write(f"CR_Tf:                     {mean_cr_s:.2f} ± {std_cr_s:.2f}\n")
                # f.write(f"CR_Vr:                     {mean_cr_val_ret:.2f} ± {std_cr_val_ret:.2f}\n")
                # f.write(f"CR_Vf:                     {mean_cr_val_for:.2f} ± {std_cr_val_for:.2f}\n")

            f.write(f"\n\n-----------------###################################-------------------\n\n\n\n\n\n")
    




if __name__ == "__main__":
    p = argparse.ArgumentParser(
        description="Run experiment with various scenarios and seeds"
    )

    # scenarios: list of [ [ints...], [floats...], int ] tuples
    p.add_argument(
        "--scenarios",
        type=literal_eval,
        default=[[[100],[0.1],20]],
        help=(
            "list of scenarios, e.g. "
            "[[[c_s],[alphas],number_of_classes_or_clusters_or_points], [[...]]]"
        )
    )

    # seeds: list of ints
    p.add_argument(
        "--seeds",
        type=literal_eval,
        default=[0],
        help="list of integer seeds, e.g. [0,1,2]"
    )

    # dataset: simple string
    p.add_argument(
        "--dataset",
        type=str,
        choices=["cifar100", "imagenet", "20_newsgroups", "news"],
        default="cifar100",
        help="dataset name, e.g. cifar100 or imagenet"
    )

    # model type: simple string
    p.add_argument(
        "--model_type",
        type=str,
        choices=["resnet18", "resnet18_imagenet", "vit", "berta_distill"],
        default="resnet18",
        help="model architecture, e.g. resnet18 or efficientnetv2l"
    )

    # batch size: integer
    p.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="batch size for training, unlearning, and evaluation (integer)"
    )

    # mode: simple string
    p.add_argument(
        "--mode",
        type=str,
        choices=["label", "pca", "cluster", "instance-label", "instance-pca", "instance-cluster", "random", "instance-random"],
        default="label",
        help="operating mode, e.g. label or pca"
    )

    # calibration set name: simple string
    p.add_argument(
        "--cp_calib",
        type=str,
        choices=["calib", "calib_val"],
        default="calib",
        help="calibration set, e.g. calib or calib_val"
    )

    # calibration set name: simple string
    p.add_argument(
        "--baseline",
        type=str,
        choices=["nabla_tau", "SSD", "SCRUB", "AMN", "all"],
        default="nabla_tau",
        help="unlearning baseline, e.g. nabla_tau or AMN"
    )

    # text type dataset: boolean
    p.add_argument(
        "--text",
        type=int,
        default=0,
        help="if 1, the dataset is text, if 0, it is image/else (integer)"
    )

    # --- nabla_tau Hyperparameters ---
    group_nabla = p.add_argument_group('nabla_tau Hyperparameters')
    group_nabla.add_argument('--nabla-tau-lr', type=float, default=1e-4, help="Learning rate for nabla_tau.")
    group_nabla.add_argument('--nabla-tau-split', type=float, default=0.5, help="Split ratio for nabla_tau.")

    # --- SCRUB Hyperparameters ---
    group_scrub = p.add_argument_group('SCRUB Hyperparameters')
    group_scrub.add_argument('--sgda-lr', type=float, default=5e-4, help="Learning rate for SCRUB's unlearning process.")
    group_scrub.add_argument('--sgda-epochs', type=int, default=10, help="Number of epochs for SCRUB's unlearning.")
    group_scrub.add_argument('--gamma', type=float, default=0.5, help="Trade-off for SCRUB's maximization loss.")
    group_scrub.add_argument('--beta', type=float, default=0.1, help="Trade-off for SCRUB's minimization loss.")
    group_scrub.add_argument('--msteps', type=int, default=1, help="Initial epochs for maximizing forget set loss in SCRUB.")
    group_scrub.add_argument('--sstart', type=int, default=5, help="Epoch to start SWA in SCRUB.")
    group_scrub.add_argument('--kd-T', type=int, default=2, help="Temperature for Knowledge Distillation in SCRUB.")

    # --- SSD Hyperparameters ---
    group_ssd = p.add_argument_group('SSD Hyperparameters')
    group_ssd.add_argument('--dampening-constant', type=float, default=0.001, help="Dampening constant for SSD.")
    group_ssd.add_argument('--selection-weighting', type=int, default=10, help="Selection weighting for SSD.")

    # --- AMN Hyperparameters ---
    group_amn = p.add_argument_group('AMN Hyperparameters')
    group_amn.add_argument('--amn-epochs', type=int, default=1, help="Number of epochs for AMN.")
    group_amn.add_argument('--amn-lr', type=float, default=1e-2, help="Learning rate for AMN.")


    # ==========================================================================
    # 2. PARSE ARGUMENTS AND RECONSTRUCT THE NESTED DICTIONARY
    # ==========================================================================

    # This line comes after all p.add_argument() calls
    args = p.parse_args()

    # Manually construct the nested dictionaries from the flat argparse namespace
    params_nabla = {
        'lr': args.nabla_tau_lr,
        'split': args.nabla_tau_split
    }

    params_scrub = {
        'sgda_lr': args.sgda_lr,
        'sgda_epochs': args.sgda_epochs,
        'gamma': args.gamma,
        'beta': args.beta,
        'msteps': args.msteps,
        'sstart': args.sstart,
        'kd_T': args.kd_T
    }

    params_ssd = {
        'dampening_constant': args.dampening_constant,
        'selection_weighting': args.selection_weighting
    }

    params_amn = {
        'epochs': args.amn_epochs,
        'lr': args.amn_lr
    }

    # Assemble the final dictionary to be passed to the main function
    parameters = {
        'nabla_tau': params_nabla,
        'SCRUB': params_scrub,
        'SSD': params_ssd,
        'AMN': params_amn
    }

    main(args.scenarios, args.seeds, args.lr, args.cp_calib, args.dataset, args.model_type, args.batch_size, args.mode, args.baseline, args.text, parameters)

    # -----------------------------------------------------------------------------
    #example usage:
    # python nabla_amn_general.py --scenarios [[[100],[0.1],20]] --seeds [0,1,2,3,4,5] --dataset cifar100 --model_type resnet18 --batch_size 256 --mode label --cp_calib calib --baseline nabla_tau --nabla-tau-lr 0.0001 --nabla-tau-split 0.5
    # python nabla_amn_general.py --scenarios [[[100],[0.1],20]] --seeds [0,1,2,3,4,5] --dataset cifar100 --model_type resnet18 --batch_size 256 --mode label --cp_calib calib --baseline SCRUB --sgda-lr 0.0001 --sgda-epochs 20 --gamma 1.0 --beta 0.1 --msteps 3 --sstart 10 --kd-T 4
    # python nabla_amn_general.py --scenarios [[[100],[0.1],20]] --seeds [0,1,2,3,4,5] --dataset cifar100 --model_type resnet18 --batch_size 256 --mode label --cp_calib calib --baseline SSD --dampening-constant 0.1 --selection-weighting 50
    # python nabla_amn_general.py --scenarios [[[100],[0.1],20]] --seeds [0,1,2,3,4,5] --dataset cifar100 --model_type resnet18 --batch_size 256 --mode label --cp_calib calib --baseline AMN --amn-epochs 5 --amn-lr 0.001