import os
import torch
import copy
import argparse
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
from utils import EarlyStopping, load_datasets, create_pseudo_labels
from models import CustomModel
from baseline.aaai import aaai_baseline
from baseline.influence import influence_baseline
from baseline.sisa import sisa_baseline
from baseline.ours import sequential_unlearning
from baseline.original import original_baseline
from baseline.retrain import retrain_baseline
from utils import split_indices


def pre_origin_model(train_dataset, batch_size, epochs, num_classes, 
                     model_type, save_path, pretrained_tag=True, early_stopping=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    if early_stopping_tag:
        early_stopping = EarlyStopping(patience=10, verbose=True, delta=4)

    model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=pretrained_tag).to(device)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    loss_fn = nn.CrossEntropyLoss()
    
    loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    pseudo_labels = create_pseudo_labels(num_classes, batch_size).to(device)

    model.train()
    for epoch in range(epochs):
        epoch_loss = torch.tensor(0.0).to(device)
        for inputs, _ in loader:
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # set default reduction = 'mean'
            batch_loss_mean = loss_fn(outputs, pseudo_labels[:outputs.size(0)])
            epoch_loss += batch_loss_mean * len(inputs)
            batch_loss_mean.backward()
            optimizer.step()
        print(f'Prepare original model----------epoch: {epoch}---------total epoch loss: {epoch_loss}')

        if early_stopping_tag:
            # early stop 
            early_stopping(epoch_loss.item())
            if early_stopping.early_stop:
                print(f'Early stopping at epoch: {epoch}')
                break

    # save model
    torch.save(model.state_dict(), os.path.join(save_path, 'model.pth'))
    print(f"Model saved to {os.path.join(save_path, 'model.pth')}")


def args_parser():
    parser = argparse.ArgumentParser(description='Sequential Unlearning')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'celeba', 'mini-fashion'],
                        help='dataset name (default: cifar10)')
    parser.add_argument('--model_type', type=str, default='resnet18', choices=['resnet18', 'vit', 'vgg'],
                        help='model type (default: resnet18)')
    parser.add_argument('--T', type=int, default=10, help='Number of time points (default: 10)')
    parser.add_argument('--eta', type=float, default=0.01, help='Learning rate for sub-model (default: 0.01)')
    parser.add_argument('--alpha', type=float, default=0.1, help='Coefficient for Hessian vector product (default: 0.1)')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size (default: 128)')
    parser.add_argument('--epochs', type=int, default=5, help='Number of epochs for sub-model training (default: 5)')
    parser.add_argument('--save_path', type=str, default='origin-model/no-pre-train', help='Save path of original model')
    return parser.parse_args()


if __name__ == "__main__":
    args = args_parser()
    train_dataset, _, _, num_classes = load_datasets(args.dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    indices = np.arange(len(train_dataset))
    
    indices, subset_indexs = split_indices(indices, args.T, split_type='class', 
                                           num_classes=num_classes, train_dataset=train_dataset, cover_flag=False)
    # np.random.shuffle(indices)

    print(f'Current status------dataset name: {args.dataset}-------model name: {args.model_type}')

    #F_indices = [indices[(t-1)*1000 : t*1000] for t in range(1, args.T+1)]
    #R_indices = np.setdiff1d(indices, np.concatenate(F_indices))
    pretrain_tag = False
    early_stopping_tag = True
    if not os.path.exists(args.save_path):
        pre_origin_model(train_dataset, indices, subset_indexs, args.batch_size, args.epochs, num_classes, 
                         args.model_type, args.save_path, pretrained_tag=pretrain_tag, early_stopping=early_stopping_tag)
        print('Original model has saved !')
    else:
        print('Original model existing, exit process !')

    going_list = ['original','retrain','ours','sisa','aaai','influence']

    if 'original' in going_list:
        original_baseline(train_dataset, indices, subset_indexs, args.T, args.batch_size, 
                          num_classes, args.model_type, args.save_path)
    if 'retrain' in going_list:
        retrain_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size, args.epochs, num_classes, 
                         args.model_type, pretrained_tag=pretrain_tag, early_stopping=early_stopping_tag)
    if 'sisa' in going_list:
        sisa_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size, args.epochs, 
                      num_classes, args.model_type, pretrained_tag=pretrain_tag, early_stopping=early_stopping_tag)
    if 'aaai' in going_list:
        aaai_baseline(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size, args.epochs, 
                      num_classes, args.model_type, args.save_path, early_stopping=early_stopping_tag)
    if 'influence' in going_list:
        influence_baseline(train_dataset, indices, args.T, args.eta, args.batch_size, args.epochs, 
                           num_classes, args.model_type, args.save_path)
    if 'ours' in going_list:
        ablation_tag = 'all'
        sequential_unlearning(train_dataset, indices, subset_indexs, args.T, args.eta, args.batch_size, 
                              args.epochs, num_classes, args.model_type, args.save_path, type='real-world', 
                              early_stopping=early_stopping_tag, ablation=ablation_tag)

    #baseline_model = retrain_baseline(train_dataset, R_indices, device, num_classes)
    #Acc_baseline = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, R_indices), batch_size=128, shuffle=False))
