"""
English medical MedCLIP model
Based on BiomedCLIP, supporting single ROI processing and four-stage training strategy
Fixed version: Resolves data type mismatch issues and integrates optimized ROI loss calculation
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from models.biomedclip_adapter import (
    build_biomedclip_vision_encoder,
    build_biomedclip_text_encoder,
    build_single_roi_processor
)
from models.losses import (
    InfoNCELoss,
    SingleROIContrastiveLoss,
    Fixed5NegativeLoss,
    ROIQualityLoss,
    StageAwareLoss,
    RECOMMENDED_LOSS_CONFIG
)

def safe_matmul(a, b):
    """Safe matrix multiplication ensuring consistent data types"""
    # Ensure both tensors are on the same device
    if a.device != b.device:
        b = b.to(a.device)

    # Ensure consistent data types
    if a.dtype != b.dtype:
        # Unify using float32
        a = a.float()
        b = b.float()

    return torch.matmul(a, b)

def ensure_float32(tensor):
    """Ensure tensor is of type float32"""
    if isinstance(tensor, torch.Tensor):
        return tensor.float()
    return tensor

class EnglishMedCLIP(nn.Module):
    """
    English medical CLIP model, specifically handling new data formats and training strategies
    Fixed version: Ensures consistent data types are used for all calculations
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        # Initialize BiomedCLIP components
        self.vision_encoder = build_biomedclip_vision_encoder(config.get('vision', {}))
        self.text_encoder = build_biomedclip_text_encoder(config.get('text', {}))
        self.roi_processor = build_single_roi_processor(config.get('roi_processor', {}))

        # Get feature dimensions
        self.vision_dim = self.vision_encoder.embed_dim
        self.text_dim = self.text_encoder.embed_dim
        self.projection_dim = config.get('projection_dim', 512)

        # Projection layers (initialized with float32)
        self.vision_projection = self._build_projection_head(self.vision_dim, self.projection_dim)
        self.text_projection = self._build_projection_head(self.text_dim, self.projection_dim)

        # ROI quality evaluator (integrated in roi_processor, independent version added here for direct calls)
        self.roi_quality_classifier = nn.Sequential(
            nn.Linear(self.projection_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 4)  # abnormal_region, abnormal_global, normal_global, no_roi
        )

        # Ensure all layers use float32
        self.vision_projection = self.vision_projection.float()
        self.text_projection = self.text_projection.float()
        self.roi_quality_classifier = self.roi_quality_classifier.float()

        # Loss functions
        loss_config = config.get('loss', {})
        self.temperature = loss_config.get('temperature', 0.07)

        # Stage-aware loss
        stage_config = loss_config.get('stage_config', RECOMMENDED_LOSS_CONFIG)
        self.stage_aware_loss = StageAwareLoss(stage_config)

        # Other loss functions
        self.global_contrastive_loss = InfoNCELoss(self.temperature)
        self.roi_contrastive_loss = SingleROIContrastiveLoss(self.temperature)
        self.negative_loss = Fixed5NegativeLoss(self.temperature)
        self.quality_loss = ROIQualityLoss()

        # Current training stage
        self.current_stage = 'warmup'

        # Initialize weights
        self._init_weights()

        print(f"EnglishMedCLIP initialized:")
        print(f"  Vision dim: {self.vision_dim}")
        print(f"  Text dim: {self.text_dim}")
        print(f"  Projection dim: {self.projection_dim}")
        print(f"  Temperature: {self.temperature}")

    def _build_projection_head(self, input_dim, output_dim):
        """Build projection head ensuring float32 usage"""
        projection = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.LayerNorm(output_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(output_dim, output_dim)
        )
        return projection.float()

    def _init_weights(self):
        """Initialize projection layer weights"""
        for module in [self.vision_projection, self.text_projection, self.roi_quality_classifier]:
            for layer in module:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.zeros_(layer.bias)

    def set_training_stage(self, stage_name):
        """Set training stage"""
        valid_stages = ['warmup', 'global_alignment', 'region_learning', 'fine_tuning']
        if stage_name not in valid_stages:
            raise ValueError(f"Invalid stage: {stage_name}. Must be one of {valid_stages}")

        self.current_stage = stage_name
        print(f"Set training stage to: {stage_name}")

        # Adjust model parameters based on stage
        self._adjust_parameters_for_stage(stage_name)

    def _adjust_parameters_for_stage(self, stage_name):
        """Adjust parameter trainability based on training stage"""
        if stage_name == 'warmup':
            # LoRA warmup stage: only train LoRA parameters and projection layers
            self._freeze_base_models()
            self._unfreeze_lora_and_projections()

        elif stage_name == 'global_alignment':
            # Global alignment stage: unfreeze part of vision encoder
            self._freeze_base_models()
            self._unfreeze_lora_and_projections()
            self.vision_encoder.unfreeze_last_n_blocks(2)

        elif stage_name == 'region_learning':
            # Region learning stage: unfreeze more layers
            self._freeze_base_models()
            self._unfreeze_lora_and_projections()
            self.vision_encoder.unfreeze_last_n_blocks(4)
            # Unfreeze ROI processor
            for param in self.roi_processor.parameters():
                param.requires_grad = True

        elif stage_name == 'fine_tuning':
            # End-to-end fine-tuning: unfreeze all parameters
            for param in self.parameters():
                param.requires_grad = True

    def _freeze_base_models(self):
        """Freeze base model parameters"""
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False

    def _unfreeze_lora_and_projections(self):
        """Unfreeze LoRA and projection layer parameters"""
        # Unfreeze LoRA parameters (if exist)
        for name, param in self.named_parameters():
            if 'lora' in name.lower():
                param.requires_grad = True

        # Unfreeze projection layers
        for param in self.vision_projection.parameters():
            param.requires_grad = True
        for param in self.text_projection.parameters():
            param.requires_grad = True
        for param in self.roi_quality_classifier.parameters():
            param.requires_grad = True

    def encode_image(self, images, return_raw=False):
        """Encode images ensuring consistent output type"""
        with torch.cuda.amp.autocast():
            # Ensure input is float32
            images = ensure_float32(images)

            raw_features = self.vision_encoder(images)
            raw_features = ensure_float32(raw_features)

            if return_raw:
                return raw_features

            projected_features = self.vision_projection(raw_features)
            projected_features = ensure_float32(projected_features)
            return projected_features

    def encode_text(self, input_ids, attention_mask, return_raw=False):
        """Encode text ensuring consistent output type"""
        with torch.cuda.amp.autocast():
            # Ensure input_ids and attention_mask are of correct type
            input_ids = input_ids.long()
            attention_mask = attention_mask.long()

            raw_features = self.text_encoder(input_ids, attention_mask)
            raw_features = ensure_float32(raw_features)

            if return_raw:
                return raw_features

            projected_features = self.text_projection(raw_features)
            projected_features = ensure_float32(projected_features)
            return projected_features

    def forward(self, batch):
        """
        Forward propagation
        Args:
            batch: Data batch from EnglishMedicalDataset
        """
        try:
            # Encode main image
            global_img_feat = self.encode_image(batch['image'])  # [B, D]
            global_img_feat = ensure_float32(global_img_feat)

            # Encode ROI image
            roi_img_feat = self.encode_image(batch['roi'])  # [B, D]
            roi_img_feat = ensure_float32(roi_img_feat)

            # Get ROI weights
            roi_weights = batch['roi_weight']  # [B]
            roi_weights = ensure_float32(roi_weights)
            roi_types = batch['roi_type']  # [B] list of strings

            # ROI quality assessment
            quality_scores = self.roi_quality_classifier(roi_img_feat)  # [B, 4]
            quality_scores = ensure_float32(quality_scores)

            # Fusion of global and ROI features
            fused_img_feat, _, attention_weights = self.roi_processor(
                global_img_feat, roi_img_feat, roi_weights
            )
            fused_img_feat = ensure_float32(fused_img_feat)
            attention_weights = ensure_float32(attention_weights)

            # Encode text features
            # 1. Report text
            report_feat = self.encode_text(batch['report_ids'], batch['report_mask'])  # [B, D]
            report_feat = ensure_float32(report_feat)

            # 2. Region description text
            region_feat = self.encode_text(batch['region_ids'], batch['region_mask'])  # [B, D]