import sys
sys.path.append('..')

import logging
import os
import argparse
import shutil
import json

import torch
from torch import optim, nn

from utils.SVC_MIA import SVC_MIA
from utils.warmup import warmup_lr
from utils.eval import accuracy
from utils.seed import set_seed
from utils.data import get_data, get_transformations, get_dataloaders
from utils.model import get_model

'''Argumants'''
parser = argparse.ArgumentParser()

# Name
parser.add_argument("--name", type=str, default='retrain', help="Name of the experiment")

# Device, Seed, and trials
parser.add_argument("--device", type=int, default=0, help="GPU device id")
parser.add_argument("--seed", type=int, default=0, help="Seed")
parser.add_argument("--num_trials", type=int, default=10, help="Number of experiment trials ")
parser.add_argument("--forget_data_ratio", type=float, default=0.1, help="Percentage of data to be forgotten")

# Dataset
parser.add_argument("--data_path", type=str, default='data', help="Specify the data direction")
parser.add_argument("--dataset", type=str, default='CIFAR10',  help="Specify the dataset [`CIFAR10`, `CIFAR100`, `TinyImageNet`]")

# Model Type
parser.add_argument("--model_type", type=str, default="resnet18", help="[`resnet18`, `vgg16`, `vit`]")

# Training
parser.add_argument("--num_epochs", type=int, default=182, help="Total number of epochs")
parser.add_argument("--num_workers", type=int, default=2, help="Number of workers")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size")
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate")
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum")
parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight decay")
parser.add_argument("--warmup", type=int, default=0, help="LR warmup")
parser.add_argument("--transformation", type=str, default="CHN", help="[`CHJGN`]")

# Logging
parser.add_argument("--log_interval", type=int, default=2, help="Intervals of logging valication accuracy")

args = parser.parse_args()

'''Setting up Log Directory'''
output_dir = os.path.join('results',
                        f'{args.name}',
                        f'{args.dataset}', # 
                        f'model_{args.model_type}',
                        f'lr_{args.lr}',
                        f'batch_size_{args.batch_size}',
                        f'warmup_{args.warmup}',
                        f'epochs_{args.num_epochs}',
                        f'seed_{args.seed}',
                        f'num_trials_{args.num_trials}',
                    )
os.makedirs(output_dir, exist_ok=True)

'''Setting up logger'''
print(f"| logging to {output_dir + '/log.txt'}")
logging.basicConfig(level=logging.DEBUG, filename=output_dir +
                    "/log.txt", filemode="w", format='%(message)s')
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.info(vars(args))
logger.info('output_dir: ' + output_dir)

'''Saving argumants and the python file in log directory'''
argsfile = os.path.join(output_dir, 'args.json')
with open(argsfile, 'w', encoding='UTF-8') as f:
    json.dump(vars(args), f) # Save the arguments in a json file
shutil.copy(__file__, f'{output_dir}/script.py') # Save the python file

'''Set Device'''
device = torch.device(f'cuda:{args.device}')

'''Checkpoints'''
models_folder = os.path.join(output_dir, 'models')

start_trial = 0
orig_seed = args.seed
for trial in range(start_trial, args.num_trials):
    
    '''Set Seed'''
    args.seed = orig_seed + trial*5
    set_seed(args.seed)
    
    logger.info(f"| Trial {trial+1}/{args.num_trials} - Seed: {args.seed}")
    
    '''Set Image Size'''
    if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100':
        args.image_size = 32
    elif args.dataset == 'TinyImageNet':
        args.image_size = 64
        
    '''Get Transformation Distribution'''
    transformations = get_transformations(args)
    
    '''Get Data'''
    datasets = get_data(args, transformations['to_tensor'])
    
    '''Get DataLoaders'''
    data_loaders = get_dataloaders(args, datasets)
    
    '''Get Model'''
    model = get_model(args, args.image_size, datasets['num_classes'])
    model.to(device)

    '''Setup CE Loss'''
    ce_criterion = nn.CrossEntropyLoss()
    
    '''Setup Optimizer'''
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    
    '''Setup Scheduler'''
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.5*args.num_epochs),int(0.75*args.num_epochs)])

    '''Setting up a csv file to log the results'''
    logfile_csv_test = os.path.join(output_dir, f'log_test_{trial}.csv')    
    with open(logfile_csv_test, 'w', encoding='UTF-8') as f:
        f.write('round, retain_accuracy, unlearn_accuracy, test_accuracy, mia\n')   # clear up the csv file
    f_log = open(logfile_csv_test, 'a', encoding='UTF-8')   

    ''' Training Iterations'''
    for epoch in range(1, args.num_epochs+1):
        logger.info(f'| Trial {trial+1}/{args.num_trials} - starting epoch: {epoch}')

        # Warmup
        if epoch < args.warmup:
            warmup_lr(epoch, epoch + 1, optimizer, one_epoch_step=len(data_loaders['retain']), args=args)

        # Fine-tune using retain data
        model.train()
        for images, targets in data_loaders['retain']:
            images, targets = images.to(device), targets.to(device)

            args.num_samples = int(images.shape[0])          

            images = transformations['train'](images)

            outputs = model(images) 

            supervised_loss = ce_criterion(outputs, targets).to(device) # CE loss
            
            loss = supervised_loss

            optimizer.zero_grad()
            loss.backward()         
            optimizer.step()           
            
        scheduler.step()
          
        # Log the accuracy and save checkpoint
        if (epoch % args.log_interval == 0) or (epoch == args.num_epochs):
            rt_acc = accuracy(model, data_loaders['retain'], device, show=False, transform=transformations['test'])
            ft_acc = accuracy(model, data_loaders['forget'], device, show=False, transform=transformations['test'])
            unlearn_acc = 100 - ft_acc
            test_acc = accuracy(model, data_loaders['test'], device, show=False, transform=transformations['test'])

            # MIA
            shadow_train = torch.utils.data.Subset(datasets['retain'], list(range(len(datasets['test']))))
            shadow_train_loader = torch.utils.data.DataLoader(shadow_train, batch_size=args.batch_size, shuffle=False)
            mia = SVC_MIA(
                shadow_train=shadow_train_loader,
                shadow_test=data_loaders['test'],
                target_train=None,
                target_test=data_loaders['forget'],
                model=model,
                transform=transformations['test'],
            )

            # Log accuracies
            logger.info (f"| Retain accs: {round(rt_acc, 3)}")          
            logger.info (f"| Forget accs: {round(unlearn_acc, 3)}")     
            logger.info (f"| Test accs: {round(test_acc, 3)}")    
            logger.info (f"| MIA-Efficacy: {mia*100:.3f}")     
            
            w_string = f'{epoch},\"{round(rt_acc, 3)}\",\"{round(unlearn_acc, 3)}\",\"{round(test_acc,3)}\",\"{mia*100:.3f}\"\n'
            f_log.write(w_string)
            f_log.flush()
            
            # Save checkpoint at last epoch
            if epoch == args.num_epochs: 
                os.makedirs(models_folder, exist_ok=True)
                torch.save({'epoch': epoch, 
                            'retain_acc': rt_acc,
                            'unlearn_acc': unlearn_acc,
                            'test_acc': test_acc,
                            'mia': f"{mia*100:.3f}",
                            'state_dict': model.state_dict(),
                            'trial': trial,
                            'seed': args.seed,
                            }, 
                            os.path.join(models_folder, f'model_t{trial}_epoch{epoch}.pth'))
    f_log.close()
