import os
import torch
import copy
import yaml
import argparse
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
from models import CustomModel
from utils import split_indices, EarlyStopping, load_datasets, plt_param_compare
from baseline.aaai import aaai_baseline
from baseline.influence import influence_baseline
from baseline.sisa import sisa_baseline
from baseline.ours import sequential_unlearning
from baseline.original import original_baseline
from baseline.retrain import retrain_baseline


def pre_origin_model(train_dataset, batch_size, eta, epochs, num_classes,
                     model_type, save_path, pretrained_tag=True, early_stopping=True, device='cuda:0'):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    if early_stopping_tag:
        early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)

    ori_model = CustomModel(model_name=model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)

    optimizer = optim.SGD(ori_model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)
    loss_fn = nn.CrossEntropyLoss()
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    ori_model.train()
    for epoch in range(epochs):
        epoch_loss = torch.tensor(0.0).to(device)
        samples = 0
        for inputs, targets in tqdm(loader, desc="Training original model progress"):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = ori_model(inputs)
            # set default reduction = 'mean'
            batch_loss_mean = loss_fn(outputs, targets)
            epoch_loss += batch_loss_mean * inputs.size(0)
            samples += inputs.size(0)
            batch_loss_mean.backward()
            optimizer.step()
        mean_epoch_loss = epoch_loss / samples
        print(
            f'Prepare original model----------epoch: {epoch}---------total epoch loss: {epoch_loss}---------mean epoch loss: {mean_epoch_loss}')

        if early_stopping_tag:
            early_stopping(mean_epoch_loss.item())
            if early_stopping.early_stop:
                print(f'Early stopping at epoch: {epoch}')
                break

        # update learning rate
        # scheduler.step()

    # save model
    torch.save(ori_model.model.state_dict(), save_path)
    print(f"Model saved to {save_path}")


def args_parser(yaml_file_path):
    with open(yaml_file_path, 'r') as file:
        cfg = yaml.safe_load(file)

    class Args:
        def __init__(self, cfg):
            self.dataset = cfg.get('dataset', 'cifar10')
            self.model_type = cfg.get('model_type', 'resnet18')
            self.T = cfg.get('T', 10)
            self.eta = cfg.get('eta', 0.01)
            self.alpha = cfg.get('alpha', 0.1)
            self.batch_size = cfg.get('batch_size', 128)
            self.epochs = cfg.get('epochs', 5)
            self.model_save_path = cfg.get('model_save_path', 'origin_model')
            self.fig_save_path = cfg.get('fig_save_path', 'figs')
            self.split_type = cfg.get('split_type', 'class')
            self.num_indices_per_f = cfg.get('num_indices_per_f', '0')
            self.num_submodel = cfg.get('num_submodel', '5')
            self.pretrain_tag = cfg.get('pretrain_tag', False)
            self.early_stopping_tag = cfg.get('early_stopping_tag', True)
            self.ablation_tag = cfg.get('ablation_tag', 'all')
            self.device = cfg.get('device', 'cuda:0')

    return Args(cfg)


def del_cache(model_list):
    for model in model_list:
        del model
    torch.cuda.empty_cache()
    model_list = []


