import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset, TensorDataset
from tqdm import tqdm
import sys
import os
sys.path.append(os.path.abspath('fxh/seq_unlearn'))
from utils import EarlyStopping, evaluate_accuracy, clone_and_freeze_model
from models import CustomModel
import time

def l2_pgd_attack(model, x, y, epsilon=0.3, alpha=0.01, iters=40):
    """
    Perform an L2 Projected Gradient Descent (PGD) attack on an input batch.

    This function applies an L2 PGD attack to generate adversarial examples that maximize the loss of the model `model` for the input batch `x` against the true labels `y`. The adversarial perturbation is constrained within an L2 norm ball defined by `epsilon`.

    Parameters:
        model (torch.nn.Module): The neural network model to attack.
        x (torch.Tensor): The input batch of data.
        y (torch.Tensor): The true labels corresponding to `x`.
        epsilon (float, optional): The maximum L2 norm of the allowable perturbation. Default is 0.3.
        alpha (float, optional): The step size (learning rate) for each iteration of the attack. Default is 0.01.
        iters (int, optional): The number of PGD iterations to perform. Default is 40.

    Returns:
        torch.Tensor: The perturbed input batch which are adversarial examples.
    """
    # Initialize the perturbation
    delta = torch.zeros_like(x).to(x.device)
    for _ in range(iters):
        delta.requires_grad = True
        output = model(x + delta)
        loss = nn.CrossEntropyLoss()(output, y)
        loss.backward()
        # L2 norm ball projection
        delta.data = (delta + alpha * delta.grad.detach().sign()).clamp(-epsilon, epsilon)
        delta.data = (delta / delta.norm()) * min(delta.norm(), epsilon)
        delta.grad.zero_()
    return x + delta.detach()

def generate_adversarial_examples(model, loader, num_class):
    """
    Generate adversarial examples using an L2 PGD attack for a given model and data loader.

    This function processes batches of data from `loader`, using the function `l2_pgd_attack` to generate adversarial examples for each batch. These adversarial examples are crafted by intentionally misclassifying the input by targeting a randomly chosen incorrect class. The function respects the original training state of the model, ensuring that it is restored after adversarial example generation.

    Parameters:
        model (torch.nn.Module): The neural network model for which adversarial examples are to be generated.
        loader (torch.utils.data.DataLoader): The data loader providing batches of (input data, true labels).
        num_class (int): The total number of classes in the dataset.

    Returns:
        list of tuple: A list of tuples, each containing an adversarial example and its targeted (incorrect) label.
    """
    # Save the training state of the original model
    was_training = model.training
    model.eval()
    
    adv_data = []
    for x, y in loader:
        x, y = x.to('cuda'), y.to('cuda')
        # Randomly select a different label
        y_adv = (y + 1) % num_class
        x_adv = l2_pgd_attack(model, x, y_adv)
        adv_data.extend(zip(x_adv.cpu().numpy(), y_adv.cpu().numpy()))
        
    # Restore the model to the original training state
    model.train(mode=was_training)
    return adv_data

def compute_weight_importances(model, inputs):
    """
    Compute the importance of each model weight based on the gradients after a forward pass with given inputs.

    This function evaluates the importance of each parameter (weight) in the neural network `model` by performing a forward pass with the provided `inputs` and calculating the gradient of the L2 norm squared of the outputs with respect to each model parameter. The computed gradients are used as an indicator of each parameter's importance in influencing the output of the model.

    Parameters:
        model (torch.nn.Module): The neural network model whose weight importances are to be computed.
        inputs (torch.Tensor): The input data on which the model is evaluated to compute weight importances.

    Returns:
        dict: A dictionary mapping each parameter name to its corresponding importance measure, represented by the absolute value of the accumulated gradients.

    """
    was_training = model.training
    model.train()

    # dictionary to store the gradient of the parameter
    importance = {name: torch.zeros_like(param) for name, param in model.named_parameters()}
    
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = next(model.parameters()).device
    inputs = inputs.to(device)
    outputs = model(inputs)
    
    # Calculate the L2 paradigm square of the output
    norm_sq = (outputs.norm(p=2, dim=1) ** 2).mean()
    
    model.zero_grad()
    norm_sq.backward()
    
    # Accumulate the gradient for each parameter
    for name, param in model.named_parameters():
        importance[name] += param.grad.data.abs()
    
    model.train(mode=was_training)
    return importance

def normalize_importances(importances):
    min_val = torch.min(importances)
    max_val = torch.max(importances)
    return (importances - min_val) / max((max_val - min_val), 1e-8)

def measure_weight_importance(model, inputs):
    """
    Measure and normalize the importance of weights in a neural network model.

    This function computes the initial importance of model weights by invoking the `compute_weight_importances` function on the provided `inputs`. The computed importances are then normalized and adjusted to reflect their relative significance in the model, where a higher value indicates less importance and vice versa.

    Parameters:
        model (torch.nn.Module): The neural network model whose weight importances are to be measured.
        inputs (torch.Tensor): The input data used to evaluate the model and compute weight importances.

    Returns:
        dict: A dictionary where each key corresponds to a parameter name and each value is the adjusted normalized importance of that parameter.

    """
    # Step 1: Calculate initial weight importances using the provided function
    importances = compute_weight_importances(model, inputs)
    
    # Step 2: Initialize the final normalized weight importance dictionary
    normalized_importances = {}
    
    # Step 3: Iterate over each layer's importances
    for name, imp in importances.items():
        # Normalize the importances for this layer
        normalized_imp = normalize_importances(imp)
        
        # Update as 1 - normalized_importance
        normalized_importances[name] = 1 - normalized_imp
    
    return normalized_importances


