import random
import os
import json
import argparse
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import datasets
import models
import method

task = 'random'
parser = argparse.ArgumentParser()
# param for MU
parser.add_argument("-seed", type=int, default=0, help="seed for runs")
parser.add_argument("-net", type=str, required=True, help="net type")
parser.add_argument("-forget_perc",type=float, default = 0.1)
parser.add_argument(
    "-dataset",
    type=str,
    required=True,
    nargs="?",
    choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "TinyImagenet", "Svhn"],
    help="dataset to train on",
)

# hyper-param for MU
parser.add_argument("-method",type=str,required=True)
parser.add_argument("-b", type=int, default=64)
parser.add_argument("-lr", type=float, default=0.1)
parser.add_argument("-epochs", type=int, default=1)

# hyper-param for each method
args, hyper_params_args = parser.parse_known_args()
hyper_params = {}
for i in range(0, len(hyper_params_args), 2):
    key = hyper_params_args[i].lstrip('-') 
    value = hyper_params_args[i + 1] if i + 1 < len(hyper_params_args) else None
    hyper_params[key] = value
print(f"\033[93m Hyper-Params for {args.method}:\033[0m\n")
print(json.dumps(hyper_params, indent=4))

ori_hyper = hyper_params.copy()
ori_hyper['lr'], ori_hyper['epochs']  = args.lr, args.epochs
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)


img_size = 32
if args.dataset.startswith('Cifar'): classes = int(args.dataset[5:]) 
elif args.dataset== 'PinsFaceRecognition': classes, img_size = 105, 64
elif args.dataset == 'TinyImagenet': classes, img_size = 200, 64
elif args.dataset == 'Svhn': classes = 10
if args.net == "ViT": img_size = 224 

weight_path = f'ckp/{args.net}-{args.dataset}-retrain-{args.seed}.pth' if args.method == 'retrain' else f'ckp/{args.net}-{args.dataset}.pth'
save_path = f'tmp_save/utility/{task}/{args.net}-{args.dataset}-{args.forget_perc}/{args.method}-{args.seed}.pt'
os.makedirs(f'tmp_save/utility/{task}/{args.net}-{args.dataset}-{args.forget_perc}', exist_ok=True)

net = getattr(models, args.net)(num_classes=classes)
net = net.cuda()
net.load_state_dict(torch.load(weight_path))

data_root = 'data/pins_face_recognition' if args.dataset == 'PinsFaceRecognition' else 'data/'
assert os.path.exists(data_root)
hyper_params['retain_sur'] = datasets.get_surrogate(args, hyper_params, num_classes = classes, original_model=net) if args.method == 'scar' else None
trainset = getattr(datasets, args.dataset)(root=data_root, download=True, train=True, unlearning=True, img_size=img_size)
validset = getattr(datasets, args.dataset)(root=data_root, download=True, train=False, unlearning=True, img_size=img_size)
train_dl = DataLoader(trainset,  batch_size=args.b,shuffle=True)
valid_dl = DataLoader(validset, batch_size=args.b, shuffle=False)


os.makedirs(f'data/MU-DATASET/{args.net}', exist_ok = True)
mu_data_path = f'data/MU-DATASET/{args.net}/{task}-{args.dataset}-{args.forget_perc}.pth'

if os.path.exists(mu_data_path): 
    (forget_train, retain_train, forget_valid, retain_valid) = torch.load(mu_data_path)
    print('\033[33m DATA LOADED .. \033[0m')
    
else:
    forget_train, retain_train = torch.utils.data.random_split(
        trainset, [int(len(trainset)*args.forget_perc), len(trainset)-int(len(trainset)*args.forget_perc)]
    )
    forget_valid, retain_valid = forget_train, validset
    torch.save((forget_train, retain_train, forget_valid, retain_valid), mu_data_path)
        
forget_valid_dl = DataLoader(forget_valid, args.b, shuffle=True)
retain_valid_dl = DataLoader(retain_valid, args.b, shuffle=True)
forget_train_dl = DataLoader(forget_train, args.b, shuffle=True)
retain_train_dl = DataLoader(retain_train, args.b, shuffle=True)




if args.method in ['badteacher','euk','cfk' ]:
    # unlearning_teacher = getattr(models, args.net)(num_classes=classes)
    unlearning_teacher = getattr(models, args.net)(num_classes=classes)
    unlearning_teacher = unlearning_teacher.cuda()
else:
    unlearning_teacher = None

base_hyper = {
    "model": net,
    "unlearning_teacher": unlearning_teacher,
    "retain_train_dl": retain_train_dl,
    "retain_valid_dl": retain_valid_dl,
    "forget_train_dl": forget_train_dl,
    "forget_valid_dl": forget_valid_dl,
    "valid_dl": valid_dl,
    "device":'cuda',
    "lr": args.lr,
    "batch_size": args.b, 
    "epochs": args.epochs,
    "num_classes": classes,
    # "forget_class": args.forget_class,
    "forget_perc": args.forget_perc,

    "dataset":args.dataset,
}

hyper_params.update(base_hyper)

(   retain_train_acc, 
    retain_valid_acc, 
    forget_train_acc, 
    forget_valid_acc, 
    total_valid_acc, 
    mia), time_elapsed = getattr(method, args.method)(**hyper_params)
torch.save(net.state_dict(), save_path)

print(f"retain train acc:{retain_train_acc}|valid acc:{retain_valid_acc}||forget train acc:{forget_train_acc}|forget valid acc:{forget_valid_acc}||total test acc:{total_valid_acc}|mia:{mia}")


columns = ['Method', 'Seed', "RetainTrainAcc", "RetainValidAcc", "ForgetTrainAcc", "ForgetValidAcc", "TestAcc", "MIA", "MethodTime", "Hyper-params"]
filename = f'results/{task}-{args.net}-{args.dataset}-forget-{args.forget_perc}.csv'
if os.path.exists(filename): results_df = pd.read_csv(filename)
else: results_df = pd.DataFrame()
new_row = pd.DataFrame({'Method': [args.method],
                        'Seed': [args.seed],
                        "RetainTrainAcc": [retain_train_acc],
                        "RetainValidAcc": [retain_valid_acc], 
                        "ForgetTrainAcc": [forget_train_acc], 
                        "ForgetValidAcc": [forget_valid_acc], 
                        "TestAcc": [total_valid_acc], 
                        "MIA": [mia], 
                        "MethodTime": [time_elapsed],
                        "Hyper-params":[str(ori_hyper)]})
new_row.to_csv(filename, mode='a', index=False, header=not os.path.exists(filename))
