import torch
from torch import nn


def compute_fisher_matrix(model, dataloader, criterion, session):
    model.eval()
    fisher_matrix = {}
    for name, param in model.named_parameters():
        fisher_matrix[name] = torch.zeros_like(param)

    # 通过遍历数据计算Fisher信息矩阵
    for data, targets in dataloader:
        if session == 0:
            data, targets = data.cuda(), targets.cuda()
        else:
            data, targets = data[0].cuda(), targets.cuda()
        model.zero_grad()
        output = model(data)
        loss = criterion(output, targets)
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:  # 检查梯度是否为None
                fisher_matrix[name] += param.grad.pow(2) / len(dataloader)
            else:
                # 对于未参与反向传播的参数，我们可以将其Fisher信息视为0
                # 或者采取其他适当的操作
                continue

    return fisher_matrix


def store_optimal_params(model):
    optimal_params = {}
    for name, param in model.named_parameters():
        optimal_params[name] = param.clone()
    return optimal_params



def ewc_loss(model, fisher_matrix, optimal_params, lambda_ewc):
    reg_loss = 0
    for name, param in model.named_parameters():
        if name in fisher_matrix:
            # 计算正则化项
            reg_loss += (fisher_matrix[name] * (param - optimal_params[name]).pow(2)).sum()
    return reg_loss * lambda_ewc


def simple_reg_loss_l1(model, optimal_params, lambda_reg):
    reg_loss = 0
    for name, param in model.named_parameters():
        if name in optimal_params:
            # 计算基于参数变化的正则化项
            reg_loss += (param - optimal_params[name]).pow(2).sum()
    return reg_loss * lambda_reg


def adjust_gradients(model, optimal_params, max_diff=1.0):
    lambda_reg = 1e-5  # 正则化系数
    max_param_diff = {}  # 用于存储每个参数的最大差异值

    for name, param in model.named_parameters():
        if name in optimal_params:
            # 计算当前权重与最优权重的绝对差异
            param_diff = torch.abs(optimal_params[name])
            # 记录最大差异值
            max_param_diff[name] = max(max_param_diff.get(name, 0), param_diff.max().item())

            # 使用tanh函数映射到[-1,1]区间
            scaled_diff = torch.tanh(param_diff)  # 缩放到[-1, 1]
            # 计算修正因子
            grad_adjustment_factor = 1 - scaled_diff
            # 如果梯度不是None，进行梯度修正
            if param.grad is not None:
                param.grad *= grad_adjustment_factor * lambda_reg

    return max_param_diff


def freeze_resnet_layers(model):
    # 检查model是否为DataParallel实例
    # 如果是，我们需要直接访问封装的模型
    if isinstance(model, torch.nn.DataParallel):
        model = model.module

    # 尝试获取encoder属性，它可能是ResNet的一个实例
    if hasattr(model, 'encoder'):
        encoder = model.encoder
    else:
        # 如果model本身就是encoder
        encoder = model

    # 通过动态检查拥有的模块来适应不同的ResNet变体
    layers_to_freeze = []
    for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
        if hasattr(encoder, layer_name):
            layers_to_freeze.append(getattr(encoder, layer_name))

    # 如果模型是ResNet20或类似结构，它可能没有明显的"layer"属性
    # 此处的逻辑可能需要根据ResNet20的实际结构进行调整
    if not layers_to_freeze:
        # 假设ResNet20或类似结构的处理逻辑
        pass  # 根据ResNet20的结构进行必要的调整

    # 冻结除最后一个模块外的所有模块
    for layer in layers_to_freeze[:-1]:  # 保留最后一个模块不冻结
        for param in layer.parameters():
            param.requires_grad = False

    # 将BatchNorm层设置为评估模式
    for module in encoder.modules():
        if isinstance(module, nn.BatchNorm2d):
            module.eval()
