import json
import os.path as osp
import torch
import torch.nn.functional as F

@torch.jit.script
def coords2dist(coords):
    return torch.norm(coords.unsqueeze(-2) - coords.unsqueeze(-3), dim=-1)

import numpy as np


def find_edges_from_distance_matrix(distance_matrix: torch.Tensor, num_nodes: int, edge_length: float = 1.0):
    num_nodes = int(num_nodes)
    edges = []
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if distance_matrix[i, j] <= edge_length and distance_matrix[i, j] > 0:
                edges.append([i, j])
    return edges


def find_torsion_chains(num_nodes, edges):
    num_nodes = int(num_nodes)
    torsion_indices = []
    adjacency_list = [[] for _ in range(num_nodes)]
    
    for start, end in edges:
        adjacency_list[start].append(end)
        adjacency_list[end].append(start)
    
    for i in range(num_nodes):
        for j in adjacency_list[i]:
            for k in adjacency_list[j]:
                if k == i:
                    continue
                for l in adjacency_list[k]:
                    if l == j or l == i:
                        continue
                    torsion_indices.append([i, j, k, l])
    
    return torsion_indices

def calculate_torsion_angle(coords, torsion_indices):
    batch_size, max_node, _ = coords.shape
    _, tsl_size, _ = torsion_indices.shape

    
    # 确保 torsion_indices 类型为 torch.long
    torsion_indices = torsion_indices.long()

    # 从 torsion_indices 中提取索引
    i = torsion_indices[:, :, 0]
    j = torsion_indices[:, :, 1]
    k = torsion_indices[:, :, 2]
    l = torsion_indices[:, :, 3]
    
    # 获取对应索引的坐标
    # batch_indices 用于索引批次
    batch_indices = torch.arange(batch_size, dtype=torch.long, device=coords.device).view(-1, 1)
    p0 = coords[batch_indices, i]
    p1 = coords[batch_indices, j]
    p2 = coords[batch_indices, k]
    p3 = coords[batch_indices, l]
    
    # 计算向量
    b0 = -1.0 * (p1 - p0)
    b1 = p2 - p1
    b2 = p3 - p2

    # 归一化 b1
    b1_norm = torch.norm(b1, dim=-1, keepdim=True)
    b1 = b1 / (b1_norm + 1e-8)

    # 计算投影矢量
    v = b0 - torch.sum(b0 * b1, dim=-1, keepdim=True) * b1
    w = b2 - torch.sum(b2 * b1, dim=-1, keepdim=True) * b1

    # 计算旋转角度的分子和分母
    x = torch.sum(v * w, dim=-1)
    y = torch.sum(torch.cross(b1, v) * w, dim=-1)

    # 计算扭转角度
    angles = torch.atan2(y, x)
    
    return angles

def extract_torsions_batch(batch, input_coords):
    batch_size = batch['num_nodes'].shape[0]  # b
    num_nodes_batch = batch['num_nodes']  # (b)
    distance_matrix_batch = batch['distance_matrix']  # (b, max_node_num, max_node_num)

    batch_torsion_indices = []
    batch_torsion_angles = []

    max_torsion_chains = 0

    for i in range(batch_size):
        num_nodes = num_nodes_batch[i].item()
        distance_matrix = distance_matrix_batch[i]
        coords = input_coords[i][:num_nodes]

        torsion_chains = batch['torsion_indices']

        torsion_angles = [calculate_torsion_angle(coords, chain) for chain in torsion_chains]

        torsion_indices = torch.tensor(torsion_chains, dtype=torch.int64)
        torsion_angles = torch.tensor(torsion_angles, dtype=torch.float32)

        batch_torsion_indices.append(torsion_indices)
        batch_torsion_angles.append(torsion_angles)

        max_torsion_chains = max(max_torsion_chains, len(torsion_chains))

    # Padding indices and angles to have the same length
    padded_indices_list = []
    padded_angles_list = []

    for torsion_indices, torsion_angles in zip(batch_torsion_indices, batch_torsion_angles):
        num_torsions = torsion_indices.shape[0]

        if num_torsions < max_torsion_chains:
            padding_size = max_torsion_chains - num_torsions
            padded_torsion_indices = torch.cat([torsion_indices, torch.zeros((padding_size, 4), dtype=torch.int64)], dim=0)
            padded_torsion_angles = torch.cat([torsion_angles, torch.zeros((padding_size,), dtype=torch.float32)], dim=0)
        else:
            padded_torsion_indices = torsion_indices
            padded_torsion_angles = torsion_angles

        padded_indices_list.append(padded_torsion_indices.unsqueeze(0))  # Add batch dimension
        padded_angles_list.append(padded_torsion_angles.unsqueeze(0))

    # Stack all padded tensors
    batch_torsion_indices = torch.cat(padded_indices_list, dim=0)
    batch_torsion_angles = torch.cat(padded_angles_list, dim=0)

    return batch_torsion_indices, batch_torsion_angles



