import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicSpatioTemporalAttention(nn.Module):
    def __init__(self, d=12, base_size=(128,128), max_time=12):
        super().__init__()
        self.d = d
        self.max_time = max_time
        
        # 时间编码器扩展为4D兼容
        self.time_embed = nn.Sequential(
            nn.Embedding(max_time, 32),  # 离散时间步嵌入
            nn.Linear(32, d),
            nn.GELU()
        )
        
        # 4D卷积网络（3D空间+1D时间）
        self.spatial_cnn = nn.Sequential(
            nn.Conv3d(6, 16, kernel_size=(3,3,3), padding=1),
            TemporalSepConv3D(16, 16),  # 新增时间分离卷积
            nn.InstanceNorm3d(16),
            nn.GELU(),
            nn.AdaptiveMaxPool3d((None, 128, 128)),
            
            nn.Conv3d(16, 32, kernel_size=(3,3,3), padding=1),
            TemporalSepConv3D(32, 32),  # 新增时间分离卷积
            nn.InstanceNorm3d(32),
            nn.GELU(),
            nn.AdaptiveMaxPool3d((None, 128, 128)),
        )
        
        # 新增全局特征提取层
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(32,d)
        )

        # 4D时空注意力
        self.attn_4d = FourDAttention4D(dim=d, max_time=max_time)
        
        # 时间上采样器
        self.time_upsampler = nn.Sequential(
            nn.Conv3d(32, 32, kernel_size=(3,3,3), padding=1),
            nn.Upsample(scale_factor=(1,2,2), mode='trilinear'),
            nn.Conv3d(32, 1, kernel_size=1)
        )

    def forward(self, ct):
        """
        Input: 
            ct: (B, 1, H, W, D) 
        Output:
            (B, 12, 1, H, W, D)  # 12个时间步的3D影像
        """
        B, _, H, W, D = ct.shape
        
        # 生成全部时间嵌入
        time_steps = torch.arange(0, self.max_time, device=ct.device)  # (12,)
        time_emb = self.time_embed(time_steps)  # (12, d)
        
        # 空间特征提取
        spatial_feat = self.spatial_cnn(ct.permute(0,1,4,2,3))  # (B,32,H',W',D')
        spatial_feat = spatial_feat.permute(0,1,3,4,2)

        # 新增：提取全局特征向量
        global_feat = self.global_pool(spatial_feat)  # [B, d]
        # 扩展时间维度
        spatial_feat = spatial_feat.unsqueeze(1)  # (B,1,32,H',W',D')
        
        # 4D注意力计算
        temporal_feat = self.attn_4d(spatial_feat, time_emb)  # (B,1,32,H',W',D')
        
        B, T, C, H, W, D = temporal_feat.shape
        temporal_feat = temporal_feat.reshape(B*T, C, H, W, D)
        
        # 通过上采样器
        output = self.time_upsampler(temporal_feat)  # [B*T, 1, H', W', D']
        
        # 恢复时间维度
        output = output.reshape(B, T, 1, H*2, W*2, D)
        return output, global_feat

class TemporalSepConv3D(nn.Module):
    """时间分离卷积"""
    def __init__(self, in_c, out_c):
        super().__init__()
        self.spatial_conv = nn.Conv3d(in_c, out_c, kernel_size=(1,3,3), padding=(0,1,1))
        self.temporal_conv = nn.Conv3d(out_c, out_c, kernel_size=(3,1,1), padding=(1,0,0))
        
    def forward(self, x):
        x = self.spatial_conv(x)
        x = self.temporal_conv(x)
        return x