def aaai_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, epochs, num_classes, 
                  model_type, resume_model_path, early_stopping=True, save_interval=1, device="cuda:0", save_path="", test_loader=None):
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    # initialize
    baseline_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=False, model_path=resume_model_path).to(device)
    loss_fn = nn.CrossEntropyLoss()
    baseline_optimizer = optim.SGD(baseline_model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)

    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []

    # preserve last time's optimal param
    prev_model = copy.deepcopy(baseline_model)

    T_opt_models = []

    baseline_model.train()
    for t in range(1, T+1):
        start = time.time()
        # get unlearning samples id
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])

        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=True)

        random_adv_indices = torch.randperm(len(Ft_indices))[:int(0.1 * len(Ft_indices))]
        generative_loader = DataLoader(Subset(train_dataset, Ft_indices[random_adv_indices]), batch_size=batch_size, shuffle=True)

        if early_stopping_tag:
            early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01)
            
        # aaai-train, conclude two loss terms
        for epoch in range(epochs):
            epoch_loss = torch.tensor(0.0).to(device)
            batch_count = 0
            for inputs, targets in Ft_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                baseline_optimizer.zero_grad()
                
                batch_loss = torch.tensor(0.0).to(device)
                
                # compute l_ul loss
                outputs = baseline_model(inputs)
                loss_ul_mean = -1 * loss_fn(outputs, targets)
                loss_ul_total = -1 * loss_fn(outputs, targets) * inputs.size(0)
                
                # compute l_ce loss
                adv_data = generate_adversarial_examples(baseline_model, generative_loader, num_classes)
                if not adv_data:
                    continue

                adv_examples = torch.tensor([x for x, _ in adv_data], dtype=torch.float32).to(device)
                adv_labels = torch.tensor([y for _, y in adv_data], dtype=torch.long).to(device)
                adv_dataset = TensorDataset(adv_examples, adv_labels)
                adv_loader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=True)

                loss_ce_total = torch.tensor(0.0).to(device)
                ce_samples = 0
                for x_adv, y_adv in adv_loader:
                    x_adv, y_adv = x_adv.to(device), y_adv.to(device)
                    outputs = baseline_model(x_adv)
                    # Accumulated losses
                    loss_ce_total += loss_fn(outputs, y_adv) * x_adv.size(0)
                    ce_samples += x_adv.size(0)
                loss_ce_mean =  loss_ce_total / max(ce_samples, 1e-8)

                # compute reg loss
                importance_weights = measure_weight_importance(baseline_model, inputs)
                loss_reg_total = torch.tensor(0.0).to(device)
                for ((_, weight_a), (_, param_b)), (_, param_a) in zip(zip(importance_weights.items(), baseline_model.named_parameters()), prev_model.named_parameters()):
                    if param_a.shape == param_b.shape:
                        # Calculate the squared difference of parameters
                        diff_square = (param_b - param_a) ** 2
                        # Convert the weight to the appropriate device
                        weight_a_tensor = weight_a.to(device)
                        # Use the corresponding importance weight for regularization
                        loss_reg_total += torch.sum(diff_square * weight_a_tensor)
                    else:
                        raise ValueError("Parameters shapes do not match between the two models.")
                loss_reg_mean = loss_reg_total / inputs.size(0)
                
                loss_compute_type = 'mean'
                if loss_compute_type == 'sum':
                    batch_loss = loss_ul_total + loss_ce_total + loss_reg_total
                    print(f'Loss of batch is: {batch_loss}, unlearning loss is: {loss_ul_total}, ce loss is: {loss_ce_total}, reg loss is: {loss_reg_total}')
                else:
                    batch_loss = loss_ul_mean + loss_ce_mean + loss_reg_mean
                    print(f'Loss of batch is: {batch_loss}, unlearning loss is: {loss_ul_mean}, ce loss is: {loss_ce_mean}, reg loss is: {loss_reg_mean}')
                
                batch_loss.backward()
                baseline_optimizer.step()
                
                # update loss
                epoch_loss += batch_loss
                batch_count += 1
            
            mean_epoch_loss = epoch_loss / batch_count
            print(f'AAAI baseline model R{t}----------epoch: {epoch}---------total epoch loss: {epoch_loss}---------mean epoch loss: {mean_epoch_loss}')

            if early_stopping_tag:
                early_stopping(mean_epoch_loss.item())
                if early_stopping.early_stop:
                    print(f'Early stopping at epoch: {epoch}')
                    break

        end = time.time()

        efficiency = end - start

        # save current time optimal model
        if t % save_interval == 0:
            T_opt_models.append(copy.deepcopy(baseline_model).cpu())
            print(f'Optimal model of time {t} has saved.')

        # update previous model
        prev_model = copy.deepcopy(baseline_model)

        # evaluate
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_accuracy(baseline_model, Rt_loader)
        Acc_Ft = evaluate_accuracy(baseline_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=128, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)

        Acc_test = evaluate_accuracy(baseline_model, test_loader)

        print(f"AAAI time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}, Acc_test: {Acc_test:.4f}, time: {efficiency:.4f}")
    
    current_model_path = save_path + '-aaai' + '.pth'
    torch.save(T_opt_models, current_model_path)

    return T_opt_models