import random
import os
import json
import argparse
import numpy as np
import pandas as pd 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import datasets
import models
import method
from argparse import Namespace
import wandb

task = 'full'
parser = argparse.ArgumentParser()
# param for MU
parser.add_argument("-seed", type=int, default=0, help="seed for runs")
parser.add_argument("-net", type=str, required=True, help="net type")
parser.add_argument("-forget_class",type=int)
parser.add_argument(
    "-dataset",
    type=str,
    required=True,
    nargs="?",
    choices=["Cifar10", "Cifar20", "Cifar100", "PinsFaceRecognition", "TinyImagenet", "Svhn"],
    help="dataset to train on",
)

# hyper-param for MU
parser.add_argument("-method",type=str,required=True)
parser.add_argument("-b", type=int, default=64)
parser.add_argument("-lr", type=float, default=0.1)
parser.add_argument("-epochs", type=int, default=1)


args, hyper_params_args = parser.parse_known_args()
config = Namespace(
    project_name=f'SWEEP-{args.method}-{args.dataset}-{args.net}-{args.forget_class}',
    net=args.net,
    dataset=args.dataset,
    b=args.b,
    method=args.method,
    epoch=args.epochs,
    forget_class=args.forget_class,
    lr=args.lr,
    seed=args.seed
)

def unlearn(config=config):
    wandb.init()
    config = wandb.config
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    
    if config.net == "ViT":
        dampening_constant = 1
        selection_weighting = 3.5
    elif config.net == 'VGG':
        dampening_constant = 4
        selection_weighting = 10
    else:
        dampening_constant = 1
        selection_weighting = 10

    hyper_params ={
        # "selection_weighting": config.selection_weighting, 
        # "dampening_constant": config.dampening_constant,
        "lipschitz_weighting":config.lipschitz_weighting,
        "dampening_constant": dampening_constant,
        "selection_weighting": selection_weighting,
    }
    ori_hyper = hyper_params.copy()
    ori_hyper['lr'], ori_hyper['epochs']  = config.lr, config.epochs
    

    img_size = 32
    if config.dataset.startswith('Cifar'): classes = int(config.dataset[5:]) 
    elif config.dataset== 'PinsFaceRecognition': classes, img_size = 105, 64
    elif config.dataset == 'TinyImagenet': classes, img_size = 200, 64
    elif config.dataset == 'Svhn': classes = 10
    if config.net == "ViT": img_size = 224 

    weight_path = f'ckp/{config.net}-{config.dataset}-retrain-{config.seed}.pth' if config.method == 'retrain' else f'ckp/{config.net}-{config.dataset}.pth'
    net = getattr(models, config.net)(num_classes=classes)
    net = net.cuda()
    net.load_state_dict(torch.load(weight_path))


    hyper_params['retain_sur'] = datasets.get_surrogate(config, hyper_params, num_classes = classes, original_model=net) if config.method == 'scar' else None
    trainset = getattr(datasets, config.dataset)(root='data/', download=True, train=True, unlearning=True, img_size=img_size)
    validset = getattr(datasets, config.dataset)(root='data/', download=True, train=False, unlearning=True, img_size=img_size)
    train_dl = DataLoader(trainset,  batch_size=config.b,shuffle=True)
    valid_dl = DataLoader(validset, batch_size=config.b, shuffle=False)


    os.makedirs(f'data/MU-DATASET/{config.net}', exist_ok = True)
    mu_data_path = f'data/MU-DATASET/{config.net}/{task}-{config.dataset}-{config.forget_class}.pth'

    if os.path.exists(mu_data_path): 
        print('\033[33m DATA LOADED .. \033[0m')
        (forget_train, retain_train, forget_valid, retain_valid) = torch.load(mu_data_path)
    else:
        forget_train, retain_train = datasets.build_retain_forget_sets(trainset, config.forget_class)
        forget_valid, retain_valid = datasets.build_retain_forget_sets(validset, config.forget_class)
        torch.save((forget_train, retain_train, forget_valid, retain_valid), mu_data_path)
            
    forget_valid_dl = DataLoader(forget_valid, config.b, shuffle=True)
    retain_valid_dl = DataLoader(retain_valid, config.b, shuffle=True)
    forget_train_dl = DataLoader(forget_train, config.b, shuffle=True)
    retain_train_dl = DataLoader(retain_train, config.b, shuffle=True)


    if args.method in ['scrub', 'badteacher']:
        unlearning_teacher = getattr(models, config.net)(num_classes=classes)
        unlearning_teacher = unlearning_teacher.cuda()
    else:
        unlearning_teacher = None

    base_hyper = {
        "model": net,
        "unlearning_teacher": unlearning_teacher,
        "retain_train_dl": retain_train_dl,
        "retain_valid_dl": retain_valid_dl,
        "forget_train_dl": forget_train_dl,
        "forget_valid_dl": forget_valid_dl,
        "valid_dl": valid_dl,
        "device":'cuda:0',
        "lr": config.lr,
        "epochs": config.epochs,
        "num_classes": classes,
        "forget_class": config.forget_class,
        "dataset":config.dataset,
    }

    hyper_params.update(base_hyper)

    (   retain_train_acc, 
        retain_valid_acc, 
        forget_train_acc, 
        forget_valid_acc, 
        total_valid_acc, 
        mia), time_elapsed = getattr(method, config.method)(**hyper_params)

    print(f"retain train acc:{retain_train_acc}|valid acc:{retain_valid_acc}||forget train acc:{forget_train_acc}|forget valid acc:{forget_valid_acc}||total test acc:{total_valid_acc}|mia:{mia}")
    wandb.log({'RetainTestAcc':retain_valid_acc, 'ForgetValidAcc':forget_valid_acc, })

    os.makedirs('results/sweep/', exist_ok=True)
    columns = ['Method', 'Seed', "RetainTrainAcc", "RetainValidAcc", "ForgetTrainAcc", "ForgetValidAcc", "TestAcc", "MIA", "MethodTime", "Hyper-params"]
    filename = f'results/sweep/fullclass-{config.net}-{config.dataset}-forget-{config.forget_class}.csv'
    if os.path.exists(filename): results_df = pd.read_csv(filename)
    else: results_df = pd.DataFrame()
    new_row = pd.DataFrame({'Method': [config.method],
                            'Seed': [config.seed],
                            "RetainTrainAcc": [retain_train_acc],
                            "RetainValidAcc": [retain_valid_acc], 
                            "ForgetTrainAcc": [forget_train_acc], 
                            "ForgetValidAcc": [forget_valid_acc], 
                            "TestAcc": [total_valid_acc], 
                            "MIA": [mia], 
                            "MethodTime": [time_elapsed],
                            "Hyper-params":[str(ori_hyper)]})
    new_row.to_csv(filename, mode='a', index=False, header=not os.path.exists(filename))
    
    
    
    
    

