# models/medclip.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.losses import (
    InfoNCELoss, RegionContrastiveLoss, SelfDistillLoss,
    HardNegativeMiningLoss, AdaptiveTemperatureLoss
)

class MedCLIP(nn.Module):
    def __init__(self, vision_encoder, text_encoder, region_head, 
                 projection_dim=512, temperature=0.07):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.text_encoder = text_encoder   # DualTextEncoder
        self.region_head = region_head
        self.temperature = temperature

        D_img = vision_encoder.embed_dim
        D_txt = self.text_encoder.encoder_en.config.hidden_size
        
        # 投影层
        self.proj_img = nn.Sequential(
            nn.Linear(D_img, projection_dim),
            nn.LayerNorm(projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
        
        self.proj_txt = nn.Sequential(
            nn.Linear(D_txt, projection_dim),
            nn.LayerNorm(projection_dim),
            nn.ReLU(), 
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
        
        # 为region head添加单独的投影层
        self.proj_region = nn.Sequential(
            nn.Linear(projection_dim, projection_dim),
            nn.LayerNorm(projection_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(projection_dim, projection_dim)
        )
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        """初始化投影层权重"""
        for module in [self.proj_img, self.proj_txt, self.proj_region]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)

    def encode_image(self, images):
        """编码图像获取特征"""
        with torch.cuda.amp.autocast():
            img_feat = self.vision_encoder(images)
            return self.proj_img(img_feat)

    def encode_text(self, input_ids, attention_mask, lang):
        """编码文本获取特征"""
        with torch.cuda.amp.autocast():
            txt_feat = self.text_encoder(input_ids, attention_mask, lang=lang)
            return self.proj_txt(txt_feat)

    def encode_regions(self, rois):
        """编码ROI区域获取特征"""
        # rois: [B, R, C, H, W]
        B, R = rois.shape[:2]
        rois_flat = rois.view(-1, *rois.shape[2:])  # [B*R, C, H, W]
        
        with torch.cuda.amp.autocast():
            # 通过vision encoder
            roi_feats = self.vision_encoder(rois_flat)  # [B*R, D]
            roi_feats = self.proj_img(roi_feats)  # [B*R, D]
            
            # 重塑为 [B, R, D]
            roi_feats = roi_feats.view(B, R, -1)
            
            # 通过region head处理
            roi_feats = self.region_head(roi_feats)  # [B, R, D]
            roi_feats = self.proj_region(roi_feats)  # [B, R, D]
            
        return roi_feats

    def forward(self, batch, stage_config=None):
        """
        前向传播，根据stage_config计算不同的loss
        """
        if stage_config is None:
            stage_config = {'losses': [], 'weights': {}}
        
        img = batch['image']  # [B, 3, H, W]
        
        # 编码全局图像特征
        global_img_feat = self.encode_image(img)  # [B, D]
        
        # 存储所有特征用于loss计算
        features = {'global_img': global_img_feat}
        
        # 编码全局文本特征
        if 'pos_ids_en' in batch:
            global_txt_feat_en = self.encode_text(
                batch['pos_ids_en'], batch['pos_mask_en'], lang='en'
            )
            features['global_txt_en'] = global_txt_feat_en
            
        if 'pos_ids_zh' in batch:
            global_txt_feat_zh = self.encode_text(
                batch['pos_ids_zh'], batch['pos_mask_zh'], lang='zh'
            )
            features['global_txt_zh'] = global_txt_feat_zh

        # 编码短描述文本特征
        if 'short_ids_en' in batch:
            short_txt_feat_en = self.encode_text(
                batch['short_ids_en'], batch['short_mask_en'], lang='en'
            )
            features['short_txt_en'] = short_txt_feat_en
            
        if 'short_ids_zh' in batch:
            short_txt_feat_zh = self.encode_text(
                batch['short_ids_zh'], batch['short_mask_zh'], lang='zh'
            )
            features['short_txt_zh'] = short_txt_feat_zh

        # 如果有ROI数据，编码区域特征
        if 'rois' in batch and batch['rois'].numel() > 0:
            roi_img_feats = self.encode_regions(batch['rois'])  # [B, R, D]
            features['roi_img'] = roi_img_feats
            
            # 编码区域文本特征
            if 'region_ids_en' in batch:
                B, R = batch['region_ids_en'].shape[:2]
                region_ids_flat = batch['region_ids_en'].view(-1, batch['region_ids_en'].shape[-1])
                region_mask_flat = batch['region_mask_en'].view(-1, batch['region_mask_en'].shape[-1])
                
                region_txt_feat_en = self.encode_text(region_ids_flat, region_mask_flat, lang='en')
                region_txt_feat_en = region_txt_feat_en.view(B, R, -1)  # [B, R, D]
                features['roi_txt_en'] = region_txt_feat_en
                
            if 'region_ids_zh' in batch:
                B, R = batch['region_ids_zh'].shape[:2]
                region_ids_flat = batch['region_ids_zh'].view(-1, batch['region_ids_zh'].shape[-1])
                region_mask_flat = batch['region_mask_zh'].view(-1, batch['region_mask_zh'].shape[-1])
                
                region_txt_feat_zh = self.encode_text(region_ids_flat, region_mask_flat, lang='zh')
                region_txt_feat_zh = region_txt_feat_zh.view(B, R, -1)  # [B, R, D]
                features['roi_txt_zh'] = region_txt_feat_zh

        # 编码负样本特征（用于困难负样本挖掘）
        if 'neg_ids_en' in batch:
            B, N = batch['neg_ids_en'].shape[:2]
            neg_ids_flat = batch['neg_ids_en'].view(-1, batch['neg_ids_en'].shape[-1])
            neg_mask_flat = batch['neg_mask_en'].view(-1, batch['neg_mask_en'].shape[-1])
            
            neg_txt_feat_en = self.encode_text(neg_ids_flat, neg_mask_flat, lang='en')
            neg_txt_feat_en = neg_txt_feat_en.view(B, N, -1)  # [B, N, D]
            features['neg_txt_en'] = neg_txt_feat_en
            
        if 'neg_ids_zh' in batch:
            B, N = batch['neg_ids_zh'].shape[:2]
            neg_ids_flat = batch['neg_ids_zh'].view(-1, batch['neg_ids_zh'].shape[-1])
            neg_mask_flat = batch['neg_mask_zh'].view(-1, batch['neg_mask_zh'].shape[-1])
            
            neg_txt_feat_zh = self.encode_text(neg_ids_flat, neg_mask_flat, lang='zh')
            neg_txt_feat_zh = neg_txt_feat_zh.view(B, N, -1)  # [B, N, D]
            features['neg_txt_zh'] = neg_txt_feat_zh

        # 计算损失
        return self._compute_losses(features, stage_config)

    def _compute_losses(self, features, stage_config):
        """计算损失函数"""
        total_loss = torch.tensor(0.0, device=features['global_img'].device, requires_grad=True)
        loss_dict = {}
        
        # 获取损失配置
        enabled_losses = stage_config.get('losses', [])
        loss_weights = stage_config.get('weights', {})
        
        # 全局对比损失 - 英文
        if 'global_en' in enabled_losses and 'global_txt_en' in features:
            if 'neg_txt_en' in features:
                # 使用困难负样本挖掘
                loss_fn = HardNegativeMiningLoss(temperature=self.temperature)
                loss_global_en = loss_fn(
                    features['global_img'], 
                    features['global_txt_en'],
                    features['neg_txt_en']
                )
            else:
                # 使用标准InfoNCE
                loss_fn = InfoNCELoss(temperature=self.temperature)
                loss_global_en = loss_fn(features['global_img'], features['global_txt_en'])
            
            weight = loss_weights.get('global_en_weight', 1.0)
            loss_dict['global_en_loss'] = loss_global_en * weight
            total_loss = total_loss + loss_dict['global_en_loss']
            
        # 全局对比损失 - 中文
        if 'global_zh' in enabled_losses and 'global_txt_zh' in features:
            if 'neg_txt_zh' in features:
                loss_fn = HardNegativeMiningLoss(temperature=self.temperature)
                loss_global_zh = loss_fn(
                    features['global_img'], 
                    features['global_txt_zh'],
                    features['neg_txt_zh']
                )
            else:
                loss_fn = InfoNCELoss(temperature=self.temperature)
                loss_global_zh = loss_fn(features['global_img'], features['global_txt_zh'])
                
            weight = loss_weights.get('global_zh_weight', 1.0)
            loss_dict['global_zh_loss'] = loss_global_zh * weight
            total_loss = total_loss + loss_dict['global_zh_loss']
            
        # 短描述对比损失 - 英文
        if 'short_en' in enabled_losses and 'short_txt_en' in features:
            loss_fn = InfoNCELoss(temperature=self.temperature)
            loss_short_en = loss_fn(features['global_img'], features['short_txt_en'])
            weight = loss_weights.get('short_en_weight', 0.5)
            loss_dict['short_en_loss'] = loss_short_en * weight
            total_loss = total_loss + loss_dict['short_en_loss']
            
        # 短描述对比损失 - 中文
        if 'short_zh' in enabled_losses and 'short_txt_zh' in features:
            loss_fn = InfoNCELoss(temperature=self.temperature)
            loss_short_zh = loss_fn(features['global_img'], features['short_txt_zh'])
            weight = loss_weights.get('short_zh_weight', 0.5)
            loss_dict['short_zh_loss'] = loss_short_zh * weight
            total_loss = total_loss + loss_dict['short_zh_loss']
            
        # 区域对比损失 - 英文
        if ('region_en' in enabled_losses and 
            'roi_img' in features and 'roi_txt_en' in features):
            loss_fn = RegionContrastiveLoss(temperature=self.temperature)
            loss_region_en = loss_fn(features['roi_img'], features['roi_txt_en'])
            weight = loss_weights.get('region_en_weight', 1.0)
            loss_dict['region_en_loss'] = loss_region_en * weight
            total_loss = total_loss + loss_dict['region_en_loss']
            
        # 区域对比损失 - 中文
        if ('region_zh' in enabled_losses and 
            'roi_img' in features and 'roi_txt_zh' in features):
            loss_fn = RegionContrastiveLoss(temperature=self.temperature)
            loss_region_zh = loss_fn(features['roi_img'], features['roi_txt_zh'])
            weight = loss_weights.get('region_zh_weight', 1.0)
            loss_dict['region_zh_loss'] = loss_region_zh * weight
            total_loss = total_loss + loss_dict['region_zh_loss']
            
        # 自蒸馏损失
        if ('self_distill' in enabled_losses and 
            'global_img' in features and 'roi_img' in features):
            loss_fn = SelfDistillLoss()
            loss_distill = loss_fn(features['global_img'], features['roi_img'])
            weight = loss_weights.get('self_distill_weight', 0.1)
            loss_dict['self_distill_loss'] = loss_distill * weight
            total_loss = total_loss + loss_dict['self_distill_loss']

        # 确保总损失有梯度
        if total_loss.item() == 0.0:
            print('[Warning] No valid loss calculated for this batch!')
            # 创建一个小的虚拟损失确保梯度计算
            dummy_loss = 0.0
            for param in self.parameters():
                if param.requires_grad:
                    dummy_loss = dummy_loss + 0.0 * param.sum()
            total_loss = total_loss + dummy_loss
            
        loss_dict['loss'] = total_loss
        
        # 添加损失统计信息
        loss_dict['num_losses'] = len([k for k in loss_dict.keys() if k.endswith('_loss')])
        loss_dict['avg_loss'] = total_loss / max(1, loss_dict['num_losses'])
        
        return loss_dict

    def get_similarity(self, img_feat, txt_feat):
        """计算图像和文本特征的相似度"""
        img_norm = F.normalize(img_feat, dim=-1)
        txt_norm = F.normalize(txt_feat, dim=-1)
        return torch.matmul(img_norm, txt_norm.T)