import torch
import copy
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Subset, TensorDataset
from tqdm import tqdm
import sys
import os
sys.path.append(os.path.abspath('autodl-tmp/Sequential-Unlearnig-main'))
from utils import EarlyStopping, evaluate_accuracy, clone_and_freeze_model
from models import CustomModel

def l2_pgd_attack(model, x, y, epsilon=0.3, alpha=0.01, iters=40):
    # 初始化扰动
    delta = torch.zeros_like(x).to(x.device)
    for _ in range(iters):
        delta.requires_grad = True
        output = model(x + delta)
        loss = nn.CrossEntropyLoss()(output, y)
        loss.backward()
        # L2 norm ball projection
        delta.data = (delta + alpha * delta.grad.detach().sign()).clamp(-epsilon, epsilon)
        delta.data = (delta / delta.norm()) * min(delta.norm(), epsilon)
        delta.grad.zero_()
    return x + delta.detach()

def generate_adversarial_examples(model, loader, num_class):
    # 保存原始模型的训练状态
    was_training = model.training
    model.eval()
    
    adv_data = []
    for x, y in tqdm(loader, desc="Generating adversial examples progress"):
        x, y = x.to('cuda'), y.to('cuda')
        y_adv = (y + 1) % num_class  # 随机选择一个不同的标签
        x_adv = l2_pgd_attack(model, x, y_adv)
        adv_data.extend(zip(x_adv.cpu().numpy(), y_adv.cpu().numpy()))
        
    # 恢复模型到原始训练状态
    model.train(mode=was_training)
    return adv_data

def compute_weight_importances(model, inputs):
    was_training = model.training
    model.eval()
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 将输入数据移动到指定的设备上
    inputs = inputs.to(device)
    
    # 字典来存储参数的梯度
    importance = {name: torch.zeros_like(param) for name, param in model.named_parameters()}
    
    # 前向传播得到输出
    outputs = model(inputs)
    
    # 计算输出的L2范数平方
    norm_sq = (outputs.norm(p=2, dim=1) ** 2).mean()
    
    # 清除现有的梯度
    model.zero_grad()
    
    # 计算L2范数平方关于参数的梯度
    norm_sq.backward()
    
    # 累积每个参数的梯度
    for name, param in model.named_parameters():
        importance[name] += param.grad.data.abs()  # 这里不需要除以batch数，因为只有一个batch
    
    model.train(mode=was_training)
    return importance

def normalize_importances(importances):
    min_val = torch.min(importances)
    max_val = torch.max(importances)
    return (importances - min_val) / (max_val - min_val)

def measure_weight_importance(model, inputs):
    # Step 1: Calculate initial weight importances using the provided function
    importances = compute_weight_importances(model, inputs)
    
    # Step 2: Initialize the final normalized weight importance dictionary
    normalized_importances = {}
    
    # Step 3: Iterate over each layer's importances
    for name, imp in importances.items():
        # Normalize the importances for this layer
        normalized_imp = normalize_importances(imp)
        
        # Update as 1 - normalized_importance
        normalized_importances[name] = 1 - normalized_imp
    
    return normalized_importances