def find_adjacent_triplets(num_nodes, edges):
    adjacency_list = [[] for _ in range(num_nodes)]

    for start, end in edges:
        adjacency_list[start].append(end)
        adjacency_list[end].append(start)

    triplets = []
    for i in range(num_nodes):
        for j in adjacency_list[i]:
            for k in adjacency_list[j]:
                if k != i:  # k should not be equal to i to form an angle
                    triplets.append((i, j, k))

    return triplets

def calculate_angle(coords, triplets):
    # 从 triplet 提取索引
    batch_indices = torch.arange(coords.shape[0], dtype=torch.long, device=coords.device).view(-1, 1)
    i = triplets[:, :, 0]
    j = triplets[:, :, 1]
    k = triplets[:, :, 2]
    
    # 获取对应索引的坐标
    p0 = coords[batch_indices, i]
    p1 = coords[batch_indices, j]
    p2 = coords[batch_indices, k]
    
    # 计算向量
    BA = p0 - p1
    BC = p2 - p1
    
    # 计算点积和叉积
    dot_product = torch.sum(BA * BC, dim=-1)
    cross_product = torch.cross(BA, BC)
    norm_cross = torch.norm(cross_product, dim=-1)

    z_unit_vector = torch.tensor([0, 0, 1], dtype=coords.dtype, device=coords.device)

    # 计算叉积的第三个分量来决定方向（正负）
    direction = torch.sign((cross_product * z_unit_vector).sum(dim=-1))

    # 计算叉积的范数（永远是非负）
    norm_cross_product = torch.norm(cross_product, dim=-1)

    # 使用 atan2 计算角度
    angles = torch.atan2(direction * norm_cross_product, dot_product) 

    return angles

# def calculate_angle(coords, triplet):
#     i, j, k = triplet
#     p0, p1, p2 = coords[i], coords[j], coords[k]

#     BA = p0 - p1
#     BC = p2 - p1

#     dot_product = torch.dot(BA, BC)
#     cross_product = torch.cross(BA, BC)
#     norm_cross = torch.norm(cross_product)

#     angle = torch.atan2(norm_cross, dot_product)

#     return angle

def extract_angles_batch(batch, input_coords):
    batch_size = batch['num_nodes'].shape[0]  # b
    num_nodes_batch = batch['num_nodes']  # (b)
    distance_matrix_batch = batch['distance_matrix']  # (b, max_node_num, max_node_num)

    batch_angle_indices = []
    batch_angles = []

    max_angle_chains = 0

    for i in range(batch_size):
        num_nodes = num_nodes_batch[i].item()
        distance_matrix = distance_matrix_batch[i]
        coords = input_coords[i][:num_nodes]

        angle_chains = batch['angle_indices']
        
        angles = [calculate_angle(coords, chain) for chain in angle_chains]
        
        angle_indices = torch.tensor(angle_chains, dtype=torch.int64)
        angles = torch.tensor(angles, dtype=torch.float32)
        
        batch_angle_indices.append(angle_indices)
        batch_angles.append(angles)

        max_angle_chains = max(max_angle_chains, len(angle_chains))

    # Padding indices and angles to have the same length
    padded_indices_list = []
    padded_angles_list = []

    for angle_indices, angles in zip(batch_angle_indices, batch_angles):
        num_angles = angle_indices.shape[0]

        if num_angles < max_angle_chains:
            padding_size = max_angle_chains - num_angles
            padded_angle_indices = torch.cat([angle_indices, torch.zeros((padding_size, 3), dtype=torch.int64)], dim=0)
            padded_angles = torch.cat([angles, torch.zeros((padding_size,), dtype=torch.float32)], dim=0)
        else:
            padded_angle_indices = angle_indices
            padded_angles = angles
        
        padded_indices_list.append(padded_angle_indices.unsqueeze(0))  # Add batch dimension
        padded_angles_list.append(padded_angles.unsqueeze(0))

    # Stack all padded tensors
    batch_angle_indices = torch.cat(padded_indices_list, dim=0)
    batch_angles = torch.cat(padded_angles_list, dim=0)

    return batch_angle_indices, batch_angles

