

import torch
import torch.nn as nn
import math

# --- 1. 定义基础组件: Timestep Embedding ---
# 这是标准的正弦位置编码，让模型理解 t=0.1 和 t=0.9 的区别
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

# --- 2. 核心组件: ResNet Block with AdaLN ---
# 这是一个带有“门控”的残差块，是生成模型的核心
class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, cond_emb_dim, dropout=0.1):
        super().__init__()
        
        # 这里的 cond_emb_dim 应该是你 time_dim + phys_dim 的总和
        self.mlp = nn.Sequential(
            nn.SiLU(),
            # 输出 scale 和 shift，维度为 dim_out * 2
            nn.Conv1d(cond_emb_dim, dim_out * 2, 1) # 使用 1x1 卷积处理序列维度的条件
        )

        self.block1 = nn.Sequential(
            nn.Conv1d(dim, dim_out, 3, padding=1),
            nn.GroupNorm(8, dim_out),
            nn.SiLU(),
        )

        self.block2 = nn.Sequential(
            nn.Dropout(dropout),
            nn.Conv1d(dim_out, dim_out, 3, padding=1),
            nn.GroupNorm(8, dim_out),
            nn.SiLU(),
        )

        self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, cond_emb):
        """
        x: [B, Dim, L]
        cond_emb: [B, Cond_Dim, L] -> 已经是融合了时间与序列级物理特征的 Tensor
        """
        
        h = self.block1(x)
        
        # --- 序列级 AdaLN 调制 ---
        # 1. 通过 1x1 卷积获取每个时间步的 scale 和 shift
        # condition: [B, Dim_Out * 2, L]
        condition = self.mlp(cond_emb) 
        scale, shift = condition.chunk(2, dim=1)
        
        # 2. 对每个时间步施加不同的物理影响
        # 这里的 scale 和 shift 是随时间 L 变化的
        h = h * (1 + scale) + shift 
        
        h = self.block2(h)
        return h + self.res_conv(x)
  

class TemporalForceNet(nn.Module):
    def __init__(self, state_dim, phys_dim, hidden_dim=64, max_len=96):
        super().__init__()
        
        
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.SiLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
        )

        
        self.horizon_emb = nn.Parameter(torch.randn(1, max_len, hidden_dim))
        
         
        self.input_proj = nn.Conv1d(state_dim * 2, hidden_dim, 1)
        # self.input_proj = nn.Conv1d(state_dim , hidden_dim, 1)
        self.phys_adapter = nn.Sequential(
            nn.Conv1d(phys_dim, hidden_dim, 1),
            nn.SiLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 1)
        )

         
        self.downs = nn.ModuleList([
            ResnetBlock(hidden_dim, hidden_dim, hidden_dim),
            ResnetBlock(hidden_dim, hidden_dim * 2, hidden_dim),
        ])

        
        self.attention = nn.MultiheadAttention(embed_dim=hidden_dim*2, num_heads=4, batch_first=True)
        self.attn_norm = nn.LayerNorm(hidden_dim*2)

        # 6. 输出投影
        self.output_proj = nn.Sequential(
            nn.GroupNorm(8, hidden_dim*2), # 注意维度匹配
            nn.SiLU(),
            nn.Conv1d(hidden_dim*2, state_dim * 2, 1)
        )
        
    def update_for_new_domain(self, new_state_dim, new_phys_dim):
         
        # 保持 hidden_dim 不变，这是你的“物理骨架”
        hidden_dim = self.input_proj.out_channels 
        
        # 重新初始化输入投影层
        self.input_proj = nn.Conv1d(new_state_dim * 2, hidden_dim, 1)
        
        # 重新初始化物理适配器（如果 Weather 的物理特征维度不同）
        self.phys_adapter[0] = nn.Conv1d(new_phys_dim, hidden_dim, 1) 
        
        # 重新初始化输出投影层
        self.output_proj[-1] = nn.Conv1d(hidden_dim * 2, new_state_dim * 2, 1) 
        
        print(f"Model adapted to new domain: state_dim={new_state_dim}")

    def forward(self, x_v, t, phys_cond):
         
        x_v = x_v[:,-12:,:]
        phys_cond = phys_cond[:,-12:,:]
        B, L, _= x_v.shape
        
        
        
        h = self.input_proj(x_v.transpose(1, 2)).transpose(1, 2) # [B, L, hidden]
        
         
        h = h + self.horizon_emb[:, :L, :] 
        
        
        t_emb = self.time_mlp(t).unsqueeze(1) # [B, 1, hidden]
        p_emb = self.phys_adapter(phys_cond.transpose(1, 2)).transpose(1, 2) # [B, L, hidden]
        gate = torch.sigmoid(t_emb) 
        cond_emb = (gate * p_emb) + ((1 - gate) * t_emb)
        

        h_res = h.transpose(1, 2) # 转回 [B, C, L] 适配 Conv1d
        cond_res = cond_emb.transpose(1, 2)
        for block in self.downs:
            h_res = block(h_res, cond_res) 
        h = h_res.transpose(1, 2) # [B, L, hidden]

        
        attn_out, _ = self.attention(h, h, h)
        h = self.attn_norm(h + attn_out)

        
        out = self.output_proj(h.transpose(1, 2))
         
        return out.transpose(1, 2)
    