sweep_config = {}
sweep_config['method'] = 'random'
sweep_config['metric'] = {'name':'RetainTestAcc','goal':'maximize'}
sweep_config['parameters'] = {}
 



sweep_config['parameters'].update({
    'net':{'value':args.net},
    'dataset':{'value':args.dataset},
    'b':{'value':args.b},
    'method':{'value':args.method},
    'forget_class':{'value':args.forget_class},
    'epochs':{'value':args.epochs},
    'seed':{'value':args.seed},
    })

if args.net == 'ResNet18':
    min_lr, max_lr, min_lip, max_lip = 1e-5, 5e-4, 1e-5, 0.5
elif args.net == 'ViT':
    min_lr, max_lr, min_lip, max_lip = 1e-4, 1e-2, 0, 1

sweep_config['parameters'].update({
    'lr': {
        'distribution': 'log_uniform_values',
        'min': min_lr,
        'max': max_lr
    },
    'lipschitz_weighting': {
        'distribution': 'uniform',
        'min': min_lip,
        'max': max_lip,
    },
})


# eta = 0.5, sigma = 0.8
wandb.login(key ="a4b0dace8a600b759e5b95b4cd1b7552e0c9ab8d" )
sweep_id = wandb.sweep(sweep_config, project = config.project_name)
wandb.agent(sweep_id, function = unlearn, count = 50 )
