from typing import List

import torch
from torch import nn

from options import AETrainConfig


class ReconstructionLossBase(nn.Module):
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        raise NotImplementedError()
    
class WeightedMSEReconstructionLoss(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.var_weighted_base = cfg.var_weighted_base
        
        print(f"WeightedMSEReconstructionLoss: in cluster: {self.reconstruction_weighted_loss_base}, between cluster: {self.var_weighted_base}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("min_var") is not None, "min_var is not provided"
        assert kwargs.get("max_var") is not None, "max_var is not provided"
        
        min_var = torch.tensor(kwargs.get("min_var")).to(reconstructed_weights[0].device)
        max_var = torch.tensor(kwargs.get("max_var")).to(reconstructed_weights[0].device)
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            var = torch.var(original_weight, unbiased=False)
            var = (var - min_var) / (max_var - min_var)
            weighted_var_scale = ((1 - var) * self.var_weighted_base).exp()
            
            weighted_diff = diff * weighted_scale * weighted_var_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
class WeightedMSEReconstructionLoss_clamp(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.var_weighted_base = cfg.var_weighted_base
        self.clamp_pos = cfg.clamp_pos
        
        print(f"WeightedMSEReconstructionLoss_Clamp: in cluster: {self.reconstruction_weighted_loss_base}, between cluster: {self.var_weighted_base}, clamp: {self.clamp_pos}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("min_var") is not None, "min_var is not provided"
        assert kwargs.get("max_var") is not None, "max_var is not provided"
        
        min_var = torch.tensor(kwargs.get("min_var")).to(reconstructed_weights[0].device)
        max_var = torch.tensor(kwargs.get("max_var")).to(reconstructed_weights[0].device)
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = torch.quantile(original_weight, self.clamp_pos)
            weight_min = torch.quantile(original_weight, 1 - self.clamp_pos)
            weight_mean = original_weight.mean()
            
            clamped_original_weight = torch.clamp(original_weight, min=weight_min, max=weight_max)
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (clamped_original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            var = torch.var(original_weight, unbiased=False)
            var = (var - min_var) / (max_var - min_var)
            weighted_var_scale = ((1 - var) * self.var_weighted_base).exp()
            
            weighted_diff = diff * weighted_scale * weighted_var_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
class WeightedMSEReconstructionLoss_clustwise(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.var_weighted_base = cfg.var_weighted_base
        
        print(f"WeightedMSEReconstructionLoss_clustwise: in cluster: {self.reconstruction_weighted_loss_base}, between cluster: {self.var_weighted_base}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("min_var") is not None, "min_var is not provided"
        assert kwargs.get("max_var") is not None, "max_var is not provided"
        
        min_var = torch.tensor(kwargs.get("min_var")).to(reconstructed_weights[0].device)
        max_var = torch.tensor(kwargs.get("max_var")).to(reconstructed_weights[0].device)
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            # 计算差分平方张量 [1024, 1024]
            diff = (original_weight - reconstructed_weight) ** 2
            
            # 按照每个通道（即行）计算均值与最大值
            original_channel_mean = original_weight.mean(dim=1, keepdim=True)  # [1024, 1]
            original_channel_min = original_weight.min(dim=1, keepdim=True).values  # [1024, 1]
            original_channel_max = original_weight.max(dim=1, keepdim=True).values  # [1024, 1]
            
            weight_distance_max = torch.max((original_channel_max - original_channel_mean).abs(), (original_channel_min - original_channel_mean).abs())

            # 计算加权比例，逐通道应用广播
            weighted_scale = (original_weight - original_channel_mean).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()  # [1024, 1024]

            var = torch.var(original_weight, unbiased=False)
            var = (var - min_var) / (max_var - min_var)
            weighted_var_scale = ((1 - var) * self.var_weighted_base).exp()
            
            # 计算加权差异
            weighted_diff = diff * weighted_scale * weighted_var_scale
            
            # 累加损失
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

class WeightedMSEReconstructionLoss_clustwise_clamp(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.var_weighted_base = cfg.var_weighted_base
        self.clamp_pos = cfg.clamp_pos
        
        print(f"WeightedMSEReconstructionLoss_clustwise_clamp: in cluster: {self.reconstruction_weighted_loss_base}, between cluster: {self.var_weighted_base}, clamp: {self.clamp_pos}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("min_var") is not None, "min_var is not provided"
        assert kwargs.get("max_var") is not None, "max_var is not provided"
        
        min_var = torch.tensor(kwargs.get("min_var")).to(reconstructed_weights[0].device)
        max_var = torch.tensor(kwargs.get("max_var")).to(reconstructed_weights[0].device)
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            # 计算差分平方张量 [1024, 1024]
            diff = (original_weight - reconstructed_weight) ** 2
            
            # 按照每个通道（即行）计算均值与最大值
            original_channel_mean = original_weight.mean(dim=1, keepdim=True)  # [1024, 1]
            original_channel_min = torch.quantile(original_weight, 1 - self.clamp_pos, dim=1, keepdim=True)
            original_channel_max = torch.quantile(original_weight, self.clamp_pos, dim=1, keepdim=True)
            
            weight_distance_max = torch.max((original_channel_max - original_channel_mean).abs(), (original_channel_min - original_channel_mean).abs())

            clamped_original_weight = torch.clamp(original_weight, min=original_channel_min, max=original_channel_max)
            
            # 计算加权比例，逐通道应用广播
            weighted_scale = (clamped_original_weight - original_channel_mean).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()  # [1024, 1024]

            var = torch.var(original_weight, unbiased=False)
            var = (var - min_var) / (max_var - min_var)
            weighted_var_scale = ((1 - var) * self.var_weighted_base).exp()
            
            # 计算加权差异
            weighted_diff = diff * weighted_scale * weighted_var_scale
            
            # 累加损失
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

class WeightedMSEReconstructionLoss_in_cluster(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        
        print(f"WeightedMSEReconstructionLoss_in_cluster: in cluster: {self.reconstruction_weighted_loss_base}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
    
class WeightedMSEReconstructionLoss_in_cluster_clamp(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.clamp_pos = cfg.clamp_pos
        
        print(f"WeightedMSEReconstructionLoss_in_cluster_clamp: in cluster: {self.reconstruction_weighted_loss_base}, clamp: {self.clamp_pos}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = torch.quantile(original_weight, self.clamp_pos)
            weight_min = torch.quantile(original_weight, 1 - self.clamp_pos)
            weight_mean = original_weight.mean()
            
            clamped_original_weight = torch.clamp(original_weight, min=weight_min, max=weight_max)
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (clamped_original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
class WeightedMSEReconstructionLoss_in_cluster_clustwise_clamp(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.clamp_pos = cfg.clamp_pos
        
        print(f"WeightedMSEReconstructionLoss_in_cluster_clustwise_clamp: in cluster: {self.reconstruction_weighted_loss_base}, clamp: {self.clamp_pos}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            # 计算差分平方张量 [1024, 1024]
            diff = (original_weight - reconstructed_weight) ** 2
            
            # 按照每个通道（即行）计算均值与最大值
            original_channel_mean = original_weight.mean(dim=1, keepdim=True)  # [1024, 1]
            original_channel_min = torch.quantile(original_weight, 1 - self.clamp_pos, dim=1, keepdim=True)
            original_channel_max = torch.quantile(original_weight, self.clamp_pos, dim=1, keepdim=True)
            
            weight_distance_max = torch.max((original_channel_max - original_channel_mean).abs(), (original_channel_min - original_channel_mean).abs())

            clamped_original_weight = torch.clamp(original_weight, min=original_channel_min, max=original_channel_max)
            
            # 计算加权比例，逐通道应用广播
            weighted_scale = (clamped_original_weight - original_channel_mean).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()  # [1024, 1024]

            # 计算加权差异
            weighted_diff = diff * weighted_scale
            
            # 累加损失
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

class WeightedMSEReconstructionLoss_between_cluster(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.var_weighted_base = cfg.var_weighted_base
        
        print(f"WeightedMSEReconstructionLoss_between_cluster: between cluster: {self.var_weighted_base}")

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("min_var") is not None, "min_var is not provided"
        assert kwargs.get("max_var") is not None, "max_var is not provided"
        
        min_var = torch.tensor(kwargs.get("min_var")).to(reconstructed_weights[0].device)
        max_var = torch.tensor(kwargs.get("max_var")).to(reconstructed_weights[0].device)
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            var = torch.var(original_weight, unbiased=False)
            var = (var - min_var) / (max_var - min_var)
            weighted_var_scale = ((1 - var) * self.var_weighted_base).exp()
            
            weighted_diff = diff * weighted_var_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss


class WeightedMSEReconstructionLoss_divide_var(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        print(f"WeightedMSEReconstructionLoss_divide_var")
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weighted_diff = diff / torch.var(original_weight, unbiased=False)
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
class WeightedMSEReconstructionLoss_byweight(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        print(f"WeightedMSEReconstructionLoss_byweight")
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight).abs()
            
            weighted_diff = diff * original_weight.abs()
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
class WeightedMSEReconstructionLoss_0(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weighted_scale = original_weight.abs() / original_weight.abs().max()
            weighted_scale = weighted_scale.exp() * self.reconstruction_weighted_loss_base
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

# 相比 WeightedMSEReconstructionLoss scaler放在指数上，拉开差距, 并且增加减均值
class WeightedMSEReconstructionLoss_1(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        print(f"WeightedMSEReconstructionLoss_1: {self.reconstruction_weighted_loss_base}")
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
    
# # 相比 WeightedMSEReconstructionLoss_1 改为通道级别计算loss
# class WeightedMSEReconstructionLoss_2(ReconstructionLossBase):
#     def __init__(self):
#         super().__init__()

#     def forward(self,
#                 reconstructed_weights: List[torch.Tensor],
#                 original_weights: List[torch.Tensor],
#                 **kwargs) -> torch.Tensor:
        
#         # 确保权重基数存在
#         assert kwargs.get("reconstruction_weighted_loss_base") is not None
        
#         # 初始化损失值
#         loss = torch.tensor(0., device=reconstructed_weights[0].device)
#         num_parameters = 0
        
#         for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
#             # 计算差分平方张量 [1024, 1024]
#             diff = (original_weight - reconstructed_weight) ** 2
            
#             # 按照每个通道（即行）计算均值与最大值
#             original_channel_mean = original_weight.mean(dim=1, keepdim=True)  # [1024, 1]
#             original_channel_max = original_weight.abs().max(dim=1, keepdim=True).values  # [1024, 1]

#             # 计算加权比例，逐通道应用广播
#             weighted_scale = (original_weight - original_channel_mean).abs() / original_channel_max
#             weighted_scale = (weighted_scale * kwargs["reconstruction_weighted_loss_base"]).exp()  # [1024, 1024]

#             # 计算加权差异
#             weighted_diff = diff * weighted_scale
            
#             # 累加损失
#             loss += weighted_diff.sum()
#             num_parameters += original_weight.numel()
        
#         # 归一化损失
#         loss /= num_parameters
#         return loss
    


    
# 相比 WeightedMSEReconstructionLoss_1, 改为对数变化
class WeightedMSEReconstructionLoss_4(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = weighted_scale * self.reconstruction_weighted_loss_base
            weighted_scale = torch.log(weighted_scale + 1.) + 1.
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
    
# 相比 WeightedMSEReconstructionLoss_1, 改为线性变化
class WeightedMSEReconstructionLoss_5(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = weighted_scale * self.reconstruction_weighted_loss_base
            
            weighted_diff = diff * weighted_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    
# 相比 WeightedMSEReconstructionLoss_1(指数变化), 增加渐进scheduler [0, pi / 2]
class WeightedMSEReconstructionLoss_6(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.total_epochs = cfg.epochs
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.scheduler = lambda step: torch.sin(step / self.total_epochs * torch.pi / 2)

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("step") is not None
        step = torch.tensor(kwargs.get("step")).to(reconstructed_weights[0].device)
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)            
        
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            curr_weight_scale = (weighted_scale - 1) * self.scheduler(step) + 1
            weighted_diff = diff * curr_weight_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss
    

# 相比 WeightedMSEReconstructionLoss_6, 改为斜率减小的scheduler, [-pi / 2, 0]
class WeightedMSEReconstructionLoss_7(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.total_epochs = cfg.epochs
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        self.scheduler = lambda step: torch.sin(step / self.total_epochs * torch.pi / 2 - torch.pi / 2) + 1
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        
        assert kwargs.get("step") is not None
        step = torch.tensor(kwargs.get("step")).to(reconstructed_weights[0].device)
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)            
        
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            weighted_scale = (original_weight - original_weight.mean()).abs() / weight_distance_max
            
            curr_weight_scale = (weighted_scale - 1) * self.scheduler(step) + 1
            weighted_diff = diff * curr_weight_scale
            
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

# 对最大权重的那一部分进行截断
class WeightedMSEReconstructionLoss_8(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) -> torch.Tensor:
        
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            # 计算差值平方
            diff = (original_weight - reconstructed_weight) ** 2

            # 计算 9 分位点
            quantile_9_5 = torch.quantile(original_weight.flatten(), 0.99)
            quantile_0_5 = torch.quantile(original_weight.flatten(), 0.01)
            
            clamped_original_weight = torch.clamp(original_weight, min=quantile_0_5, max=quantile_9_5)
            
            weight_max = original_weight.max()
            weight_min = original_weight.min()
            weight_mean = original_weight.mean()
            
            weight_distance_max = max((weight_max - weight_mean).abs(), (weight_min - weight_mean).abs())
            
            # 计算加权因子
            weighted_scale = (clamped_original_weight - original_weight.mean()).abs() / weight_distance_max
            weighted_scale = (weighted_scale * self.reconstruction_weighted_loss_base).exp()
            
            # 加权差值
            weighted_diff = diff * weighted_scale
            
            # 累加损失和参数数量
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
        
        # 归一化损失
        loss /= num_parameters
        return loss
    
# 根据数量进行反向加权
class WeightedMSEReconstructionLoss_9(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()
        self.reconstruction_weighted_loss_base = cfg.reconstruction_weighted_loss_base
        print(f"WeightedMSEReconstructionLoss_9: {self.reconstruction_weighted_loss_base}")
        
    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) -> torch.Tensor:
        
        device = reconstructed_weights[0].device
        loss = torch.tensor(0.).to(device)
        num_parameters = 0
        
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            # 获取当前层的原始参数
            original_weight_flat = original_weight.flatten()
            
            # 计算当前层参数的最小值和最大值
            min_val = original_weight_flat.min()
            max_val = original_weight_flat.max()
            
            # 将当前层参数范围分为100个bin
            num_bins = 100
            bins = torch.linspace(min_val, max_val, num_bins + 1, device=device)
            # 统计每个bin的参数数量
            bin_counts = torch.histc(original_weight_flat, bins=num_bins, min=min_val.item(), max=max_val.item())
            
            # 计算权重（每个bin的权重为参数数量的倒数，防止除0）
            bin_counts = torch.clamp(bin_counts, min=1)
            weights = 1.0 / bin_counts
            weights /= weights.sum()  # 归一化权重
            
            indices = torch.bucketize(original_weight_flat, bins, right=True) - 1  # 参数对应的bin索引
            indices = torch.clamp(indices, 0, num_bins - 1)
            param_weights = weights[indices]  # 获取对应的权重
            
            param_weights = (param_weights * self.reconstruction_weighted_loss_base).exp()
            param_weights = param_weights.view(original_weight.size())
            
            # 计算差值平方
            diff = (original_weight - reconstructed_weight) ** 2
            # 加权差值
            weighted_diff = diff * param_weights
            
            # 累加损失和参数数量
            loss += weighted_diff.sum()
            num_parameters += original_weight.numel()
        
        # 归一化损失
        loss /= num_parameters
        return loss

class MSEReconstructionLoss(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        mse_loss = nn.MSELoss()
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            loss += mse_loss(original_weight, reconstructed_weight)
        loss /= len(original_weights)
        return loss

class MSEReconstructionLoss_1(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
                
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            diff = (original_weight - reconstructed_weight) ** 2
            
            max_diff = (original_weight.max() - reconstructed_weight.max()).abs()
            min_diff = (original_weight.min() - reconstructed_weight.min()).abs()
            
            loss = loss + diff.sum() + max_diff + min_diff
            num_parameters += original_weight.numel()
            
        loss /= num_parameters
        return loss

class LWLNMSEReconstructionLoss(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        mse_loss = nn.MSELoss()
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            mse = mse_loss(original_weight, reconstructed_weight) / torch.var(original_weight, unbiased=False)
            loss += mse
        loss /= len(original_weights)
        return loss

class L2ReconstructionLoss(ReconstructionLossBase):
    def __init__(self, cfg: AETrainConfig):
        super().__init__()

    def forward(self,
                reconstructed_weights: List[torch.Tensor],
                original_weights: List[torch.Tensor],
                **kwargs) \
            -> torch.Tensor:
        loss = torch.tensor(0.).to(reconstructed_weights[0].device)
        num_parameters = 0
        for original_weight, reconstructed_weight in zip(original_weights, reconstructed_weights):
            num_parameters += original_weight.numel()
            loss += torch.sum((original_weight - reconstructed_weight) ** 2)
        loss /= num_parameters
        return torch.sqrt(loss)


def load_reconstruction_loss(cfg: AETrainConfig) -> ReconstructionLossBase:
    losses = {
        "MSELoss": MSEReconstructionLoss,
        "WeightedMSELoss": WeightedMSEReconstructionLoss,
        "WeightedMSELoss_clustwise": WeightedMSEReconstructionLoss_clustwise,
        "WeightedMSELoss_clustwise_clamp": WeightedMSEReconstructionLoss_clustwise_clamp,
        "WeightedMSELoss_in_cluster": WeightedMSEReconstructionLoss_in_cluster,
        "WeightedMSELoss_in_cluster_clustwise_clamp": WeightedMSEReconstructionLoss_in_cluster_clustwise_clamp,
        "WeightedMSELoss_between_cluster": WeightedMSEReconstructionLoss_between_cluster,
        "WeightedMSELoss_divide_var": WeightedMSEReconstructionLoss_divide_var,
        "WeightedMSELoss_byweight": WeightedMSEReconstructionLoss_byweight,
        "WeightedMSELoss_clamp": WeightedMSEReconstructionLoss_clamp,
        "WeightedMSELoss_in_cluster_clamp": WeightedMSEReconstructionLoss_in_cluster_clamp,
        "L2Loss": L2ReconstructionLoss
    }

    try:
        return losses[cfg.reconstruction_loss_type](cfg)
    except KeyError:
        raise ValueError("Unknown Reconstruction Loss Type")
        
