import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
import torch
import numpy as np
import argparse
import copy
from torch.utils.data import DataLoader
from ast import literal_eval
from helper_functions import *



def main(scenarios, seeds, cp_calib, dataset, model_type, batch_size, mode, text):
    import random
    print(f"\nRunning RT with the following parameters:")
    print(f"Scenarios: {scenarios}")
    print(f"Seeds: {seeds}")
    print(f"Dataset: {dataset}, Model Type: {model_type}, Batch Size: {batch_size}, Mode: {mode}")
    if text == 1:
        trans = True
    else:
        trans = False

    if dataset == "cifar100" or dataset == "imagenet":
        num_classes = 100
    elif dataset == "news":
        num_classes = 4
    elif dataset == "20_newsgroups":
        num_classes = 20
    elif dataset == "ucf101":
        num_classes = 101
    elif "kin400" in dataset:
        num_classes = 400
    # -----------------------------------------------------------------------------
    def load_pretrained_model(model_type, file_name, device):
        model = create_model(model_type, device, num_classes=num_classes)
        model.load_state_dict(torch.load(f"../models/{file_name}.pth"))
        model = model.to(device)
        model.eval()
        return model

    with open(f"../results/RT_{dataset}.txt", 'a+') as f:
        f.write(f"\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}, over seeds: {seeds}\n")
        f.write(f"Dataset: {dataset}, Model Type: {model_type}, Mode: {mode}, CP Calibration: {cp_calib}\n")

    # a place to stash your per‐seed results
    all_stats = []

    # -----------------------------------------------------------------------------
    # Device Setup
    # -----------------------------------------------------------------------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPUs!")

    seed_data = 0
    rng_data = set_seed(seed_data)
    for scenario in scenarios:
        c_s, alphas, forgot_set = scenario
        print(f"Working on scenario: {scenario}")
        forgotten_set = forgot_set
        # if vary:
        #     forgotten_set = forgot_set
        # else:
        #     forgotten_set = perm_forget_set[:forgot_set]
        # call your function and get back a dict
        loaders = OLD_load_data(dataset, forgotten_set, rng_data, mode, batch_size=batch_size, n_work=2)

        # now unpack each entry into its own variable
        train_loader               = loaders["train_loader"]
        test_loader                = loaders["test_loader"]
        calibration_loader         = loaders["cal_test_loader"]
        forget_loader              = loaders["Df_loader"]
        retained_loader            = loaders["Dr_loader"]
        calibration_val_loader     = loaders["cal_unlearn_loader"]
        forget_val_loader          = loaders["Vf_loader"]
        retained_val_loader        = loaders["Vr_loader"]
        S_loader                   = loaders["Tf_loader"]
        R_loader                   = loaders["Tr_loader"]
        calibration_forget_loader  = loaders["cal_test_forget_loader"]
        calibration_retain_loader  = loaders["cal_test_retain_loader"]
        D_calib_val                = loaders["cal_unlearn_subset"]
        D_f_val                    = loaders["Vf_subset"]
        D_r_val                    = loaders["Vr_subset"]
        D_calib                    = loaders["cal_test_subset"]
        D_f                        = loaders["Df_subset"]
        D_r                        = loaders["Dr_subset"]
        test_dataset               = loaders["test_subset"]
        R_small_loader             = loaders["Tr_small_loader"]

        if train_loader is None:
            length_train = 0
        else:
            length_train = len(train_loader.dataset)

        print(f"Loaded dataset with {length_train} training samples, "
                f"{len(calibration_loader.dataset)} calibration samples, "
                f"{len(forget_loader.dataset)} forget samples, "
                f"{len(retained_loader.dataset)} retained samples, "
                f"{len(calibration_val_loader.dataset)} calibration validation samples, "
                f"{len(forget_val_loader.dataset)} forget validation samples, "
                f"{len(retained_val_loader.dataset)} retained validation samples, "
                f"{len(S_loader.dataset)} S samples, "
                f"{len(R_loader.dataset)} R samples.")
        # -----------------------------------------------------------------------------
        # Set up the model and training parameters
        # -----------------------------------------------------------------------------
        if cp_calib == "calib":
            cp_set = D_calib
            unlearn_set = D_calib_val
        else:
            cp_set = D_calib_val
            unlearn_set = D_calib

        # We will do hyperparam searching over these lists
        nonconf_func = "one_minus"
        optimizing = "RT"
        norm = "l2_sq"
        print(f"{length_train/1000}k training samples, ")
    
        for seed in seeds:
            rng = set_seed(seed)
            random.seed(seed)
            model_filename = f"original_{dataset}_{model_type}_{int(length_train/1000)}k_seed{seed_data}"

            if os.path.exists(f"../models/{model_filename}.pth"):
                model = load_pretrained_model(model_type, model_filename, device)
                total_unlearn_time = 0.0
            else:
                model, total_unlearn_time = build_base_model(model_type, model_filename, train_loader,  calibration_loader, device, num_classes, trans=trans)
                print(f"Total learning time is: {total_unlearn_time}")
            # -----------------------------------------------------------------------------

            original_state_dict = copy.deepcopy(model.state_dict())
            original_model = load_pretrained_model(model_type, model_filename, device)

            # Evaluate initial model performance on entire test set
            initial_test_accuracy = 0 #evaluate(original_model, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2), device, trans=trans)
            print(f"Initial model test accuracy: {initial_test_accuracy*100:.2f}%")

            accuracy_S_o = 0 #evaluate(original_model, S_loader, device, trans=trans)
            accuracy_R_o = 0 #evaluate(model, R_loader, device, trans=trans)
            print(f"Accuracy before on D_train subset S (forgotten classes): {accuracy_S_o*100:.2f}%")
            print(f"Accuracy before on D_train subset R (retained classes): {accuracy_R_o*100:.2f}%")

            accuracy_retained_o = 0 #evaluate(original_model, retained_loader, device, trans=trans)
            accuracy_forget_o = 0 #evaluate(original_model, forget_loader, device, trans=trans)
            print(f"Original model accuracy on retained set (D_r): {accuracy_retained_o*100:.2f}%")
            print(f"Original model accuracy on forgotten set (D_f): {accuracy_forget_o*100:.2f}%")

            accuracy_validation_for_o = 0 #evaluate(original_model, forget_val_loader, device, trans=trans)
            accuracy_validation_ret_o = 0 #evaluate(original_model, retained_val_loader, device, trans=trans)
            print(f"Original model accuracy on validation set V_f: {accuracy_validation_for_o*100:.2f}%")
            print(f"Original model accuracy on validation set V_r: {accuracy_validation_ret_o*100:.2f}%")



            with open(f"../results/RT_{dataset}.txt", 'a+') as f:
                f.write(f"\n\n------------------------------------\n")
                f.write(f"Scenario: {scenario}, Seed: {seed}\n")
                f.write(f"Initial model test accuracy: {initial_test_accuracy*100:.2f}%\n")
                f.write(f"Accuracy before on D_train subset S (forgotten classes): {accuracy_S_o*100:.2f}%\n")
                f.write(f"Accuracy before on D_train subset R (retained classes): {accuracy_R_o*100:.2f}%\n")
                f.write(f"Original model accuracy on retained set (D_r): {accuracy_retained_o*100:.2f}%\n")
                f.write(f"Original model accuracy on forgotten set (D_f): {accuracy_forget_o*100:.2f}%\n")
                f.write(f"Original model accuracy on validation set V_f: {accuracy_validation_for_o*100:.2f}%\n")
                f.write(f"Original model accuracy on validation set V_r: {accuracy_validation_ret_o*100:.2f}%\n")
            
            model_filename = f"retrain_{dataset}_{model_type}_{mode}_{forgot_set}_seed{seed}"

            if os.path.exists(f"../models/{model_filename}.pth"):
                model_unlearned = load_pretrained_model(model_type, model_filename, device)
                total_unlearn_time = 0.0
            else:
                model_unlearned, total_unlearn_time = build_base_model(model_type, model_filename, R_loader,  calibration_loader, device, num_classes, trans=trans)
                print(f"Total unlearning time is: {total_unlearn_time}")

            for alpha in alphas:
                print(f"→ RT  with α={alpha} Working on {forgot_set} fogotten classes.")
                # -----------------------------------------------------------------------------
                # CP-based Machine Unlearning
                # -----------------------------------------------------------------------------
                
                # Find the quantile scores of the original model and the unlearned model.
                # q_hat_original, waste = conformal_prediction_quantile_and_returnall(original_model, cp_set, alpha, nonconf_func, trans=trans)
                q_hat_unlearned, waste = conformal_prediction_quantile_and_returnall(model_unlearned, cp_set, alpha, nonconf_func, trans=trans)
                # print(f"q_hat (original model): {q_hat_original:.4f}")
                print(f"q_hat (unlearned model): {q_hat_unlearned:.4f}")
                
                # Final evaluations
                # Evaluate initial model performance on entire test set
                unlearned_test_accuracy = evaluate(model_unlearned, DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2), device, trans=trans)
                print(f"Unlearned model test accuracy: {unlearned_test_accuracy*100:.2f}%")
                
                unlearned_accuracy_retained = evaluate(model_unlearned, retained_loader, device, trans=trans)
                unlearned_accuracy_forget = evaluate(model_unlearned, forget_loader, device, trans=trans)
                print(f"Unlearned model accuracy on retained set (D_r): {unlearned_accuracy_retained*100:.2f}%")
                print(f"Unlearned model accuracy on forgotten set (D_f): {unlearned_accuracy_forget*100:.2f}%")
            
                accuracy_S = evaluate(model_unlearned, S_loader, device, trans=trans)
                accuracy_R = evaluate(model_unlearned, R_loader, device, trans=trans)
            
                print(f"Accuracy after on D_train subset S (forgotten classes): {accuracy_S*100:.2f}%")
                print(f"Accuracy after on D_train subset R (retained classes): {accuracy_R*100:.2f}%")
            

                accuracy_val_for = evaluate(model_unlearned, forget_val_loader, device, trans=trans)
                accuracy_val_ret = evaluate(model_unlearned, retained_val_loader, device, trans=trans)
                print(f"Accuracy on validation set V_f: {accuracy_val_for*100:.2f}%")
                print(f"Accuracy on validation set V_r: {accuracy_val_ret*100:.2f}%")

                # Compute membership inference attack scores
                mia_scores_un, mia_mean_acc_un = compute_membership_inference_attack(
                                model=model_unlearned,
                                member_loader=S_loader,
                                nonmember_loader=calibration_forget_loader,
                                device=device,
                                n_splits=10,
                                random_state=seed, trans=trans
                            )
                
                for c in c_s:
                    print(f"→ RT  with α={alpha}, c={c}, Working on {forgot_set} forgotten classes.")
                    # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                    efn_ret, efn_for, avg_D = compute_ccucr(
                        model_unlearned, forget_loader, retained_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                
                    # Compute updated CCUCR on D_f (forgotten set) and D_r (retained set).
                    efn_r, efn_s, avg_T = compute_ccucr(
                        model_unlearned, S_loader, R_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )
                    
                    efn_ret_val, efn_for_val, avg_V = compute_ccucr(
                        model_unlearned, forget_val_loader, retained_val_loader,
                        q_hat_unlearned, nonconf_func, c=c, trans=trans
                    )

                    # cr_ret, cr_for = compute_cr(model_unlearned, forget_loader, retained_loader, q_hat_unlearned, nonconf_func, trans=trans)
                    # cr_r, cr_s = compute_cr(model_unlearned, S_loader, R_loader, q_hat_unlearned, nonconf_func, trans=trans)
                    # cr_ret_val, cr_for_val = compute_cr(model_unlearned, forget_val_loader, retained_val_loader, q_hat_unlearned, nonconf_func, trans=trans)

                    # compute harmonic mean H
                    H1 = harmonic_mean([efn_ret, efn_r, efn_ret_val])
                    H2 = harmonic_mean([efn_for, efn_s, efn_for_val])
                    H = harmonic_mean([H1, H2])
                        
                    all_stats.append({
                        'seed': seed,
                        'forgotten': forgot_set,
                        'alpha': alpha,
                        'non-conformity': nonconf_func,
                        'loss_function': optimizing,
                        'c': c,
                        'total_unlearn_time': total_unlearn_time,
                        'unlearned_test_acc': unlearned_test_accuracy,
                        'initial_test_acc': initial_test_accuracy,
                        # 'qhat_orig': q_hat_original,
                        'qhat_unlearn': q_hat_unlearned,
                        'acc_S_before': accuracy_S_o,
                        'acc_S_after': accuracy_S,
                        'acc_R_before': accuracy_R_o,
                        'acc_R_after': accuracy_R,
                        'acc_Dr': unlearned_accuracy_retained,
                        'acc_Df': unlearned_accuracy_forget,
                        'acc_val_for': accuracy_val_for,
                        'acc_val_ret': accuracy_val_ret,
                        'mia_score_difference': mia_mean_acc_un,
                        'efn_for': efn_for,
                        'cover_ret': efn_ret,
                        'efn_s': efn_s,
                        'cover_r': efn_r,
                        'efn_val_for': efn_for_val,
                        'cover_val_ret': efn_ret_val,
                        'H_retain': H1,
                        'H_forget': H2,
                        'H': H,
                        # 'cr_for': cr_for,
                        # 'cr_ret': cr_ret,
                        # 'cr_s': cr_s,
                        # 'cr_r': cr_r,
                        # 'cr_val_for': cr_for_val,
                        # 'cr_val_ret': cr_ret_val,
                        'avg_D': avg_D,
                        'avg_T': avg_T,
                        'avg_V': avg_V
                    })
                    with open(f"../results/RT_{dataset}.txt", 'a+') as f:
                        f.write(f"\n\n------------------------------------\nauxiliary result:\n{all_stats[-1]}")
                    # ─────────────────────────────────────────────

    from collections import defaultdict
    from uncertainties import ufloat

    # pick the keys on which you want to group (everything except seed and the actual metrics)
    group_keys = ['forgotten', 'alpha','non-conformity', 'loss_function', 'c']

    # bucket stats by those keys
    buckets = defaultdict(list)
    for st in all_stats:
        key = tuple((len(st[k]) if isinstance(st[k], list) else st[k])
                    for k in group_keys)
        buckets[key].append(st)

    # prepare containers for plotting data
    H_dict = {}         # will hold H values per (c, alpha)
    efn_for_dict = {}   # will hold efn_for per (c, alpha)
    efn_ret_dict = {}   # will hold efn_ret per (c, alpha)

    with open(f"../results/RT_{dataset}.txt", 'a+') as f:
        f.write(f"\n\n\n\n-----------------###################################-------------------\n")
        f.write(f"Results for scenarios: {scenarios}\n")
        f.write(f"Dataset: {dataset}, Model Type: {model_type}, Mode: {mode}\n")

        for key, recs in buckets.items():
            # unpack group_keys
            lambda_val = key[1]
            c_val     = key[-1]

            # pull out each metric into numpy arrays
            unlearn_time     = np.array([r['total_unlearn_time']     for r in recs])
            mia_scores_un    = np.array([r['mia_score_difference']   for r in recs])
            acc_S_aft_vec    = np.array([r['acc_S_after']          for r in recs])
            acc_R_aft_vec    = np.array([r['acc_R_after']          for r in recs])
            acc_Dr_vec       = np.array([r['acc_Dr']               for r in recs])
            acc_Df_vec       = np.array([r['acc_Df']               for r in recs])
            acc_val_for_vec  = np.array([r['acc_val_for']          for r in recs])
            acc_val_ret_vec  = np.array([r['acc_val_ret']          for r in recs])
            efn_for_arr      = np.array([r['efn_for']              for r in recs])
            efn_ret_arr      = np.array([r['cover_ret']            for r in recs])
            efn_s_arr        = np.array([r['efn_s']                for r in recs])
            efn_r_arr        = np.array([r['cover_r']              for r in recs])
            efn_val_for_arr  = np.array([r['efn_val_for']          for r in recs])
            efn_val_ret_arr  = np.array([r['cover_val_ret']        for r in recs])

            # compute means and stds
            mean_time    = unlearn_time.mean();        std_time    = unlearn_time.std()
            mean_mia_scores_un = mia_scores_un.mean(); std_mia_scores_un = mia_scores_un.std()
            mean_Sa, std_Sa = acc_S_aft_vec.mean(), acc_S_aft_vec.std()
            mean_Ra, std_Ra = acc_R_aft_vec.mean(), acc_R_aft_vec.std()
            mean_Dr, std_Dr = acc_Dr_vec.mean(),     acc_Dr_vec.std()
            mean_Df, std_Df = acc_Df_vec.mean(),     acc_Df_vec.std()
            mean_val_for, std_val_for = acc_val_for_vec.mean(), acc_val_for_vec.std()
            mean_val_ret, std_val_ret = acc_val_ret_vec.mean(), acc_val_ret_vec.std()
            mean_efn_for, std_efn_for = efn_for_arr.mean(),     efn_for_arr.std()
            mean_efn_ret, std_efn_ret = efn_ret_arr.mean(),     efn_ret_arr.std()
            mean_efn_s,   std_efn_s   = efn_s_arr.mean(),       efn_s_arr.std()
            mean_efn_r,   std_efn_r   = efn_r_arr.mean(),       efn_r_arr.std()
            mean_efn_val_for, std_efn_val_for = efn_val_for_arr.mean(), efn_val_for_arr.std()
            mean_efn_val_ret, std_efn_val_ret = efn_val_ret_arr.mean(), efn_val_ret_arr.std()

            # harmonic mean H over the six efn metrics
            e1 = ufloat(mean_efn_for,     std_efn_for)
            e2 = ufloat(mean_efn_ret,     std_efn_ret)
            e3 = ufloat(mean_efn_s,       std_efn_s)
            e4 = ufloat(mean_efn_r,       std_efn_r)
            e5 = ufloat(mean_efn_val_for, std_efn_val_for)
            e6 = ufloat(mean_efn_val_ret, std_efn_val_ret)
            try:
                H = 6 / (1/e1 + 1/e2 + 1/e3 + 1/e4 + 1/e5 + 1/e6)
            except ZeroDivisionError:
                H = ufloat(0.0, 0.0)

            # record into plotting dicts
            H_dict.setdefault(c_val, {})[lambda_val]           = (H.nominal_value, H.std_dev)
            efn_for_dict.setdefault(c_val, {})[lambda_val]     = (mean_efn_for, std_efn_for)
            efn_ret_dict.setdefault(c_val, {})[lambda_val]     = (mean_efn_ret, std_efn_ret)

            # --- write to file as before ---
            f.write("\n" + "#"*50 + "\n")
            params_str = ", ".join(f"{k}={v}" for k, v in zip(group_keys, key))
            f.write(f"{params_str}\n\n")
            f.write(f"Unlearn cycle time:        {mean_time:.2f} ± {std_time:.2f}s\n\n")
            f.write(f"MIA score difference:      {mean_mia_scores_un*100:.2f}% ± {std_mia_scores_un*100:.2f}%\n\n")
            
            f.write(f"A_Tr after:                {mean_Ra*100:.2f}% ± {std_Ra*100:.2f}%\n")
            f.write(f"A_Tf after:                {mean_Sa*100:.2f}% ± {std_Sa*100:.2f}%\n\n")
            f.write(f"A_Dr:                      {mean_Dr*100:.2f}% ± {std_Dr*100:.2f}%\n")
            f.write(f"A_Df:                      {mean_Df*100:.2f}% ± {std_Df*100:.2f}%\n\n")
            f.write(f"A_Vr:                      {mean_val_ret*100:.2f}% ± {std_val_ret*100:.2f}%\n")
            f.write(f"A_Vf:                      {mean_val_for*100:.2f}% ± {std_val_for*100:.2f}%\n\n")
            
            f.write(f"frakC_Dr:                  {mean_efn_ret:.2f} ± {std_efn_ret:.2f}\n")
            f.write(f"frakN_Df:                  {mean_efn_for:.2f} ± {std_efn_for:.2f}\n")
            f.write(f"frakC_Tr:                  {mean_efn_r:.2f} ± {std_efn_r:.2f}\n")
            f.write(f"frakN_Tf:                  {mean_efn_s:.2f} ± {std_efn_s:.2f}\n")
            f.write(f"frakC_Vr:                  {mean_efn_val_ret:.2f} ± {std_efn_val_ret:.2f}\n")
            f.write(f"frakN_Vf:                  {mean_efn_val_for:.2f} ± {std_efn_val_for:.2f}\n")
            f.write(f"H:                     {H.nominal_value:.2f} ± {H.std_dev:.2f}\n")

        f.write(f"\n\n-----------------###################################-------------------\n\n\n\n\n\n")
    