def coords2angle(coords):
    # 输入 coords 现在是形状为 (b, n, 3) 的张量
    # b 是批量大小，n 是每个批次中的点数
    
    # 扩展 coords 以形成所有可能的三元组组合
    A = coords.unsqueeze(2).unsqueeze(3).expand(-1, -1, coords.size(1), coords.size(1), -1)  # 形状为 (b, n, n, n, 3)
    B = coords.unsqueeze(1).unsqueeze(3).expand(-1, coords.size(1), -1, coords.size(1), -1)  # 形状为 (b, n, n, n, 3)
    C = coords.unsqueeze(1).unsqueeze(2).expand(-1, coords.size(1), coords.size(1), -1, -1)  # 形状为 (b, n, n, n, 3)
    
    # 计算向量 BA 和 BC
    BA = A - B  # 形状为 (b, n, n, n, 3)
    BC = C - B  # 形状为 (b, n, n, n, 3)
    
    # 计算向量的点积 BA · BC
    dot_product = (BA * BC).sum(dim=-1)  # 形状为 (b, n, n, n)
    
    # 计算向量的叉积 BA x BC
    cross_product = torch.cross(BA, BC, dim=-1)
    norm_cross_product = torch.norm(cross_product, dim=-1)


    angle = torch.atan2(norm_cross_product, dot_product)
    
    return angle

def add_coords_noise(coords, edge_mask, noise_level, noise_smoothing):
    noise = coords.new(coords.size()).normal_(0, noise_level)
    dist_mat = coords2dist(coords).add_((1-edge_mask.float())*1e9)
    smooth_mat = torch.softmax(-dist_mat/noise_smoothing, -1)
    noise = torch.matmul(smooth_mat, noise)
    coords = coords + noise
    return coords

def mask_angles(distance_matrix):
    
    # 获取batch大小和原子数量
    b, n, _ = distance_matrix.shape
    
    # 创建一个与distance_matrix相同形状的tensor，用于比较距离
    # 并扩展维度以便广播
    distance_matrix_ij = distance_matrix.unsqueeze(2).expand(b, n, n, n)
    distance_matrix_jk = distance_matrix.unsqueeze(1).expand(b, n, n, n)
    
    # # 创建mask，条件是任一边的距离大于3
    # mask_angle = (distance_matrix_ij <= 0) | (distance_matrix_ij >= 4) | \
    #             (distance_matrix_jk <= 0) | (distance_matrix_jk >= 4)

    # 创建mask，条件是排除一条边两端是两个相同节点的情况
    mask_angle = (distance_matrix_ij <= 0)  | \
                (distance_matrix_jk <= 0) 

    mask_angle = ~mask_angle
    # print(f'angle num {n*n*n*b} vaild angle num {mask_angle.sum()}')
    
    return mask_angle


class SinCosLoss:
    def __init__(self):
        pass
    
    def __call__(self, sin_cos_pred, sin_cos_true, mask, reduce=True):
        """
        sin_cos_pred: 预测值，形式为[sin(预测角度), cos(预测角度)]
        sin_cos_true: 真实值，形式为[sin(真实角度), cos(真实角度)]
        mask: 一个与sin_cos_pred和sin_cos_true相同形状的掩码，用于指示哪些数据点应该计入损失计算
        reduce: 如果为True，则返回所有样本的平均损失；如果为False，则返回每个样本的损失
        """
        bsize = sin_cos_pred.size(0)

        # 计算 sin 和 cos 的差异
        sin_diff = sin_cos_pred[..., 0] - sin_cos_true[..., 0]
        cos_diff = sin_cos_pred[..., 1] - sin_cos_true[..., 1]

        # 计算损失
        sin_loss = sin_diff ** 2
        cos_loss = cos_diff ** 2
        total_loss = sin_loss + cos_loss

        # 更新掩码以排除损失小于0.001的部分
        significant_loss_mask = (total_loss >= 0.001).float()
        updated_mask = mask * significant_loss_mask

        total_loss = total_loss.contiguous().view(bsize, -1)
        updated_mask = updated_mask.to(total_loss.dtype).view(bsize, -1)

        # 确保分母不为零
        mask_sum = updated_mask.sum(dim=1) + 1e-6  # 增加一个较大的稳定常数

        if reduce:
            total_loss = (total_loss * updated_mask).sum() / mask_sum.sum()
        else:
            total_loss = (total_loss * updated_mask).sum(dim=1) / mask_sum
        
        # 检查是否有 NaN 或 Inf
        if torch.any(torch.isnan(total_loss)) or torch.any(torch.isinf(total_loss)):
            print("Error: NaN or Inf detected in the loss computation.")

        return total_loss