if __name__ == "__main__":
    yaml_file_path = 'configs/test.yml'
    args = args_parser(yaml_file_path)

    # print process id
    pid = os.getpid()
    print('--------------------------------------------------')
    print(f"Current Process ID: {pid}")

    print('----------------Current Setting-------------------')
    for key, value in vars(args).items():
        print(f"{key}: {value}")
    print('--------------------------------------------------')

    train_dataset, _, test_loader, num_classes = load_datasets(args.dataset, args.model_type.lower())
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device(args.device)
    indices = np.arange(len(train_dataset))

    random_indices, subset_indexs = split_indices(indices, args.T, split_type=args.split_type,
                                                  num_per_subset=args.num_indices_per_f, num_classes=num_classes, train_dataset=train_dataset, 
                                                  max_per=3000)
    print('--------------------------------------------------')
    for i in range(args.T):
        print(f'{i}-th unlearning query has {len(subset_indexs[i])} samples.')
    print('--------------------------------------------------')

    print(f'Current status------dataset name: {args.dataset}-------model name: {args.model_type}')

    pretrain = 'pre' if args.pretrain_tag else 'no-pre'
    current_model_path = os.path.join(args.model_save_path,
                                      pretrain + '-' + args.model_type + '-' + args.dataset + '.pth')
    retrain_model_path = os.path.join(args.model_save_path,
                                      pretrain + '-' + args.model_type + '-' + args.dataset)
    if not os.path.exists(current_model_path):
        pre_origin_model(train_dataset, args.batch_size, args.eta, args.epochs, num_classes, args.model_type,
                         current_model_path, pretrained_tag=args.pretrain_tag, early_stopping=args.early_stopping_tag,
                         device=device)
        print('Original model has saved------------------------')
    else:
        print('Original model existing, exit process------------------------')

    # going_list = ['original', 'retrain', 'sisa', 'aaai', 'influence', 'ours']
    going_list = ['original', 'retrain', 'sisa', 'influence', 'ours']
    # going_list = ['retrain','ours']
    # going_list = ['ours']

    current_fig_path = os.path.join(args.fig_save_path, pretrain + '-' + args.model_type + '-' + args.dataset)
    save_time_interval = 1

    print('------------------------Begin to unlearn-----------------------')
    if 'original' in going_list:
        print('Begin to evaluate original model------------------------')
        original_baseline(train_dataset, indices, subset_indexs, args.T, args.batch_size,
                          num_classes, args.model_type, current_model_path, device=device, test_loader=test_loader)
        print('Original model complete!')
    if 'retrain' in going_list:
        print('Begin retrain------------------------')
        retrain_models = retrain_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size,
                                          args.epochs, num_classes,
                                          args.model_type, pretrained_tag=args.pretrain_tag,
                                          early_stopping=args.early_stopping_tag, save_interval=save_time_interval,
                                          device=device, save_path=retrain_model_path, test_loader=test_loader)
        del_cache(retrain_models)
        print('Retrain complete------------------------')
    if 'sisa' in going_list:
        # paper: 2019 machine unlearning
        print('Begin SISA------------------------')
        sisa_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size, args.epochs,
                      num_classes, args.model_type,
                      split_type=args.split_type, num_model=args.num_submodel, pretrained_tag=args.pretrain_tag,
                      early_stopping=args.early_stopping_tag, device=device, test_loader=test_loader)
        print('SISA complete------------------------')
    if 'influence' in going_list:
        # paper: Remember What You Want to Forget; Algorithms for Machine Unlearning
        print('Begin influence------------------------')
        influence_models = influence_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size,
                                              args.epochs,
                                              num_classes, args.model_type, current_model_path,
                                              save_interval=save_time_interval, device=device, save_path=retrain_model_path, test_loader=test_loader)
        # plt_param_compare(influence_models, retrain_models, current_fig_path, 'influence')
        del_cache(influence_models)
        print('Influence complete------------------------')
    if 'ours' in going_list:
        # paper: overleaf
        print('Begin ours------------------------')
        ours_models = sequential_unlearning(train_dataset, indices, subset_indexs, args.T, args.eta, args.alpha,
                                            args.batch_size,
                                            args.epochs, num_classes, args.model_type, current_model_path,
                                            type='real-world',
                                            early_stopping=args.early_stopping_tag, save_interval=save_time_interval,
                                            ablation=args.ablation_tag, device=device, save_path=retrain_model_path, test_loader=test_loader)
        # plt_param_compare(ours_models, retrain_models, current_fig_path, 'ours')
        del_cache(ours_models)
        print('Ours sequential unlearning complete------------------------')
    if 'aaai' in going_list:
        # paper: Learning to Unlearn: Instance-wise Unlearning for Pre-trained Classifiers
        print('Begin aaai------------------------')
        aaai_models = aaai_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size,
                                    args.epochs,
                                    num_classes, args.model_type, current_model_path,
                                    early_stopping=args.early_stopping_tag, save_interval=save_time_interval,
                                    device=device, save_path=retrain_model_path, test_loader=test_loader)
        # compare param
        # plt_param_compare(aaai_models, retrain_models, current_fig_path, 'aaai')
        del_cache(aaai_models)
        print('AAAI complete------------------------')
    print('----------------All baseline complete !--------------------')
