import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import MultiheadAttention
import torch.distributed as dist
import math

class ResidualAttentionBlock(nn.Module):

    def __init__(self, d_model: int, n_head: int):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
        self.mlp = nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.GELU(),
                nn.Linear(d_model * 4, d_model)
        )

        self.ln_q = nn.LayerNorm(d_model)
        self.ln_k = nn.LayerNorm(d_model)
        self.ln_v = nn.LayerNorm(d_model)

        self.ln = nn.LayerNorm(d_model)
    def attention(self, q, k, v, key_padding_mask):
        return self.attn(q, k, v, need_weights=False, key_padding_mask=key_padding_mask)[0]

    def forward(self, q, k, v, key_padding_mask=None):
        x = q
        q, k, v = self.ln_q(q), self.ln_k(k), self.ln_v(v)
        x = x + self.attention(q, k, v, key_padding_mask=key_padding_mask)
        x = x + self.mlp(self.ln(x))
        return x

class CrossAttentionBlock(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.self_attn_block_2D = ResidualAttentionBlock(d_model, n_head)
        self.self_attn_block_3D = ResidualAttentionBlock(d_model, n_head)
        self.cross_attn_block_2Dto3D = ResidualAttentionBlock(d_model, n_head)
        self.cross_attn_block_3Dto2D = ResidualAttentionBlock(d_model, n_head)

    def forward(self, features_2d, features_3d, mask_2d, mask_3d):
        features_2d = self.self_attn_block_2D(features_2d, features_2d, features_2d, key_padding_mask=~mask_2d)
        features_2d = features_2d * mask_2d.unsqueeze(-1)

        features_3d = self.self_attn_block_3D(features_3d, features_3d, features_3d, key_padding_mask=~mask_3d)
        features_3d = features_3d * mask_3d.unsqueeze(-1)
        
        features_2d = self.cross_attn_block_3Dto2D(features_2d, features_3d, features_3d, key_padding_mask=~mask_3d)
        features_2d = features_2d * mask_2d.unsqueeze(-1)

        features_3d = self.cross_attn_block_2Dto3D(features_3d, features_2d, features_2d, key_padding_mask=~mask_2d)
        features_3d = features_3d * mask_3d.unsqueeze(-1)

        return features_2d, features_3d

class BlendingModule(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_layers, dims):
        super(BlendingModule, self).__init__()

        self.cross_attn_blocks = nn.ModuleList()
        for i in range(num_layers):
            self.cross_attn_blocks.append(CrossAttentionBlock(hidden_dim, num_heads))

        self.mlp = nn.ModuleDict()

        for key, dim in dims.items():
            self.mlp[key] = nn.Sequential(
                nn.Linear(dim, hidden_dim))


    def forward(self, graph_embeds, graph_masks):
        features_2d, features_3d = graph_embeds['moleculestm'], graph_embeds['unimol']
        mask_2d, mask_3d = graph_masks['moleculestm'], graph_masks['unimol']
        features_2d = self.mlp['moleculestm'](features_2d) * mask_2d.unsqueeze(-1)
        features_3d = self.mlp['unimol'](features_3d) * mask_3d.unsqueeze(-1)

        for cross_attn_block in self.cross_attn_blocks:
            features_2d, features_3d = cross_attn_block(features_2d, features_3d, mask_2d, mask_3d)
        
        features = torch.cat([features_2d, features_3d], dim=1)
        masks = torch.cat([mask_2d, mask_3d], dim=1)
        graph_rep_indices = [0, mask_2d.size(1)]
        return features, masks, graph_rep_indices

import re # 为使用正则表达式而添加

def symbolic_to_numeric_translation(symbolic_force_field: str, num_atoms: int, device: torch.device) -> torch.Tensor:
    """
    解析一个代表符号化力场的结构化文本字符串，并将其转换为数值力张量。
    这对应于CCU框架中的 T 模块。

    Args:
        symbolic_force_field (str): 由LLM生成的字符串。
            预期格式: "atom 0: force [f_x, f_y, f_z]; atom 1: force [f_x, f_y, f_z]; ..."
        num_atoms (int): 原子总数，用于张量初始化。
        device (torch.device): 创建张量所用的设备。

    Returns:
        torch.Tensor: 一个数值力张量 F_num,t，形状为 [num_atoms, 3]。
    """
    # 初始化一个全零的力张量
    forces = torch.zeros(num_atoms, 3, device=device)
    
    # 使用正则表达式严格解析力模式，以处理空格和数值格式（浮点数/科学计数法）的变化
    pattern = re.compile(
        # 匹配 "atom", 空格, 数字 (atom_idx)
        r"atom\s+(\d+)\s*:\s*force\s*"
        # 匹配左方括号
        r"\[\s*"
        # 匹配第一个力分量 (fx)
        r"(-?\d+\.?\d*e?-?\d*)\s*,\s*"
        # 匹配第二个力分量 (fy)
        r"(-?\d+\.?\d*e?-?\d*)\s*,\s*"
        # 匹配第三个力分量 (fz)
        r"(-?\d+\.?\d*e?-?\d*)\s*\]"
    )
    
    matches = pattern.finditer(symbolic_force_field)
    
    found_atoms = set()
    for match in matches:
        try:
            atom_idx = int(match.group(1))
            fx = float(match.group(2))
            fy = float(match.group(3))
            fz = float(match.group(4))
            
            # 确保原子索引在有效范围内
            if 0 <= atom_idx < num_atoms:
                if atom_idx in found_atoms:
                    # 如果一个原子的力被多次定义，则覆盖。
                    # 这是一个简单的策略，也可以选择平均或抛出错误。
                    pass
                forces[atom_idx] = torch.tensor([fx, fy, fz], device=device)
                found_atoms.add(atom_idx)
            else:
                # 静默忽略越界的原子索引
                pass
        except (ValueError, IndexError):
            # 静默忽略格式错误的条目
            pass
            
    # 未在字符串中指定的原子将正确地拥有一个零力向量。
    return forces


class DifferentiableCoordinateUpdater(nn.Module):
    """
    使用欧拉积分执行一个可微的坐标更新步骤。
    这对应于CCU框架中的坐标更新步骤。
    """
    def __init__(self, initial_step_size: float = 0.01, is_learnable: bool = False):
        """
        Args:
            initial_step_size (float): 更新的初始步长 (eta)。
            is_learnable (bool): 如果为True，步长将成为一个可学习的参数。
        """
        super().__init__()
        if is_learnable:
            self.step_size = nn.Parameter(torch.tensor(initial_step_size))
        else:
            self.register_buffer('step_size', torch.tensor(initial_step_size))

    def forward(self, coords: torch.Tensor, forces: torch.Tensor) -> torch.Tensor:
        """
        根据提供的力来更新坐标: X_{t+1} = X_t + η * F_num,t

        Args:
            coords (torch.Tensor): 当前原子坐标 X_t, 形状为 [num_atoms, 3]。
            forces (torch.Tensor): 原子上的数值力 F_num,t, 形状为 [num_atoms, 3]。

        Returns:
            torch.Tensor: 更新后的原子坐标 X_{t+1}, 形状为 [num_atoms, 3]。
        """
        # 如果步长是可学习的，使用ReLU确保其非负
        step = F.relu(self.step_size) if isinstance(self.step_size, nn.Parameter) else self.step_size
        
        # 梯度路径为 X_{t+1} -> F_num,t -> LLM。
        # forces张量本身携带了来自LLM的梯度历史。
        updated_coords = coords + step * forces
        
        return updated_coords