import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
import sys
import os

sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import EarlyStopping, create_pseudo_labels, compute_epoch_grad, compute_grad, \
    compute_epoch_hessian_vector_product, compute_epoch_hessian, evaluate_accuracy, clone_and_freeze_model
from models import CustomModel
import time


def sequential_unlearning(train_dataset, indices, subset_indexs, T, eta, alpha, batch_size, epochs, num_classes,
                          model_type, resume_model_path, type='real-world', early_stopping=True, save_interval=1,
                          ablation='all', device="cuda:0", save_path="", test_loader=None):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    # initialize
    origin_model = CustomModel(model_name=model_type, num_classes=num_classes, pretrained=False,
                               model_path=resume_model_path).to(device)
    loss_fn = nn.CrossEntropyLoss()

    prev_theta_star = [p.clone().detach() for p in origin_model.parameters()]
    indices = indices
    R0_size = len(indices)
    Rt_indices = copy.deepcopy(indices)

    prev_F_t_1_indices = []
    prev_total_loss = torch.tensor(0.0).to(device)
    prev_mean_loss = torch.tensor(0.0).to(device)
    total_F_t_size = 0

    T_opt_models = []

    for t in range(1, T + 1):

        start = time.time()
        
        ori_model = copy.deepcopy(origin_model)
        ori_model.train()

        # get unlearning samples id
        Ft_indices = subset_indexs[t - 1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])
        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=True)
        pseudo_labels = create_pseudo_labels(num_classes, batch_size).to(device)
        Ft_size = len(Ft_indices)
        total_F_t_size += Ft_size

        # initialize sub_model
        sub_model = CustomModel(model_name=model_type, num_classes=num_classes, pretrained=False,
                                model_path=resume_model_path).to(device)
        sub_optimizer = optim.SGD(sub_model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)

        if early_stopping_tag:
            early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)

        # get optimal model theta_t_f
        sub_model.train()
        for epoch in range(epochs):
            total_epoch_loss = torch.tensor(0.0).to(device)
            samples = 0
            for inputs, _ in Ft_loader:
                inputs = inputs.to(device)
                sub_optimizer.zero_grad()
                outputs = sub_model(inputs)
                batch_loss_mean = loss_fn(outputs, pseudo_labels[:outputs.size(0)])
                total_epoch_loss += batch_loss_mean * inputs.size(0)
                samples += inputs.size(0)
                batch_loss_mean.backward()
                sub_optimizer.step()
            mean_epoch_loss = total_epoch_loss / samples
            print(
                f'Training optimal sub-model F{t}----------epoch: {epoch}---------total epoch loss: {total_epoch_loss}---------mean epoch loss: {mean_epoch_loss}')

            if early_stopping_tag:
                # early stop 
                early_stopping(mean_epoch_loss.item())
                if early_stopping.early_stop:
                    print(f'Early stopping at epoch: {epoch}')
                    break

        theta_T_F = [p.clone().detach() for p in sub_model.parameters()]
        print(f'--------------------Finish training {t}-th sub-model !--------------------')

        # loss_compute_type indicates whether the loss used to compute the Hessian matrix is total loss or average loss, 
        # where average loss is better.
        loss_compute_type = 'mean'

        # compute grad l_t_f
        grad_L_F_t, total_loss_F_t, _ = compute_epoch_grad(Ft_loader, ori_model, loss_fn, loss_compute_type,
                                                           pseudo_labels)
        ori_model.zero_grad()
        print(f'--------------------Finish computing grad of LFt !--------------------')

        # compute hessian
        total_loader = DataLoader(Subset(train_dataset, indices), batch_size=batch_size, shuffle=False)

        # type indicates the form of the method implementation, one for the ideal case and one for the actual case
        if type == 'dream':
            if t == 1:
                term_1 = [torch.zeros_like(p) for p in ori_model.parameters()]
                hessians = compute_epoch_hessian(ori_model, total_loader, loss_fn, loss_type=loss_compute_type,
                                                 num_samples=len(indices))
                h0 = [torch.inverse(h) for h in hessians]
            else:
                if loss_compute_type == 'sum':
                    grad_g_t_1_R = compute_grad(prev_total_loss, list(ori_model.parameters()))
                else:
                    grad_g_t_1_R = compute_grad(prev_mean_loss, list(ori_model.parameters()))
                term_1 = [(Ft_size / R0_size) * torch.matmul(h, gtr) for h, gtr in zip(h0, grad_g_t_1_R)]
            ori_model.zero_grad()
            term_2 = [(total_F_t_size / R0_size) * torch.matmul(h, lft) for h, lft in zip(h0, grad_L_F_t)]
            ori_model.zero_grad()
        else:
            # samples_compute_hvp denotes the samples used to compute the Hessian matrix, 
            # ideally all the training data would be used, but in practice it can be scaled down due to video memory issues.
            # samples_compute_hvp = 100
            samples_compute_hvp = len(indices)
            if t == 1:
                term_1 = [torch.zeros_like(p) for p in ori_model.parameters()]
            else:
                if loss_compute_type == 'sum':
                    grad_g_t_1_R = compute_grad(prev_total_loss, list(ori_model.parameters()))
                else:
                    grad_g_t_1_R = compute_grad(prev_mean_loss, list(ori_model.parameters()))
                ori_model.zero_grad()

                interval_1 = time.time()

                hvp_1 = compute_epoch_hessian_vector_product(ori_model, total_loader, loss_fn, grad_g_t_1_R,
                                                             loss_type=loss_compute_type,
                                                             num_samples=samples_compute_hvp)

                interval_2 = time.time()

                term_1 = [(Ft_size / R0_size) * grad for grad in hvp_1]
                ori_model.zero_grad()

            interval_3 = time.time()

            hvp_2 = compute_epoch_hessian_vector_product(ori_model, total_loader, loss_fn, grad_L_F_t,
                                                         loss_type=loss_compute_type, num_samples=samples_compute_hvp)
            interval_4 = time.time()

            term_2 = [(total_F_t_size / R0_size) * grad for grad in hvp_2]
            ori_model.zero_grad()
        print(f'--------------------Finish computing grad of tem1 and term2 !--------------------')

        # ablation denotes the ablation experiment of our method, where 'all' denotes no ablation, 
        # 'only-one-and-two' denotes that only the first and second terms are retained, and so on for the rest of them
        if ablation == 'all':
            a1 = t / (t + 1)
            a2 = 1.0 / (t + 1)
            a3 = alpha / (1 + t)
        elif ablation == 'only-one-and-two':
            a1 = t / (t + 1)
            a2 = 1.0 / (t + 1)
            a3 = 0.0
        elif ablation == 'only-two-and-three':
            a1 = 0.0
            a2 = 1.0
            a3 = alpha / (1 + t)
        elif ablation == 'only-two':
            a1 = 0.0
            a2 = 1.0
            a3 = 0.0
        # aggregation the optimal model
        with torch.no_grad():
            for param, ps, tf, t1, t2 in zip(list(ori_model.parameters()), prev_theta_star, theta_T_F, term_1, term_2):
                param.copy_(a1 * ps + a2 * tf + a3 * t1 + a3 * t2)

        print(f'--------------------Finish aggregating parameters !--------------------')

        end = time.time()

        if t == 1:
            efficiency = end - start - (interval_4 - interval_3)
        else:
            efficiency = end - start - (interval_4 - interval_3) - (interval_2 - interval_1)
        
        # save current time optimal model
        if t % save_interval == 0:
            T_opt_models.append(copy.deepcopy(ori_model).cpu())
            print(f'Optimal model of time {t} has saved.')

        # save current optimal model param
        prev_theta_star = [p.clone().detach() for p in ori_model.parameters()]

        # evaluate
        print(
            f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len(Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_accuracy(ori_model, Rt_loader)
        Acc_Ft = evaluate_accuracy(ori_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            F_t_1_indices = prev_F_t_1_indices
            Acc_F_t_1 = evaluate_accuracy(ori_model,
                                          DataLoader(Subset(train_dataset, F_t_1_indices), batch_size=batch_size,
                                                     shuffle=False))

        # update previous time param
        prev_F_t_1_indices.extend(Ft_indices)
        prev_total_loss += total_loss_F_t
        prev_mean_loss = prev_total_loss / (len(prev_F_t_1_indices))

        Acc_test = evaluate_accuracy(ori_model, test_loader)

        print(f"Ours time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}, time: {efficiency:.4f}")

    current_model_path = save_path + '-ours' + '.pth'
    torch.save(T_opt_models, current_model_path)
    return T_opt_models
