import random
import os
import json
import argparse
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import datasets
import models
import method

task = 'full'
parser = argparse.ArgumentParser()
# param for MU
parser.add_argument("-seed", type=int, default=0, help="seed for runs")
parser.add_argument("-net", type=str, required=True, help="net type")
parser.add_argument("-forget_class",type=int)
parser.add_argument(
    "-dataset",
    type=str,
    required=True,
    nargs="?",
    choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "TinyImagenet", "Svhn"],
    help="dataset to train on",
)

# hyper-param for MU
parser.add_argument("-method",type=str,required=True)
parser.add_argument("-b", type=int, default=64)
parser.add_argument("-lr", type=float, default=0.1)
parser.add_argument("-epochs", type=int, default=1)

# hyper-param for each method
args, hyper_params_args = parser.parse_known_args()
hyper_params = {}
for i in range(0, len(hyper_params_args), 2):
    key = hyper_params_args[i].lstrip('-') 
    value = hyper_params_args[i + 1] if i + 1 < len(hyper_params_args) else None
    hyper_params[key] = value
print(f"\033[93m Hyper-Params for {args.method}:\033[0m\n")
print(json.dumps(hyper_params, indent=4))

ori_hyper = hyper_params.copy()
ori_hyper['lr'], ori_hyper['epochs']  = args.lr, args.epochs
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)


img_size = 32
if args.dataset.startswith('Cifar'): classes = int(args.dataset[5:]) 
elif args.dataset== 'PinsFaceRecognition': classes, img_size = 105, 64
elif args.dataset == 'TinyImagenet': classes, img_size = 200, 64
elif args.dataset == 'Svhn': classes = 10
if args.net == "ViT": img_size = 224 

weight_path = f'ckp/{args.net}-{args.dataset}-retrain-{args.seed}.pth' if args.method == 'retrain' else f'ckp/{args.net}-{args.dataset}.pth'
save_path = f'tmp_save/utility/{task}/{args.net}-{args.dataset}-{args.forget_class}/{args.method}.pt'
os.makedirs(f'tmp_save/utility/{task}/{args.net}-{args.dataset}-{args.forget_class}', exist_ok=True)

net = getattr(models, args.net)(num_classes=classes)
net = net.cuda()
net.load_state_dict(torch.load(weight_path))

data_root = 'data/pins_face_recognition' if args.dataset == 'PinsFaceRecognition' else 'data/'
assert os.path.exists(data_root)
hyper_params['retain_sur'] = datasets.get_surrogate(args, hyper_params, num_classes = classes, original_model=net) if args.method == 'scar' else None
trainset = getattr(datasets, args.dataset)(root=data_root, download=True, train=True, unlearning=True, img_size=img_size)
validset = getattr(datasets, args.dataset)(root=data_root, download=True, train=False, unlearning=True, img_size=img_size)
train_dl = DataLoader(trainset,  batch_size=args.b,shuffle=True)
valid_dl = DataLoader(validset, batch_size=args.b, shuffle=False)


os.makedirs(f'data/MU-DATASET/{args.net}', exist_ok = True)
mu_data_path = f'data/MU-DATASET/{args.net}/{task}-{args.dataset}-{args.forget_class}.pth'

if os.path.exists(mu_data_path): 
    (forget_train, retain_train, forget_valid, retain_valid) = torch.load(mu_data_path)
    print('\033[33m DATA LOADED .. \033[0m')
    
else:
    forget_train, retain_train = datasets.build_retain_forget_sets(trainset, args.forget_class)
    forget_valid, retain_valid = datasets.build_retain_forget_sets(validset, args.forget_class)
    torch.save((forget_train, retain_train, forget_valid, retain_valid), mu_data_path)
        
forget_valid_dl = DataLoader(forget_valid, args.b, shuffle=True)
retain_valid_dl = DataLoader(retain_valid, args.b, shuffle=True)
forget_train_dl = DataLoader(forget_train, args.b, shuffle=True)
retain_train_dl = DataLoader(retain_train, args.b, shuffle=True)




if args.method in ['badteacher','euk','cfk' ]:
    # unlearning_teacher = getattr(models, args.net)(num_classes=classes)
    unlearning_teacher = getattr(models, args.net)(num_classes=classes)
    unlearning_teacher = unlearning_teacher.cuda()
else:
    unlearning_teacher = None

base_hyper = {
    "model": net,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_train_dl,
    "retain_valid_dl": retain_valid_dl,
    "forget_train_dl": forget_train_dl,
    "forget_valid_dl": forget_valid_dl,
    "valid_dl": valid_dl,
    "device":'cuda',
    "lr": args.lr,
    "batch_size": args.b, 
    "epochs": args.epochs,
    "num_classes": classes,
    "forget_class": args.forget_class,
    "dataset":args.dataset,
}

hyper_params.update(base_hyper)
# with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):

(   retain_train_acc, 
    retain_valid_acc, 
    forget_train_acc, 
    forget_valid_acc, 
    total_valid_acc, 
    mia), time_elapsed = getattr(method, args.method)(**hyper_params)
if int(args.seed) >0 : torch.save(net.state_dict(), save_path)

print(f"retain train acc:{retain_train_acc}|valid acc:{retain_valid_acc}||forget train acc:{forget_train_acc}|forget valid acc:{forget_valid_acc}||total test acc:{total_valid_acc}|mia:{mia}")


columns = ['Method', 'Seed', "RetainTrainAcc", "RetainValidAcc", "ForgetTrainAcc", "ForgetValidAcc", "TestAcc", "MIA", "MethodTime", "Hyper-params"]
filename = f'results/fullclass-{args.net}-{args.dataset}-forget-{args.forget_class}.csv'
if os.path.exists(filename): results_df = pd.read_csv(filename)
else: results_df = pd.DataFrame()
new_row = pd.DataFrame({'Method': [args.method],
                        'Seed': [args.seed],
                        "RetainTrainAcc": [retain_train_acc],
                        "RetainValidAcc": [retain_valid_acc], 
                        "ForgetTrainAcc": [forget_train_acc], 
                        "ForgetValidAcc": [forget_valid_acc], 
                        "TestAcc": [total_valid_acc], 
                        "MIA": [mia], 
                        "MethodTime": [time_elapsed],
                        "Hyper-params":[str(ori_hyper)]})
new_row.to_csv(filename, mode='a', index=False, header=not os.path.exists(filename))
