import copy
import torch
from torch import nn
from torch.nn import functional as F
from .Resnet import ResNet50,ResNet18,ResNet101
import math
import torchvision.models as models
from .diffusion import Encoder, Decoder  # 假设你已经实现了这些类
from .quantizer import VectorQuantizerPerDimSinkhorn
import yaml
import lpips
from .SoftQuantizer import PerDimSoftQuantizer
# from .SoftQuantizerEntropy import PerDimSoftQuantizer
 #第一次尝试 直接encoder出来解耦成10维，然后重建
  
 
class ONDLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(ONDLinear, self).__init__()
        # 定义权重矩阵和偏置
        self.weight = nn.Parameter(torch.randn(out_features, in_features))  # 可训练参数
        nn.init.uniform_(self.weight, a=0, b=1)  # 均匀分布 [0, 1]

        
    def non_negative_transform(self, x):
        """自定义非负转换函数（可导）"""
        return torch.log(1 + torch.exp(x))  # Softplus
    
    def forward(self, input):
        # 执行线性变换： y = x * W^T + b
        H = input.matmul(self.weight.T)
        V = H.matmul(self.weight)
        return H,V
    
    def clamp_weights(self):
        # 使用 clamp 操作裁剪负值为 0
        with torch.no_grad():
            self.weight.copy_(torch.clamp(self.weight.data, min=1e-6))



# 定义 GoldenRoad 网络结构
class GoldenRoad(nn.Module):
    def __init__(
        self,
        batch_size=64,
        hidden_size=512,        # 投影层维度
        moving_average_decay=0.99,   # 慢速更新的 EMA 系数
        nmf_lambda=0.1,              # NMF 损失的权重
        orthogonal_lambda=0.1,       # 正交性损失的权重
        similarity_lambda=0.1,
        vae_lambda=0.3,# 相似性损失的权重
        rank = 5
    ):
        super().__init__()
        config_path = '/home/star/Projects/g2/gyh/gyh/CDQAE-main/config/3dshpe.yaml'
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        encoder_params = config['autoencoder']['params']['encoder']['params']
        decoder_params = config['autoencoder']['params']['decoder']['params']
        # 创建 Encoder 和 Decoder 实例
        self.perceptual_loss = lpips.LPIPS(net='vgg',verbose=False).requires_grad_(False)
        self.encoder = Encoder(**encoder_params)
        self.decoder = Decoder(**decoder_params)
        # 定义逆卷积层（转置卷积），将大小从 (B, 10, 1, 1) 转换为 (B, 10, 4, 4) #这个应该用分组卷积来试试
        self.deconv = nn.ConvTranspose2d(in_channels=10, out_channels=10, 
                                        kernel_size=4, stride=4, padding=0)
        self.vae_lambda = vae_lambda
        self.nmf_lambda = nmf_lambda
        self.orthogonal_lambda = orthogonal_lambda
        self.sim_lambda = similarity_lambda
        self.nmf = ONDLinear(512,10) #2048,128
        # self.quantizer = VectorQuantizerPerDimSinkhorn(n_e = 20,e_dim = 10)
        self.quantizer = PerDimSoftQuantizer(num_latents=10, num_values_per_latent=20, temperature=1.0,min_temperature=0.1, anneal_rate=2e-2)
    def shift_projection(self,tensor):
        min_val = torch.min(tensor, dim=-1, keepdim=True).values  # 按最后一个维度计算
        return tensor - min_val + 1e-6  # 加上微小正值，避免数值过小
    def orthogonal_loss(self,W):
        # 计算 W^T W
        identity = torch.eye(W.size(0), device=W.device)
        a = torch.matmul(W,W.T)
        loss = torch.norm(torch.matmul(W,W.T) - identity, p='fro')  # 使用 Frobenius 范数
        # or_loss = torch.log2(loss+1)
        or_loss = torch.log1p(loss)
        return or_loss
    
    # def pairwise_separation_loss(self, z):
    #     B = z.size(0)
    #     dist = ((z.unsqueeze(1) - z.unsqueeze(0)) ** 2).sum(dim=-1)
    #     mask = ~torch.eye(B, dtype=torch.bool, device=z.device)
    #     mean_dist = dist[mask].mean()
    #     return 1.0 / (mean_dist + 1e-6)  # 越远，loss 越小；始终为正
    
    def forward(self, x_unit,epoch):
        x1,_ = x_unit
        # X = torch.cat((x1,x2),dim=0)
        b, _, h, w = x1.shape
        online_features = self.encoder(x1)
        online_features = F.adaptive_avg_pool2d(online_features, (1, 1))
        online_features_flat = self.shift_projection(online_features.view(b,-1))
        z,V = self.nmf(online_features_flat)
        if epoch >=1:
            z_quant, weight,loss_quant = self.quantizer(z)
        else:
            z_quant = z
        z_quant = z_quant.unsqueeze(-1).unsqueeze(-1)
        z_rec = self.deconv(z_quant)
        x_rec = self.decoder(z_rec)
        # 计算重建损失：L1 损失
        loss_l1 = (x_rec - x1).abs().mean()
        # 计算重建损失：L2 损失
        loss_l2 = (x_rec - x1).pow(2).mean()
        loss_p = self.perceptual_loss(x1, x_rec).mean()
        base_loss = ((V - online_features_flat) ** 2).mean(dim=1).mean()  # 每个样本计算L2损失
        # base_loss = torch.norm(V - online_features_flat, p='fro')  # 使用F范数计算重构误差
        reconstruction_loss = torch.log2(1+base_loss)        
        l1_regularization = torch.mean(torch.abs(z))* 1e-3 # Mean absolute value
        total_nmf_loss = self.nmf_lambda * (reconstruction_loss + l1_regularization)
        RzLoss =  self.orthogonal_lambda*self.orthogonal_loss(self.nmf.weight)
        if epoch >=1:
            total_loss = loss_l1 + loss_l2 + 2*loss_p + total_nmf_loss + RzLoss + loss_quant 
        else:
            # total_loss = loss_l1 + loss_l2 + 2*loss_p 
            total_loss = loss_l1 + loss_l2 + 2*loss_p + total_nmf_loss + RzLoss
        return z_quant.squeeze(),total_loss,reconstruction_loss,x_rec
    