def aaai_baseline(train_dataset, indices, subset_indexs, T, eta, batch_size, 
                  epochs, num_classes, model_type, resume_model_path, early_stopping=True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    early_stopping_tag = early_stopping

    baseline_model = CustomModel(model_name = model_type, num_classes=num_classes, pretrained=False, model_path=resume_model_path).to(device)
    loss_fn = nn.CrossEntropyLoss()
    baseline_optimizer = optim.SGD(baseline_model.parameters(), lr=eta, momentum=0.9, weight_decay=5e-4)

    indices = indices
    Rt_indices = copy.deepcopy(indices)
    prev_F_t_1_indices = []

    # preserve last time's optimal param
    prev_model = clone_and_freeze_model(baseline_model)

    baseline_model.train()
    for t in range(1, T+1):
        # Ft_indices = indices[(t-1)*1000 : t*1000]
        Ft_indices = subset_indexs[t-1]
        Rt_indices = np.setdiff1d(indices, [*prev_F_t_1_indices, *Ft_indices])

        Ft_loader = DataLoader(Subset(train_dataset, Ft_indices), batch_size=batch_size, shuffle=True)
        
        if early_stopping_tag:
            early_stopping = EarlyStopping(patience=10, verbose=True, delta=4)
            
        # aaai-train, conclude two loss
        epoch_count = 0
        for epoch in range(epochs):
            epoch_count += 1
            if epoch_count >= 2:
                break
            
            total_epoch_loss = torch.tensor(0.0).to(device)
            samples = 0
            batch_count = 0
            for inputs, targets in tqdm(Ft_loader, desc="Training progress"):
                inputs, targets = inputs.to(device), targets.to(device)
                baseline_optimizer.zero_grad()
                
                batch_loss = torch.tensor(0.0).to(device)
                
                batch_count += 1
                if batch_count >= 2:
                    break
                
                # compute l_ul loss
                outputs = baseline_model(inputs)
                loss_ul_mean = -1 * loss_fn(outputs, targets)
                loss_ul_total = -1 * loss_fn(outputs, targets) * inputs.size(0)
                
                # compute l_ce loss
                adv_data = generate_adversarial_examples(baseline_model, Ft_loader, num_classes)
                adv_examples = torch.tensor([x for x, _ in adv_data], dtype=torch.float32).to(device)
                adv_labels = torch.tensor([y for _, y in adv_data], dtype=torch.long).to(device)
                adv_dataset = TensorDataset(adv_examples, adv_labels)
                loader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=True)

                loss_ce_total = torch.tensor(0.0).to(device)
                ce_samples = 0
                for x_adv, y_adv in loader:
                    x_adv, y_adv = x_adv.to(device), y_adv.to(device)
                    outputs = baseline_model(x_adv)
                    # 累加损失
                    loss_ce_total += loss_fn(outputs, y_adv) * x_adv.size(0)
                    ce_samples += x_adv.size(0)
                # 返回平均损失
                loss_ce_mean =  loss_ce_total / ce_samples

                # compute reg loss
                importance_weights = measure_weight_importance(baseline_model, inputs)
                loss_reg_total = torch.tensor(0.0).to(device)
                for ((name_a, weight_a), (name_b, param_b)), (_, param_a) in zip(zip(importance_weights.items(), baseline_model.named_parameters()), prev_model.named_parameters()):
                    if param_a.shape == param_b.shape:
                        # Calculate the squared difference of parameters
                        diff_square = (param_b - param_a) ** 2

                        # Convert the weight to the appropriate device
                        weight_a_tensor = weight_a.to(device)

                        # Use the corresponding importance weight for regularization
                        loss_reg_total += torch.sum(diff_square * weight_a_tensor)
                    else:
                        raise ValueError("Parameters shapes do not match between the two models.")
                loss_reg_mean = loss_reg_total / inputs.size(0)
                
                batch_loss = loss_ul_mean + loss_ce_mean + loss_reg_total
                batch_loss.backward()
                baseline_optimizer.step()
                
                print(f'Loss of batch is: {batch_loss}, unlearning loss is: {loss_ul_mean}, ce loss is: {loss_ce_mean}, reg loss is: {loss_reg_total}')
                total_epoch_loss += batch_loss
                
            # update previous model
            prev_model = clone_and_freeze_model(baseline_model)

            print(f'AAAI baseline model R{t}----------epoch: {epoch}---------total epoch loss: {total_epoch_loss}')

            if early_stopping_tag:
                early_stopping(total_epoch_loss.item())
                if early_stopping.early_stop:
                    print(f'Early stopping at epoch: {epoch}')
                    break

        # evaluate
        print(f'indices length: {len(indices)}, Ft_indices length:{len(Ft_indices)},  Rt_indices length: {len( Rt_indices)}, prev_F_t_1_indices length: {len(prev_F_t_1_indices)}')
        if len(Rt_indices) == 0:
            Acc_Rt = 0
        else:
            Rt_loader = DataLoader(Subset(train_dataset, Rt_indices), batch_size=batch_size)
            Acc_Rt = evaluate_accuracy(baseline_model, Rt_loader)
        Acc_Ft = evaluate_accuracy(baseline_model, Ft_loader)

        if t == 1:
            Acc_F_t_1 = 0.0
        else:
            Acc_F_t_1 = evaluate_accuracy(baseline_model, DataLoader(Subset(train_dataset, prev_F_t_1_indices), batch_size=128, shuffle=False))
        
        prev_F_t_1_indices.extend(Ft_indices)

        print(f"Time {t}: Acc_Ft: {Acc_Ft:.4f}, Acc_Rt: {Acc_Rt:.4f}, Acc_F_t-1: {Acc_F_t_1:.4f}")
    
    return 0