import random
import os
import json
import argparse
import datasets
import models
import method
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils import evaluate

parser = argparse.ArgumentParser()
# param for MU
parser.add_argument("-seed", type=int, default=0, help="seed for runs")
parser.add_argument("-net", type=str, required=True, help="net type")
parser.add_argument("-task", type=str, default='sequential_full', help="task typr")

parser.add_argument("-forget_class",type=int)
    
parser.add_argument(
    "-dataset",
    type=str,
    required=True,
    nargs="?",
    choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "TinyImagenet", "Svhn"],
    help="dataset to train on",
)

# hyper-param for MU
parser.add_argument("-method",type=str,required=True)
parser.add_argument("-b", type=int, default=64)
parser.add_argument("-lr", type=float, default=0.1)
parser.add_argument("-epochs", type=int, default=1)

# hyper-param for each method
args, hyper_params_args = parser.parse_known_args()
hyper_params = {}
for i in range(0, len(hyper_params_args), 2):
    key = hyper_params_args[i].lstrip('-') 
    value = hyper_params_args[i + 1] if i + 1 < len(hyper_params_args) else None
    hyper_params[key] = value
print(f"\033[93m Hyper-Params for {args.method}:\033[0m\n")
print(json.dumps(hyper_params, indent=4))

ori_hyper = hyper_params.copy()
ori_hyper['lr'], ori_hyper['epochs']  = args.lr, args.epochs
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)


img_size = 32
if args.dataset.startswith('Cifar'): classes = int(args.dataset[5:]) 
elif args.dataset== 'PinsFaceRecognition': classes, img_size = 105, 64
elif args.dataset == 'TinyImagenet': classes, img_size = 200, 64
elif args.dataset == 'Svhn': classes = 10
if args.net == "ViT": img_size = 224 

# weight_path = f'ckp/{args.net}-{args.dataset}-retrain-{args.seed}.pth' if args.method == 'retrain' else f'ckp/{args.net}-{args.dataset}.pth'


if args.task == 'sequential_full': overall_forget_classes = list(range(5)) 
elif args.task == 'sequential_sub': overall_forget_classes = [54, 62, 70, 82, 92]

forget_index = overall_forget_classes.index(args.forget_class)
forgotten_classes = overall_forget_classes[:forget_index]

if args.method == 'retrain' : weight_path = f'ckp/sequential/{args.task}/{args.net}-{args.dataset}-retrain-{args.forget_class}-{args.seed}.pth'
elif forget_index == 0: weight_path = f'ckp/{args.net}-{args.dataset}-retrain-{args.seed}.pth' if args.method == 'retrain' else f'ckp/{args.net}-{args.dataset}.pth'
else:
    weight_path=f'tmp_save/sequential/{args.task}/{args.method}-{args.net}-{args.dataset}-{forgotten_classes[forget_index-1]}.pt'
os.makedirs(f'tmp_save/sequential/{args.task}/', exist_ok=True)
save_path = f'tmp_save/sequential/{args.task}/{args.method}-{args.net}-{args.dataset}-{args.forget_class}.pt'



net = getattr(models, args.net)(num_classes=classes)
net = net.cuda()
net.load_state_dict(torch.load(weight_path))

data_root = 'data/pins_face_recognition' if args.dataset == 'PinsFaceRecognition' else 'data/'


assert os.path.exists(data_root)
hyper_params['retain_sur'] = datasets.get_surrogate(args, hyper_params, num_classes = classes, original_model=net) if args.method == 'scar' else None
trainset = getattr(datasets, args.dataset)(root=data_root, download=True, train=True, unlearning=True, img_size=img_size)
validset = getattr(datasets, args.dataset)(root=data_root, download=True, train=False, unlearning=True, img_size=img_size)
train_dl = DataLoader(trainset,  batch_size=args.b,shuffle=True)
valid_dl = DataLoader(validset, batch_size=args.b, shuffle=False)


os.makedirs(f'data/MU-DATASET/{args.net}', exist_ok = True)
mu_data_path = f'data/MU-DATASET/{args.net}/{args.task}-{args.dataset}-{args.forget_class}.pth'

if os.path.exists(mu_data_path): 
    (forget_train, retain_train, forgotten_train, forget_valid, retain_valid, forgotten_valid) = torch.load(mu_data_path)
    print('\033[33m DATA LOADED .. \033[0m')
    
