"""
English Medical MedCLIP Trainer
Implements four-stage training strategy: warmup, global_alignment, region_learning, fine_tuning
Fixed version: Resolves data type mismatch and Event object serialization issues, integrates optimized ROI quality assessment
"""

import os
import torch
import random
import numpy as np
import time
import datetime
from torch.utils.data import DataLoader, Subset, random_split
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
import warnings
from collections import defaultdict

def move_batch_to_device(batch, device):
    """Optimized batch data movement to ensure consistent data types"""
    out = {}
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            # Ensure all tensors use consistent data types
            if v.dtype == torch.float64:
                v = v.float()  # Convert to float32
            elif v.dtype in [torch.int64, torch.long] and 'ids' not in k:
                # Only tensors related to input_ids remain long type
                v = v.float()
            out[k] = v.to(device, non_blocking=True)
        else:
            out[k] = v
    return out

def safe_tensor_operation(tensor1, tensor2, operation='matmul'):
    """Safe tensor operation to ensure data type matching"""
    # Ensure both tensors are on the same device
    if tensor1.device != tensor2.device:
        tensor2 = tensor2.to(tensor1.device)

    # Ensure data types match
    if tensor1.dtype != tensor2.dtype:
        # Prefer using float32
        target_dtype = torch.float32
        tensor1 = tensor1.to(target_dtype)
        tensor2 = tensor2.to(target_dtype)

    # Perform operation
    if operation == 'matmul':
        return torch.matmul(tensor1, tensor2)
    elif operation == 'add':
        return tensor1 + tensor2
    elif operation == 'mul':
        return tensor1 * tensor2
    else:
        raise ValueError(f"Unsupported operation: {operation}")

def compute_english_medical_recall(model, dataloader, device, K=(1, 5, 10), max_batches=None):
    """
    Calculate recall metrics for English medical data
    Fixed version: Ensure consistent data types
    """
    model.eval()
    img_embeds = []
    txt_embeds = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if max_batches and i >= max_batches:
                break

            try:
                batch = move_batch_to_device(batch, device)

                # Get image embeddings
                img_feat = model.get_image_embeddings(batch['image'])

                # Get text embeddings (using reports)
                txt_feat = model.get_text_embeddings(batch['report_ids'], batch['report_mask'])

                # Ensure consistent feature types
                img_feat = img_feat.float()
                txt_feat = txt_feat.float()

                # Check feature validity
                if torch.isnan(img_feat).any() or torch.isnan(txt_feat).any():
                    print(f"[Warning] NaN detected in batch {i}, skipping...")
                    continue

                if torch.isinf(img_feat).any() or torch.isinf(txt_feat).any():
                    print(f"[Warning] Inf detected in batch {i}, skipping...")
                    continue

                img_embeds.append(img_feat)
                txt_embeds.append(txt_feat)

            except Exception as e:
                print(f"[Warning] Error in batch {i}: {e}")
                continue

    if not img_embeds:
        print("[Warning] No valid embeddings for recall calculation")
        return {}

    img_embeds = torch.cat(img_embeds, dim=0).float()
    txt_embeds = torch.cat(txt_embeds, dim=0).float()

    # Ensure consistent feature dimensions
    if img_embeds.shape[0] != txt_embeds.shape[0]:
        min_size = min(img_embeds.shape[0], txt_embeds.shape[0])
        img_embeds = img_embeds[:min_size]
        txt_embeds = txt_embeds[:min_size]

    # Safely calculate similarity matrix
    try:
        sim = safe_tensor_operation(img_embeds, txt_embeds.T, 'matmul')
        labels = torch.arange(sim.shape[0], device=sim.device)

        recalls = {}
        for k in K:
            if k <= sim.shape[0]:
                # Image to text retrieval
                topk = sim.topk(k, dim=1).indices
                r = (topk == labels.unsqueeze(1)).any(dim=1).float().mean().item()
                recalls[f'img2txt_R@{k}'] = r

                # Text to image retrieval
                topk = sim.t().topk(k, dim=1).indices
                r = (topk == labels.unsqueeze(1)).any(dim=1).float().mean().item()
                recalls[f'txt2img_R@{k}'] = r

        return recalls

    except Exception as e:
        print(f"[Warning] Error computing similarity matrix: {e}")
        return {}