def discrete_dist(dist, num_bins, range_bins):
    dist = dist * ((num_bins - 1) / range_bins)
    dist = dist.long().clamp(0, num_bins - 1)
    return dist

def discrete_buckets(angle, num_bins, max_angle=2*3.14159):
    """
    将角度离散化为指定数量的桶（类别）。
    angle: 输入角度，形状为[...]
    num_bins: 桶的数量（类别数）
    max_angle: 最大角度，一般为360度
    """
    angle = angle % max_angle #修改范围到（0,2π）
    bucket_size = max_angle / num_bins
    idx = (angle / bucket_size).long().clamp(0, num_bins - 1)
    return idx

class DiscreteAngleLoss:  
    def __init__(self, num_bins, max_angle=2*3.14159):  
        self.num_bins = num_bins  
        self.max_angle = max_angle  
    
    def __call__(self, angle_pred_logits, angle_true, mask=None, reduce=True):  
        """  
        angle_pred_logits: 预测值logits，形状为 [..., num_bins]，表示角度的分桶预测  
        angle_true: 真实的角度值，形状为 [...]  
        mask: 掩码，用于指示哪些数据点应该计入损失计算，形状为 [...]  
        reduce: 是否对批次内样本平均损失，如果为 True 则返回平均损失，否则返回每个样本的损失  
        """  
        bsize = angle_pred_logits.size(0)  

        # 将真实角度值离散化成桶  
        angle_bins = discrete_buckets(angle_true, self.num_bins, self.max_angle)  

        # 计算原始的交叉熵损失  
        ce_loss = F.cross_entropy(angle_pred_logits.contiguous().view(-1, self.num_bins),   
                                  angle_bins.contiguous().view(-1),   
                                  reduction='none')  

        # 计算循环移动后的交叉熵损失  
        shifted_logits = torch.cat([angle_pred_logits[..., -1:], angle_pred_logits[..., :-1]], dim=-1)  
        shifted_bins = (angle_bins - 1) % self.num_bins  
        shifted_ce_loss = F.cross_entropy(shifted_logits.contiguous().view(-1, self.num_bins),   
                                          shifted_bins.contiguous().view(-1),   
                                          reduction='none')  

        # 选择较小的损失  
        angle_xent = torch.min(ce_loss, shifted_ce_loss)  

        # 重新整理形状  
        angle_xent = angle_xent.view(bsize, -1)  

        if mask is not None:  
            mask = mask.to(angle_xent.dtype).contiguous().view(bsize, -1)  
            angle_xent = angle_xent * mask  

        if reduce:  
            total_xent = angle_xent.sum() / (mask.sum() + 1e-9) if mask is not None else angle_xent.mean()  
        else:  
            total_xent = angle_xent.sum(dim=1) / (mask.sum(dim=1) + 1e-9) if mask is not None else angle_xent.mean(dim=1)  

        return total_xent  

# class DiscreteAngleLoss:
#     def __init__(self, num_bins, max_angle=2*3.14159):
#         self.num_bins = num_bins
#         self.max_angle = max_angle
    
#     def __call__(self, angle_pred_logits, angle_true, mask=None, reduce=True):
#         """
#         angle_pred_logits: 预测值logits，形状为 [..., num_bins]，表示角度的分桶预测
#         angle_true: 真实的角度值，形状为 [...]
#         mask: 掩码，用于指示哪些数据点应该计入损失计算，形状为 [...]
#         reduce: 是否对批次内样本平均损失，如果为 True 则返回平均损失，否则返回每个样本的损失
#         """
#         bsize = angle_pred_logits.size(0)