else:
    forget_train, retain_train, forgotten_train = datasets.build_sequential_retain_forget_sets(trainset, args.forget_class,forgotten_classes, subclass = 'sub' in args.task)
    forget_valid, retain_valid, forgotten_valid = datasets.build_sequential_retain_forget_sets(validset, args.forget_class, forgotten_classes, subclass = 'sub' in args.task)
    torch.save( (forget_train, retain_train, forgotten_train, forget_valid, retain_valid, forgotten_valid) , mu_data_path)
        
forget_valid_dl = DataLoader(forget_valid, args.b, shuffle=True)
retain_valid_dl = DataLoader(retain_valid, args.b, shuffle=True)

forget_train_dl = DataLoader(forget_train, args.b, shuffle=True)
retain_train_dl = DataLoader(retain_train, args.b, shuffle=True)


forgotten_valid_dl = DataLoader(forgotten_valid, args.b, shuffle=True) if forget_index>0 else forget_valid_dl
forgotten_train_dl = DataLoader(forgotten_train, args.b, shuffle=True) if forget_index>0 else forget_train_dl




if args.method in ['badteacher', ]:
    unlearning_teacher = getattr(models, args.net)(num_classes=classes)
    unlearning_teacher = unlearning_teacher.cuda()
else:
    unlearning_teacher = None

base_hyper = {
    "model": net,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_train_dl,
    "retain_valid_dl": retain_valid_dl,
    "forget_train_dl": forget_train_dl,
    "forget_valid_dl": forget_valid_dl,
    "forgotten_train_dl": forgotten_train_dl,
    "forgotten_valid_dl": forgotten_valid_dl,
    "valid_dl": valid_dl,
    "device":'cuda',
    "lr": args.lr,
    "epochs": args.epochs,
    "num_classes": classes,
    "forget_class": args.forget_class,
    "dataset":args.dataset,
    "batch_size": args.b
}

hyper_params.update(base_hyper)
# with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):

(   retain_train_acc, 
    retain_valid_acc, 
    forget_train_acc, 
    forget_valid_acc, 
    total_valid_acc, 
    mia), time_elapsed = getattr(method, args.method)(**hyper_params)
    
torch.save(net.state_dict(), save_path)

if forget_index == 0: forgotten_train_acc, forgotten_valid_acc = forget_train_acc, forget_valid_acc
else:
    forgotten_train_res, forgotten_valid_res = evaluate(net, forgotten_train_dl, 'cuda'),  evaluate(net, forgotten_valid_dl, 'cuda')
    forgotten_train_acc, forgotten_valid_acc = forgotten_train_res['Acc'], forgotten_valid_res['Acc']

print(f"retain train acc:{retain_train_acc}|valid acc:{retain_valid_acc}||forget train acc:{forget_train_acc}|forget valid acc:{forget_valid_acc}|forgotten train acc:{forgotten_train_acc}|forgotten valid acc:{forgotten_valid_acc}|total test acc:{total_valid_acc}|mia:{mia}")

columns = ['Method', 'Seed', "RetainTrainAcc", "RetainValidAcc", "ForgetTrainAcc", "ForgetValidAcc",  "ForgottenTrainAcc", "ForgottenValidAcc", "TestAcc", "MIA", "MethodTime", "Hyper-params"]
filename = f'results/{args.task}-{args.net}-{args.dataset}-forget-{args.forget_class}.csv'
if os.path.exists(filename): results_df = pd.read_csv(filename)
else: results_df = pd.DataFrame()
new_row = pd.DataFrame({'Method': [args.method],
                        'Seed': [args.seed],
                        "RetainTrainAcc": [retain_train_acc],
                        "RetainValidAcc": [retain_valid_acc], 
                        "ForgetTrainAcc": [forget_train_acc], 
                        "ForgetValidAcc": [forget_valid_acc], 
                        "ForgottenTrainAcc": [forgotten_train_acc], 
                        "ForgottenValidAcc": [forgotten_valid_acc], 
                        "TestAcc": [total_valid_acc], 
                        "MIA": [mia],
                        "MethodTime": [time_elapsed],
                        "Hyper-params":[str(ori_hyper)]})
new_row.to_csv(filename, mode='a', index=False, header=not os.path.exists(filename))