if __name__ == "__main__":
    p = argparse.ArgumentParser(
        description="Run experiment with various scenarios and seeds"
    )

    # scenarios: list of [ [ints...], [floats...], int ] tuples
    p.add_argument(
        "--scenarios",
        type=literal_eval,
        default=[[[100],[0.1],20]],
        help=(
            "list of scenarios, e.g. "
            "[[[c_s],[alphas],number_of_classes_or_clusters_or_points], [[...]]]"
        )
    )

    # seeds: list of ints
    p.add_argument(
        "--seeds",
        type=literal_eval,
        default=[0],
        help="list of integer seeds, e.g. [0,1,2]"
    )

    # dataset: simple string
    p.add_argument(
        "--dataset",
        type=str,
        choices=["cifar100", "imagenet", "20_newsgroups", "news"],
        default="cifar100",
        help="dataset name, e.g. cifar100 or imagenet"
    )

    # model type: simple string
    p.add_argument(
        "--model_type",
        type=str,
        choices=["resnet18", "resnet18_imagenet", "vit", "berta_distill"],
        default="resnet18",
        help="model architecture, e.g. resnet18 or berta_distill"
    )

    # batch size: integer
    p.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="batch size for training, unlearning, and evaluation (integer)"
    )

    # mode: simple string
    p.add_argument(
        "--mode",
        type=str,
        choices=["label", "pca", "cluster", "instance-label", "instance-pca", "instance-cluster", "random"],
        default="label",
        help="operating mode, e.g. label or pca"
    )

    # calibration set name: simple string
    p.add_argument(
        "--cp_calib",
        type=str,
        choices=["calib", "calib_val"],
        default="calib",
        help="calibration set, e.g. calib or calib_val"
    )

    # text type dataset: boolean
    p.add_argument(
        "--text",
        type=int,
        choices=[0,1],
        default=0,
        help="if 1, the dataset is text, if 0, it is image/else (integer)"
    )

    args = p.parse_args()

    main(args.scenarios, args.seeds, args.cp_calib, args.dataset, args.model_type, args.batch_size, args.mode, args.text)

    