class FourDAttention4D(nn.Module):
    def __init__(self, dim=8, max_time=12):
        super().__init__()
        self.d = dim
        self.t = max_time
        # 调整输入通道数为 (C + d)
        self.to_qkv = nn.Conv3d(32 + dim, 96, 1)  # 32是空间特征的通道数
        self.attn_scale = dim ** -0.5
        
    def forward(self, x, t):
        B, _, C, H, W, D = x.shape
        
        # 调整维度为5D (合并batch和时间维度)
        x_merged = x.reshape(B*1, C, H, W, D)  # 初始时间维度为1
        
        # 时间编码广播
        t = t.reshape(1, self.t, self.d, 1, 1, 1).expand(B, -1, -1, H, W, D)
        t_merged = t.reshape(B*self.t, self.d, H, W, D)
        
        # 合并特征
        x_combined = torch.cat([
            x_merged.unsqueeze(1).expand(-1, self.t, -1, -1, -1, -1).reshape(B*self.t, C, H, W, D),
            t_merged
        ], dim=1)  # [B*T, C+d, H, W, D]
        # 生成QKV
        qkv = self.to_qkv(x_combined)  # [B*T, 96, H, W, D]
        q, k, v = qkv.chunk(3, dim=1)  # 各[B*T, 32, H, W, D]
        
        # 恢复时间维度
        q = q.reshape(B, self.t, 32, H, W, D)
        k = k.reshape(B, self.t, 32, H, W, D)
        v = v.reshape(B, self.t, 32, H, W, D)
        
        # 注意力计算
        attn = torch.einsum('btchwd,btchwq->btdq', q, k) * self.attn_scale
        attn = F.softmax(attn, dim=-1)
        
        # 特征聚合
        ctx = torch.einsum('btdq,btchwq->btchwd', attn, v)
        return ctx

class EnhancedClinicalFusion(nn.Module):
    def __init__(self, clin_dim=39, d=12, time_steps=12):
        super().__init__()
        # 临床数据编码
        self.clin_encoder = nn.Sequential(
            nn.Linear(clin_dim, 128),
            nn.LayerNorm(128),
            nn.GELU(),
            nn.Linear(128,d)
        )

        # 重建影像特征提取
        self.recon_encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3,3,3), padding=1),
            TemporalSepConv3D(16, 16),
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(16, d)
        ) 

        # 双路径对齐模块
        self.align_transformer = nn.TransformerEncoderLayer(
            d_model=d, 
            nhead=4,
            dim_feedforward=256
        )

        # 动态融合门控
        self.fusion_gate = nn.Parameter(torch.randn(2, d))

    def forward(self, img_feat, clin_feat, recon_ct,d=12):
        """
        输入:
            img_global: [B, d] 原始影像全局特征
            clin_feat: [B, clin_dim] 临床数据
            recon_ct: [B, T, 1, H, W, D] 重建影像
        输出:
            fused_feat: [B, d] 融合特征
            align_loss: 双路径对齐损失
        """
        B, T = recon_ct.shape[:2]
        
        # 临床数据编码
        # print(clin_feat.shape)
        h_clin = self.clin_encoder(clin_feat)  # [B, d]
        
        # 重建影像特征提取
        recon_feats = []
        for t in range(T):
            feat = self.recon_encoder(recon_ct[:, t])  # [B, d]
            recon_feats.append(feat)
        h_recon = torch.stack(recon_feats, dim=1)      # [B, T, d]
        
        # 时间维度对齐
        h_recon = self.align_transformer(h_recon)      # [B, T, d]
        h_recon = h_recon.mean(dim=1)                  # [B, d]
        
        # ==============================
        # 模态对齐损失（最大化余弦相似度）
        # ==============================
        # 计算影像特征与临床特征的余弦相似度
        cosine_sim = F.cosine_similarity(img_feat, h_clin, dim=1)
        align_loss = -cosine_sim.mean()  # 负号使损失最小化等价于相似度最大化
        
        # ==============================
        # 模态解纠缠损失（最小化相关性）
        # ==============================
        # 计算特征协方差矩阵的差异
        cov_img = torch.mm(img_feat.T, img_feat)  # [d, d]
        cov_clin = torch.mm(h_clin.T, h_clin)     # [d, d]
        dis_loss = torch.norm(cov_img - cov_clin, p='fro') / (d ** 2)

        gate = F.softmax(self.fusion_gate, dim=0)
        fused_feat = gate[0] * img_feat + gate[1] * h_recon + h_clin
        
        return fused_feat, align_loss, dis_loss

