import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CurvatureLoss(nn.Module):
    def __init__(self, q, w=10):
        super().__init__()
        self.q = q
        self.w = w
        
    def curvature_estimation(self, embs, q=10, device=0, w=10):
        """Batch-optimized curvature estimation with PyTorch"""
        embs = torch.tensor(embs, device=device)
        batch_size, emb_dim, seq_len = embs.shape
        q = self.q
        w = self.w
        
        # 转置维度用于卷积操作
        # embs = embs.permute(0, 2, 1)  # [B, D, T]
        
        # 构建左右padding
        embs_pad_left = F.pad(embs, (q-1, 0), mode='replicate')
        embs_pad_right = F.pad(embs, (0, q-1), mode='replicate')
        
        # 构建差分卷积核
        kernel = torch.zeros(emb_dim, 1, q, device=embs.device)
        kernel[:, :, 0] = -1
        kernel[:, :, -1] = 1
        
        # 并行计算左右变化向量
        cv_left = F.conv1d(embs_pad_left, kernel, groups=emb_dim)
        cv_right = F.conv1d(embs_pad_right, kernel, groups=emb_dim)
        
        # 恢复原始维度
        cv_left = cv_left.permute(0, 2, 1)  # [B, T, D]
        cv_right = cv_right.permute(0, 2, 1)
        
        # 计算向量范数
        norm_left = torch.norm(cv_left, dim=2, keepdim=True)  # [B, T, 1]
        norm_right = torch.norm(cv_right, dim=2, keepdim=True)
        
        # 计算余弦相似度
        cos_sim = F.cosine_similarity(cv_left, cv_right, dim=2)
        cos_sim = torch.clamp(cos_sim, -1+1e-6, 1-1e-6)
        
        # 计算角度和曲率
        angle = torch.acos(cos_sim)
        curvature = angle / (norm_left.squeeze() + norm_right.squeeze() + 1e-6)
        
        # 移动平均平滑
        curvature = curvature.unsqueeze(1)  # [B, 1, T]
        curvature_pad = F.pad(curvature, (w-1, w), mode='replicate')
        avg_kernel = torch.ones(1, 1, 2*w, device=embs.device) / (2*w)
        movavg = F.conv1d(curvature_pad, avg_kernel)
        
        # 反转并标准化到[0,1]
        movavg = movavg.squeeze(1)  # [B, T]
        movavg = movavg.max(dim=1, keepdim=True)[0] - movavg
        min_val = movavg.min(dim=1, keepdim=True)[0]
        max_val = movavg.max(dim=1, keepdim=True)[0]
        movavg = (movavg - min_val) / (max_val - min_val + 1e-6)

        
        curv = torch.squeeze(curvature)
        curv = torch.max(curv)-curv # make CP higher than In-Segment
        curv = (curv - torch.min(curv))/(torch.max(curv)-torch.min(curv))

        curv_reciprocal = ((norm_left.squeeze()+norm_right.squeeze())/(angle+1e-6))
        curv_reciprocal = (curv_reciprocal - torch.min(curv_reciprocal))/(torch.max(curv_reciprocal)-torch.min(curv_reciprocal))


        return curv, curv_reciprocal, movavg
    
    def forward(self, embs, labels):
        """
        Args:
            embs:  形状 [B, T, D] 的嵌入序列
            labels: 形状 [B, T] 的标签序列
        Returns:
            曲率损失值
        """
        # 计算预测曲率
        _, pred_curv, _ = self.curvature_estimation(embs)  # [B, T]
        
        # 生成目标曲率
        target = torch.zeros_like(pred_curv)
        target[:, 1:] = (labels[:, 1:] != labels[:, :-1]).float()
        print(target.shape, pred_curv.shape)
        torch.set_printoptions(profile="full")
        print(target[0], labels[0])
        exit()
        
        # 计算加权MSE损失
        return F.mse_loss(pred_curv, target)