def compute_roi_quality_accuracy_final(model, dataloader, device, max_batches=None):
    """Final optimized ROI quality assessment - adapted to dataset characteristics"""
    model.eval()

    correct = 0
    total = 0
    strategy_stats = defaultdict(int)

    # Optimized strategy quality mapping (based on dataset characteristics)
    strategy_to_quality = {
        # Perfect matches (expected types)
        'normal_perfect_match': 1,
        'abnormal_perfect_match': 1,

        # Data design matches (opposite types, but this is normal)
        'normal_data_design': 1,      # Increased weight as this is a dataset characteristic
        'abnormal_data_design': 1,    # Increased weight as this is a dataset characteristic

        # Similarity matches
        'normal_similarity_match': 0,
        'abnormal_similarity_match': 0,

        # Fallback matches
        'normal_quality_fallback': 0,
        'abnormal_quality_fallback': 0,
        'normal_final_fallback': 0,
        'abnormal_final_fallback': 0,

        # Failure cases
        'no_roi_available': 0,
        'load_failed': 0,
        'no_roi_found': 0
    }

    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if max_batches and i >= max_batches:
                break

            try:
                batch = move_batch_to_device(batch, device)

                # Count ROI strategies
                for roi_type in batch['roi_type']:
                    strategy_stats[roi_type] += 1

                # Forward pass
                outputs = model(batch)
                features = outputs.get('features', {})

                if 'quality_scores' not in features:
                    continue

                quality_scores = features['quality_scores']
                roi_types = batch['roi_type']

                # Generate labels based on optimized strategy
                true_labels = []
                for roi_type in roi_types:
                    quality_label = strategy_to_quality.get(roi_type, 0)
                    true_labels.append(quality_label)

                # Prediction (binary classification)
                if quality_scores.shape[1] >= 2:
                    predicted = torch.argmax(quality_scores[:, :2], dim=1)
                else:
                    predicted = torch.zeros(len(roi_types), dtype=torch.long, device=device)

                # Calculate accuracy
                true_labels_tensor = torch.tensor(true_labels, device=device, dtype=torch.long)
                correct += (predicted == true_labels_tensor).sum().item()
                total += len(roi_types)

            except Exception as e:
                print(f"ROI evaluation batch {i} error: {e}")
                continue

    accuracy = correct / total if total > 0 else 0.0

    # Calculate high-quality ROI ratio
    high_quality_count = 0
    for strategy, count in strategy_stats.items():
        if strategy_to_quality.get(strategy, 0) == 1:
            high_quality_count += count
    high_quality_ratio = high_quality_count / total if total > 0 else 0

    # Simplified statistical report (only show key information)
    print(f"ROI quality evaluation: {correct}/{total} = {accuracy:.4f} (high-quality ROI ratio: {high_quality_ratio:.1%})")

    return {'roi_quality_accuracy': accuracy}

class EnglishMedicalTrainer:
    """English Medical MedCLIP Trainer"""

    def __init__(self, model, model_manager, dataset, tokenizer, config, device, logger):
        self.model = model.to(device)
        self.model_manager = model_manager
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.config = config
        self.device = device
        self.logger = logger

        # Training configuration
        self.training_config = config.get('training', {})
        self.stages_config = self.training_config.get('stages', {})

        # Debug mode control
        self.debug_mode = config.get('debug_mode', False)

        # Create output directory
        os.makedirs(config['output_dir'], exist_ok=True)

        # Mixed precision training
        self.scaler = GradScaler()

        # Gradient accumulation and clipping
        self.grad_accumulation_steps = self.training_config.get('grad_accumulation_steps', 1)
        self.max_grad_norm = self.training_config.get('max_grad_norm', 1.0)

        # Validation configuration
        self.val_sample_n = self.training_config.get('val_sample_n', 1000)
        self.eval_batch_size = self.training_config.get('eval_batch_size', 32)

        # Data loader parameters
        self.num_workers = self.training_config.get('num_workers', 4)

        # Training history record (fixed: not using Event objects)
        self.training_history = []

        print(f"Trainer initialization completed:")
        print(f"  Dataset size: {len(self.dataset)}")
        print(f"  Validation samples: {self.val_sample_n}")
        print(f"  Gradient accumulation steps: {self.grad_accumulation_steps}")