class PrognosisHead(nn.Module):
    def __init__(self, d=12, max_time=12):
        super().__init__()
        # 复发分类头
        self.recur_head = nn.Sequential(
            nn.Linear(d, 32),
            nn.GELU(),
            nn.Linear(32, max_time)  # 每个时间点一个输出
        )
        
        # 生存回归头
        self.surv_head = nn.Sequential(
            nn.Linear(d, 32),
            nn.GELU(),
            nn.Linear(32, max_time)  # 每个时间点一个输出 
        )
        
    def forward(self, x):
        recur_logits = self.recur_head(x)  # [B, T]
        surv_logits = self.surv_head(x)     # [B, 1]
        return recur_logits, surv_logits

class IntegratedModel(nn.Module):
    def __init__(self, clin_dim=48, d=12, time_steps=12):
        super().__init__()
        # 影像处理主干
        self.img_net = DynamicSpatioTemporalAttention(d=d, max_time=time_steps)
        
        # 增强融合模块
        self.fusion = EnhancedClinicalFusion(clin_dim=clin_dim, d=d, time_steps=time_steps)
        
        # 多任务预测头
        self.head = PrognosisHead(d=d, max_time=time_steps)

    def forward(self, ct, clin_feat):
        # 影像处理
        recon_ct, img_global = self.img_net(ct)  # [B,T,1,H,W,D], [B,d]
        
        # 双路径特征融合
        fused_feat, align_loss, dis_loss = self.fusion(img_global, clin_feat, recon_ct)
        
        # 预后预测
        recur_logits, surv_logits = self.head(fused_feat)
        
        return {
            'recon_ct': recon_ct,
            'recur_logits': recur_logits,
            'surv_logits': surv_logits,
            'align_loss': align_loss,
            'dis_loss':dis_loss
        }
    

# 测试代码
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 测试参数配置
    B, C, D, H, W = 1, 6, 101, 256, 256  # batch_size=2, channels=6, depth=101, height=256, width=256
    clin_dim, d, T = 48, 12, 12          # 临床特征维度=48, 融合维度=12, 时间步=12
    
    # 初始化完整模型
    model = IntegratedModel(clin_dim=clin_dim, d=d).to(device)
    
    # 生成模拟输入数据（注意维度顺序）
    test_ct = torch.randn(B, C, H, W, D).to(device)     # [B, C, D, H, W]
    test_clin = torch.randn(B, clin_dim).to(device)      # [B, clin_dim]

    print("="*40)
    print("测试输入维度:")
    print(f"CT扫描维度: {test_ct.shape}")
    print(f"临床数据维度: {test_clin.shape}")
    print("="*40 + "\n")
    
    # 前向传播测试
    with torch.no_grad():
        outputs = model(test_ct, test_clin)

    
    # 验证各模块输出维度
    print("="*40)
    print("关键模块维度验证:")
    print(f"1. 影像重建输出维度: {outputs['recon_ct'].shape} | 期望: ({B}, {T}, 1, 128, 128, {D})")
    print(f"2. 全局特征维度: {outputs['recur_probs'].shape[0]} | 期望: {B}")
    print(f"3. 复发概率输出维度: {outputs['recur_probs'].shape} | 期望: ({B}, {T})")
    print(f"4. 生存评分输出维度: {outputs['surv_score'].shape} | 期望: ({B}, 1)")
    print("="*40)

    current_mem = torch.cuda.memory_allocated() 
    print(f"当前显存占用: {current_mem / 1024**2:.2f} MB")
    reserved_mem = torch.cuda.memory_reserved()
    print(f"缓存显存: {reserved_mem / 1024**2:.2f} MB")
    
    # 自动化断言验证
    assert outputs['recon_ct'].shape == (B, T, 1, 256, 256, D), "影像重建维度错误"
    assert outputs['recur_probs'].shape == (B, T), "复发预测维度错误"
    assert outputs['surv_score'].shape == (B, 1), "生存评分维度错误"
    assert outputs['align_loss'].dim() == 0, "对齐损失应为标量"
    assert outputs['dis_loss'].dim() == 0, "解纠缠损失应为标量"
    
    print("\n所有测试通过！模型维度符合设计要求")