#         # 将真实角度值离散化成桶
#         angle_bins = discrete_buckets(angle_true, self.num_bins, self.max_angle)

#         # 计算交叉熵损失
#         angle_xent = F.cross_entropy(angle_pred_logits.contiguous().view(-1, self.num_bins), angle_bins.view(-1), reduction='none')

#         # 重新整理形状
#         angle_xent = angle_xent.view(bsize, -1)

#         if mask is not None:
#             mask = mask.to(angle_xent.dtype).contiguous().view(bsize, -1)
#             angle_xent = angle_xent * mask

#         if reduce:
#             total_xent = angle_xent.sum() / (mask.sum() + 1e-9) if mask is not None else angle_xent.mean()
#         else:
#             total_xent = angle_xent.sum(dim=1) / (mask.sum(dim=1) + 1e-9) if mask is not None else angle_xent.mean(dim=1)

#         return total_xent

class DiscreteDistLoss:
    def __init__(self, num_bins, range_bins):
        self.num_bins = num_bins
        self.range_bins = range_bins
    
    def __call__(self, dist_logits, dist_targ, mask, reduce=True):
        num_bins = self.num_bins
        range_bins = self.range_bins
        bsize = dist_logits.size(0)
        
        dist_targ = discrete_dist(dist_targ, num_bins, range_bins)
        dist_logits = dist_logits.contiguous().view(-1, num_bins)
        dist_targ = dist_targ.view(-1)
        xent = F.cross_entropy(dist_logits, dist_targ, reduction='none')
        
        xent = xent.view(bsize, -1)
        mask = mask.to(xent.dtype).view(bsize, -1)
        
        if reduce:
            xent = (xent * mask).sum() / (mask.sum() + 1e-9)
        else:
            xent = (xent * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
        
        return xent

class MaskedL1Loss:
    def __init__(self, reduction='mean'):
        self.reduction = reduction

    def __call__(self, input, target, mask):
        if input.shape != target.shape:
            raise ValueError("Input and target must have the same shape")

        # 计算 L1 损失
        abs_diff = torch.abs(input - target)

        # 更新掩码以排除小于 0.01 的差异
        significant_diff_mask = abs_diff >= 0.01
        updated_mask = mask * significant_diff_mask

        # # 计算因为损失小于0.01而被mask的元素数量占有效元素百分比
        # num_total_elements = mask.sum().item()
        # num_masked_elements = num_total_elements - updated_mask.sum().item()
        # percentage_masked = (num_masked_elements / num_total_elements) * 100 if num_total_elements != 0 else 0

        # 应用更新后的掩码
        masked_loss = abs_diff * updated_mask

        # 计算掩码的和
        mask_sum = updated_mask.sum()

        if mask_sum.item() == 0:
            raise ValueError("Sum of mask elements is zero, no elements to compute loss for")

        # 根据 reduction 参数返回不同的损失值
        if self.reduction == 'sum':
            return masked_loss.sum()
        elif self.reduction == 'mean':
            return masked_loss.sum() / mask_sum
        else:
            return masked_loss


class BinsProcessor:
    def __init__(self, path, 
                 shift_half=True,
                 zero_diag=True,):
        self.path = path
        self.shift_half = shift_half
        self.zero_diag = zero_diag
        
        self.data_path = osp.join(path, 'data')
        self.meta_path = osp.join(path, 'meta.json')
        
        with open(self.meta_path, 'r') as f:
            self.meta = json.load(f)
        
        self.num_samples = self.meta['num_samples']
        self.num_bins = self.meta['num_bins']
        self.range_bins = self.meta['range_bins']
        
        self.bin_size = self.range_bins / (self.num_bins-1)
    
    def bins2dist(self, bins):
        bins = bins.float()
        if self.shift_half:
            bins = bins + 0.5
        dist = bins * self.bin_size
        dist = dist + dist.transpose(-2,-1)
        if self.zero_diag:
            dist = dist * (1 - torch.eye(dist.size(-1),
                                         dtype=dist.dtype,
                                         device=dist.device))
        return dist
