import warnings
import torch
warnings.filterwarnings("ignore", category=FutureWarning)

# 计算扩散项（Diffusion）
D_0 = 0.2  # 标准扩散系数（m^2/s）
n = 0.5  # 扩散指数
T_0 = 273.15  # K，参考温度

def compute_diffusion(x, temp_k):
    """
    统一把 x／temp_k 变成标准形状，再计算扩散项：
      x: Tensor, 原始可能是
         - [seq, batch, n_loc, d_model]
         - [batch, seq, n_loc, d_model]
         - [1, batch, seq, n_loc, d_model]
      temp_k: Tensor, 原始可能是
         - [batch, seq]
         - [seq, batch]
    最终输出和 laplacian 形状一致。
    """
    # —— 1. 处理 x 维度 ——
    # 如果多了最外层的 “1” 批次维度，就先 squeeze 掉
    if x.dim() == 5 and x.shape[0] == 1:
        x = x.squeeze(0)

    # 如果 x 是 [seq, batch, n_loc, d_model]，就 permute 成 [batch, seq, ...]
    if x.dim() == 4 and x.shape[0] != temp_k.shape[0]:
        # 假设 x.shape == [seq, batch, ...]
        x = x.permute(1, 0, 2, 3)

    # 现在 x 应该是 [batch, seq, n_loc, d_model]
    batch, seq, n_loc, d_model = x.shape

    # —— 2. 处理 temp_k 维度 ——
    # 如果 temp_k 是 [seq, batch]，就转成 [batch, seq]
    if temp_k.dim() == 2 and temp_k.shape[0] == seq and temp_k.shape[1] == batch:
        temp_k = temp_k.permute(1, 0)

    # 现在 temp_k 应该是 [batch, seq]
    assert temp_k.dim() == 2 and temp_k.shape == (batch, seq), \
           f"temp_k shape must be [batch, seq], got {tuple(temp_k.shape)}"

    # —— 3. 计算 D，并扩展到 [batch, seq, n_loc, d_model] ——
    D = D_0 * (temp_k / T_0) ** n       # [batch, seq]
    D = D.unsqueeze(-1).unsqueeze(-1)   # [batch, seq, 1, 1]

    # —— 4. 计算拉普拉斯算子 ——
    laplacian = (
        torch.roll(x, shifts=1, dims=2) +
        torch.roll(x, shifts=-1, dims=2) +
        torch.roll(x, shifts=1, dims=3) +
        torch.roll(x, shifts=-1, dims=3) -
        4 * x
    )                                   # [batch, seq, n_loc, d_model]

    # —— 5. 返回 diffusion ——
    return D * laplacian                # 广播后也是 [batch, seq, n_loc, d_model]



# 计算对流项（Advection）
def compute_advection(x, velocity_field):
    # 计算梯度
    grad_x = torch.diff(x, dim=3, prepend=x[..., :1])  # x 方向的梯度
    grad_y = torch.diff(x, dim=2, prepend=x[..., :1, :])  # y 方向的梯度

    # 提取速度分量
    v_x = velocity_field[:, 0, :, :].unsqueeze(1)  # 东向分量 u
    v_y = velocity_field[:, 1, :, :].unsqueeze(1)  # 北向分量 v

    # 计算对流项（advection）
    advection = - (v_x * grad_x + v_y * grad_y)

    return advection
