#!/usr/bin/env python3
"""
Complete training script: mortality_24h_48h and los_prediction_48h classification tasks
Report metrics: AUROC, AUPRC, Accuracy, F1-score
Based on train_mortality_24h_48h.py with aligned training logic
"""

import os
import sys
# Setup paths for release directory
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(BASE_DIR)  # project root

# Add paths - prioritize release directory, fallback to project root
sys.path.insert(0, BASE_DIR)
if os.path.exists(os.path.join(BASE_DIR, 'datapress')):
    sys.path.insert(0, os.path.join(BASE_DIR, 'datapress'))
    sys.path.insert(0, os.path.join(BASE_DIR, 'datapress', 'Aligned'))
else:
    sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datapress'))
    sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datapress', 'Aligned'))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)
from tqdm import tqdm
from omegaconf import OmegaConf
import json
import numpy as np
from datetime import datetime
import time
import gc

# Note: MedicalDatasetWithLOS is imported conditionally to support offline-only releases
from datapress.Aligned.medical_dataset_wrapper import MedicalDatasetWrapper, collate_medical_batch
from utils.optim_factory import Muon

# Set memory allocation strategy
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

TASK_NAMES = ['mortality_24h_48h', 'los_prediction_48h']

# Global variables: store training set statistics for time series and static features (for normalization)
_ts_normalization_means = None
_ts_normalization_stds = None
_static_normalization_means = None
_static_normalization_stds = None

def extract_logits_tensor(obj):
    """
    Recursively extract logits tensor from model output
    Supports nested dictionaries, lists, tuples, etc.
    """
    if isinstance(obj, torch.Tensor):
        return obj
    elif isinstance(obj, dict):
        # Prioritize common logits keys
        for key in ['logits', 'output', 'pred', 'classification', 'logit']:
            if key in obj:
                result = extract_logits_tensor(obj[key])
                if result is not None:
                    return result
        # Recursively search all values
        for value in obj.values():
            result = extract_logits_tensor(value)
            if result is not None:
                return result
    elif isinstance(obj, (list, tuple)):
        for item in obj:
            result = extract_logits_tensor(item)
            if result is not None:
                return result
    return None

def calculate_classification_metrics(y_true, y_pred_proba, task_name="task", threshold=None):
    """
    Calculate complete metrics for binary classification task (supports one-hot encoding)
    If threshold is None, will search for optimal threshold (based on F1)
    """
    # Convert to numpy arrays
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.detach().cpu().numpy()
    if isinstance(y_pred_proba, torch.Tensor):
        y_pred_proba = y_pred_proba.detach().cpu().numpy()
    
    # Handle one-hot encoded labels
    if y_true.ndim == 2 and y_true.shape[1] == 2:
        y_true_binary = y_true[:, 1].astype(int)
    else:
        y_true_binary = y_true.flatten().astype(int)
    
    # Handle prediction probabilities
    if y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 2:
        # If logits, convert to probabilities
        if y_pred_proba.max() > 1.0 or y_pred_proba.min() < 0.0:
            y_pred_proba = torch.softmax(torch.from_numpy(y_pred_proba), dim=-1).numpy()
        y_pred_proba_pos = y_pred_proba[:, 1]
    else:
        y_pred_proba_pos = y_pred_proba.flatten()
        if y_pred_proba_pos.max() > 1.0 or y_pred_proba_pos.min() < 0.0:
            y_pred_proba_pos = torch.sigmoid(torch.from_numpy(y_pred_proba_pos)).numpy()
    
    # Search for optimal threshold (if not specified)
    if threshold is None and len(np.unique(y_true_binary)) > 1:
        # For class-imbalanced tasks, search for optimal threshold between 0.05 and 0.5
        best_threshold = 0.5
        best_f1 = 0.0
        for thresh in np.arange(0.05, 0.51, 0.05):
            y_pred_binary = (y_pred_proba_pos > thresh).astype(int)
            try:
                f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0)
                if f1 > best_f1:
                    best_f1 = f1
                    best_threshold = thresh
            except:
                pass
        threshold = best_threshold
    else:
        threshold = threshold if threshold is not None else 0.5
    
    y_pred_binary = (y_pred_proba_pos > threshold).astype(int)
    
    # Ensure consistent length
    min_len = min(len(y_true_binary), len(y_pred_proba_pos))
    y_true_binary = y_true_binary[:min_len]
    y_pred_proba_pos = y_pred_proba_pos[:min_len]
    y_pred_binary = y_pred_binary[:min_len]
    
    metrics = {}
    
    try:
        # Accuracy
        metrics['accuracy'] = float(accuracy_score(y_true_binary, y_pred_binary))
        metrics['precision'] = float(precision_score(y_true_binary, y_pred_binary, zero_division=0))
        metrics['recall'] = float(recall_score(y_true_binary, y_pred_binary, zero_division=0))
        metrics['f1'] = float(f1_score(y_true_binary, y_pred_binary, zero_division=0))
        
        # AUROC
        if len(np.unique(y_true_binary)) > 1:
            metrics['auroc'] = float(roc_auc_score(y_true_binary, y_pred_proba_pos))
        else:
            metrics['auroc'] = 0.0
        
        # AUPRC
        if len(np.unique(y_true_binary)) > 1:
            metrics['auprc'] = float(average_precision_score(y_true_binary, y_pred_proba_pos))
        else:
            metrics['auprc'] = 0.0
        
        # Confusion Matrix
        cm = confusion_matrix(y_true_binary, y_pred_binary)
        if cm.shape == (2, 2):
            metrics['tn'] = int(cm[0, 0])
            metrics['fp'] = int(cm[0, 1])
            metrics['fn'] = int(cm[1, 0])
            metrics['tp'] = int(cm[1, 1])
        else:
            metrics['tn'] = metrics['fp'] = metrics['fn'] = metrics['tp'] = 0
        
        metrics['num_positive'] = int(y_true_binary.sum())
        metrics['num_negative'] = int(len(y_true_binary) - y_true_binary.sum())
        metrics['num_total'] = int(len(y_true_binary))
        metrics['optimal_threshold'] = float(threshold)
        
    except Exception as e:
        print(f"Error calculating metrics for {task_name}: {e}")
        import traceback
        traceback.print_exc()
        metrics = {
            'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0,
            'auroc': 0.0, 'auprc': 0.0,
            'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0,
            'num_positive': 0, 'num_negative': 0, 'num_total': 0
        }
    
    return metrics

@torch.no_grad()
def validate(model, val_loader, criterions, device, writer, global_step, task_names=TASK_NAMES):
    """Validation function (multi-task)"""
    global _ts_normalization_means, _ts_normalization_stds, _static_normalization_means, _static_normalization_stds
    
    model.eval()
    total_losses = {task: 0.0 for task in task_names}
    total_align_loss = 0.0
    total_batches = 0
    all_predictions = {task: [] for task in task_names}
    all_targets = {task: [] for task in task_names}
    
    val_start_time = time.time()
    
    for batch_idx, batch in enumerate(val_loader):
        images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
        images = images.to(device)
        static_data = static_data.to(device)
        padded_time_series = padded_time_series.to(device)
        time_series_lengths = time_series_lengths.to(device)
        for k in text_item:
            text_item[k] = text_item[k].to(device)
        
        # Apply normalization (same as time series data, using vectorized operations)
        if padded_time_series.dim() == 5:
            ts_values = padded_time_series[:, :, :, :, 0]  # [B, N_med, T, F]
            ts_mask = padded_time_series[:, :, :, :, 1]
            B, N_med, T, F = ts_values.shape
            ts_values_flat = ts_values.view(-1, F)  # [B*N_med*T, F]
            if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                ts_means = _ts_normalization_means.to(device)  # [F]
                ts_stds = _ts_normalization_stds.to(device)  # [F]
                if ts_means.shape[0] >= F and ts_stds.shape[0] >= F:
                    ts_means = ts_means[:F]
                    ts_stds = ts_stds[:F]
                    ts_values_flat = (ts_values_flat - ts_means.unsqueeze(0)) / (ts_stds.unsqueeze(0) + 1e-6)
            ts_values = ts_values_flat.view(B, N_med, T, F)
            padded_time_series = torch.stack([ts_values, ts_mask], dim=-1)
        elif padded_time_series.dim() == 4:
            B, N_med, T, F = padded_time_series.shape
            ts_values_flat = padded_time_series.view(-1, F)  # [B*N_med*T, F]
            if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                ts_means = _ts_normalization_means.to(device)  # [F]
                ts_stds = _ts_normalization_stds.to(device)  # [F]
                if ts_means.shape[0] >= F and ts_stds.shape[0] >= F:
                    ts_means = ts_means[:F]
                    ts_stds = ts_stds[:F]
                    ts_values_flat = (ts_values_flat - ts_means.unsqueeze(0)) / (ts_stds.unsqueeze(0) + 1e-6)
            padded_time_series = ts_values_flat.view(B, N_med, T, F)
        
        # Apply static data normalization (using vectorized operations)
        if static_data is not None and static_data.numel() > 0:
            original_static_shape = static_data.shape
            if static_data.dim() == 3:
                static_values = static_data[:, 0, :]  # [B, D_static]
                static_reshaped = True
            elif static_data.dim() == 2:
                static_values = static_data  # [B, D_static]
                static_reshaped = False
            else:
                static_values = static_data.view(static_data.shape[0], -1)
                static_reshaped = False
            
            D_static = static_values.shape[-1]
            if _static_normalization_means is not None and _static_normalization_stds is not None:
                static_means = _static_normalization_means.to(device)  # [D_static]
                static_stds = _static_normalization_stds.to(device)  # [D_static]
                if static_means.shape[0] >= D_static and static_stds.shape[0] >= D_static:
                    static_means = static_means[:D_static]
                    static_stds = static_stds[:D_static]
                    static_values = (static_values - static_means.unsqueeze(0)) / (static_stds.unsqueeze(0) + 1e-6)
            
            if static_reshaped:
                static_data = static_values.unsqueeze(1).expand(original_static_shape)
            else:
                static_data = static_values.view(original_static_shape)
        
        packed_ts = {
            'ts_data': padded_time_series,
            'seq_lengths': time_series_lengths
        }
        
        # Detect if multimodal information is missing (also handle in validation)
        has_image_val = True
        has_text_val = True
        
        # Check images: if all zeros or contains NaN/Inf, consider missing
        if images.numel() > 0:
            image_sum = images.abs().sum()
            has_nan_inf = torch.isnan(images).any() or torch.isinf(images).any()
            if image_sum.item() < 1e-6 or has_nan_inf:
                images = torch.zeros_like(images)
                has_image_val = False
        
        # Check text: if input_ids all zeros or contains NaN/Inf, consider missing
        if 'input_ids' in text_item and text_item['input_ids'].numel() > 0:
            text_sum = text_item['input_ids'].abs().sum()
            has_nan_inf = torch.isnan(text_item['input_ids']).any() or torch.isinf(text_item['input_ids']).any()
            if text_sum.item() < 1e-6 or has_nan_inf:
                for k in text_item:
                    if isinstance(text_item[k], torch.Tensor):
                        text_item[k] = torch.zeros_like(text_item[k])
                has_text_val = False
        
        has_missing_modality_val = not has_image_val or not has_text_val
        
        # Get labels for all tasks
        if not isinstance(labels, dict):
            continue
        
        # Forward pass
        try:
            out = model(images, text_item, packed_ts, static_data, task='classification')
        except TypeError:
            out = model(images, text_item, packed_ts, static_data)
        
        # Process model output
        align_loss = torch.tensor(0.0).to(device)
        outputs_dict = {}  # Store logits for each task
        
        if isinstance(out, tuple):
            # (dict, align_loss) format
            outputs_dict, align_loss = out
            if not isinstance(outputs_dict, dict):
                continue
            align_loss = align_loss if isinstance(align_loss, torch.Tensor) else torch.tensor(align_loss).to(device) if isinstance(align_loss, (int, float)) else torch.tensor(0.0).to(device)
            
            # If modality missing, set align_loss to 0
            if has_missing_modality_val:
                align_loss = torch.tensor(0.0).to(device)
            elif torch.isnan(align_loss) or torch.isinf(align_loss):
                align_loss = torch.tensor(0.0).to(device)
            
        elif isinstance(out, dict):
            # Dict format - should contain logits for each task
            if all(task in out for task in task_names):
                outputs_dict = {task: out[task].clone() for task in task_names}
            align_loss = out.get('clip_loss', out.get('align_loss', torch.tensor(0.0).to(device)))
            if not isinstance(align_loss, torch.Tensor):
                align_loss = torch.tensor(align_loss).to(device) if isinstance(align_loss, (int, float)) else torch.tensor(0.0).to(device)
            
            # If modality missing, set align_loss to 0
            if has_missing_modality_val:
                align_loss = torch.tensor(0.0).to(device)
            elif torch.isnan(align_loss) or torch.isinf(align_loss):
                align_loss = torch.tensor(0.0).to(device)
        else:
            continue
        
        if len(outputs_dict) == 0:
            continue
        
        # Process labels and predictions for each task
        for task_name in task_names:
            target = labels.get(task_name)
            if target is None:
                continue
            
            # Handle label shape
            # Label format: may be [B], [B, 1], [B, 2] (one-hot [neg, pos]), or [B, n_med, ...]
            target = target.float().to(device)
            
            # Handle multi-dimensional labels (may be [B, n_med, ...] format)
            if target.dim() >= 3:
                # [B, n_med, ...] format, take first med
                target = target[:, 0, :]
            
            # Handle 2D labels
            if target.dim() == 2:
                if target.shape[1] == 2:
                    target = target.argmax(dim=-1).float()
                elif target.shape[1] == 1:
                    target = target.squeeze(-1)
                else:
                    target = target[:, 0]
            elif target.dim() > 2:
                target = target.view(-1, target.shape[-1])
                if target.shape[-1] == 2:
                    target = target.argmax(dim=-1).float()
                else:
                    target = target.squeeze(-1)
            
            target = target.view(-1)
            
            # Extract logits for corresponding task from dictionary
            if task_name not in outputs_dict:
                continue
            
            logits = outputs_dict[task_name]
            if not isinstance(logits, torch.Tensor):
                continue
            
                # Handle logits shape
            if logits.dim() >= 3:
                logits = logits[:, 0]
            if logits.dim() == 1:
                pass
            elif logits.dim() == 2:
                if logits.shape[-1] == 2:
                    logits = logits[:, 1]  # Take positive class logit
                elif logits.shape[-1] == 1:
                    logits = logits.squeeze(-1)
                else:
                    logits = logits[:, 0]
            
            logits = logits.view(-1)
            
            # Ensure batch dimension matches
            batch_size = target.shape[0]
            if logits.shape[0] != batch_size:
                if logits.shape[0] > batch_size:
                    logits = logits[:batch_size]
                else:
                    continue
            
            # Calculate loss
            loss_pred = criterions[task_name](logits, target)
            total_losses[task_name] += loss_pred.item()
            
            all_predictions[task_name].append(logits.detach())
            all_targets[task_name].append(target.detach())
        
        total_align_loss += align_loss.item()
        total_batches += 1
    
    # Calculate average losses and metrics
    avg_losses = {task: total_losses[task] / total_batches if total_batches > 0 else 0.0 
                  for task in task_names}
    avg_align_loss = total_align_loss / total_batches if total_batches > 0 else 0.0
    
    metrics_dict = {}
    for task_name in task_names:
        if len(all_predictions[task_name]) > 0:
            all_pred = torch.cat(all_predictions[task_name], dim=0)
            all_tgt = torch.cat(all_targets[task_name], dim=0)
            
            metrics_dict[task_name] = calculate_classification_metrics(all_tgt, all_pred, task_name, threshold=None)
        else:
            metrics_dict[task_name] = {
                'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0,
                'auroc': 0.0, 'auprc': 0.0
            }
    
    # Record validation metrics (if writer exists)
    if writer is not None:
        for task_name in task_names:
            metrics = metrics_dict[task_name]
            writer.add_scalar(f'val/{task_name}/loss', avg_losses[task_name], global_step)
            writer.add_scalar(f'val/{task_name}/accuracy', metrics['accuracy'], global_step)
            writer.add_scalar(f'val/{task_name}/f1', metrics['f1'], global_step)
            writer.add_scalar(f'val/{task_name}/auroc', metrics['auroc'], global_step)
            writer.add_scalar(f'val/{task_name}/auprc', metrics['auprc'], global_step)
        
        writer.add_scalar('val/align_loss', avg_align_loss, global_step)
        
        val_time = time.time() - val_start_time
        writer.add_scalar('val/time', val_time, global_step)
    else:
        val_time = time.time() - val_start_time
    
    print(f"Validation - Align Loss: {avg_align_loss:.4f}")
    for task_name in task_names:
        metrics = metrics_dict[task_name]
        print(f"  {task_name} - Loss: {avg_losses[task_name]:.4f}, Acc: {metrics['accuracy']:.4f}, "
              f"F1: {metrics['f1']:.4f}, AUROC: {metrics['auroc']:.4f}, AUPRC: {metrics['auprc']:.4f}")
    
    return avg_losses, avg_align_loss, metrics_dict

def train_multi_task(
    model, train_loader, val_loader=None, num_epochs=10, lr=1e-4, 
    device='cuda', log_dir="runs/classify", experiment_name=None, 
    pos_weights=None, task_names=TASK_NAMES, task_weights=None
):
    """
    Train multi-task (mortality_24h_48h and los_prediction_48h)
    """
    # Declare global variables (at the beginning of the function)
    global _ts_normalization_means, _ts_normalization_stds, _static_normalization_means, _static_normalization_stds
    
    ts_all_values = []
    static_all_values = []
    max_samples_for_stats = 5000  # Use at most 5000 samples to compute statistics
    sample_count = 0
    
    for batch_idx, batch in enumerate(train_loader):
        if sample_count >= max_samples_for_stats:
            break
        try:
            _, _, padded_time_series, _, static_data, _ = batch
            # Extract time series data values
            if padded_time_series.dim() == 5:
                ts_values = padded_time_series[:, :, :, :, 0]  # [B, N_med, T, F]
            elif padded_time_series.dim() == 4:
                ts_values = padded_time_series
            else:
                continue
            
            # Flatten all dimensions except feature dimension
            ts_values_flat = ts_values.view(-1, ts_values.shape[-1])
            ts_all_values.append(ts_values_flat.cpu())
            
            # Extract static data values
            if static_data is not None and static_data.numel() > 0:
                if static_data.dim() == 3:
                    static_values = static_data[:, 0, :]  # [B, D_static]
                elif static_data.dim() == 2:
                    static_values = static_data
                else:
                    static_values = static_data.view(static_data.shape[0], -1)
                static_all_values.append(static_values.cpu())
            
            sample_count += ts_values.shape[0]
        except Exception:
            continue
    
    global _ts_normalization_means, _ts_normalization_stds, _static_normalization_means, _static_normalization_stds
    
    # Calculate time series feature statistics
    if len(ts_all_values) > 0:
        # Merge all batch data
        ts_all = torch.cat(ts_all_values, dim=0)  # [N_total, F]
        # Calculate mean and std for each feature (using more robust method)
        _ts_normalization_means = ts_all.mean(dim=0)  # [F]
        _ts_normalization_stds = ts_all.std(dim=0)  # [F]
        
        for f in range(ts_all.shape[1]):
            feature_data = ts_all[:, f]
            if _ts_normalization_stds[f] > 100.0:
                median = feature_data.median()
                mad = (feature_data - median).abs().median()
                # MAD to std: std ≈ 1.4826 * MAD
                robust_std = mad * 1.4826
                if robust_std > 1e-6:
                    _ts_normalization_means[f] = median
                    _ts_normalization_stds[f] = robust_std
        
        _ts_normalization_stds = torch.clamp(_ts_normalization_stds, min=1e-6)
    else:
        _ts_normalization_means = None
        _ts_normalization_stds = None
    
    if len(static_all_values) > 0:
        static_all = torch.cat(static_all_values, dim=0)  # [N_total, D_static]
        _static_normalization_means = static_all.mean(dim=0)  # [D_static]
        _static_normalization_stds = static_all.std(dim=0)  # [D_static]
        
        for f in range(static_all.shape[1]):
            feature_data = static_all[:, f]
            if _static_normalization_stds[f] > 100.0:
                median = feature_data.median()
                mad = (feature_data - median).abs().median()
                robust_std = mad * 1.4826
                if robust_std > 1e-6:
                    _static_normalization_means[f] = median
                    _static_normalization_stds[f] = robust_std
        
        _static_normalization_stds = torch.clamp(_static_normalization_stds, min=1e-6)
    else:
        _static_normalization_means = None
        _static_normalization_stds = None
    
    from torch.utils.data import DataLoader
    train_dataset = train_loader.dataset
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_loader.batch_size,
        shuffle=True,
        num_workers=train_loader.num_workers,
        pin_memory=train_loader.pin_memory,
        collate_fn=train_loader.collate_fn,
        persistent_workers=train_loader.persistent_workers if hasattr(train_loader, 'persistent_workers') else False,
        prefetch_factor=train_loader.prefetch_factor if hasattr(train_loader, 'prefetch_factor') else 2
    )
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
    
    model = model.to(device)
    effective_lr = lr * 0.3  
    
    param_groups = []
    
    ts_encoder_params = []
    other_params = []
    
    for name, param in model.named_parameters():
        if 'ts_encoder' in name:
            ts_encoder_params.append(param)
        else:
            other_params.append(param)
    
    if len(ts_encoder_params) > 0:
        param_groups.append({
            'params': ts_encoder_params,
            'lr': effective_lr * 0.1,  
            'name': 'ts_encoder'
        })
        print(f"[INFO] TimeSeriesEncoder parameters: {len(ts_encoder_params)} params, lr={effective_lr * 0.1:.2e}")
    
    if len(other_params) > 0:
        param_groups.append({
            'params': other_params,
            'lr': effective_lr,
            'name': 'other'
        })
        print(f"[INFO] Other parameters: {len(other_params)} params, lr={effective_lr:.2e}")
    
    if len(param_groups) > 0:
        optimizer = Muon(param_groups, lr=effective_lr, momentum=0.9, weight_decay=0.01)
    else:
        optimizer = Muon(model.parameters(), lr=effective_lr, momentum=0.9, weight_decay=0.01)
    
    print(f"[INFO] Using effective learning rate: {effective_lr:.2e} (reduced from {lr:.2e} for stability)")
    if len(ts_encoder_params) > 0:
        print(f"[INFO] TimeSeriesEncoder learning rate: {effective_lr * 0.1:.2e} (10x smaller for gradient stability)")
    
    if pos_weights is None:
        pos_weights = {}
        pos_weights['mortality_24h_48h'] = 0.02  
        pos_weights['los_prediction_48h'] = 4.0  
    
    criterions = {}
    if task_weights is None:
        task_weights = {}
        for task_name in task_names:
            pos_weight = pos_weights.get(task_name, 1.0)
            if pos_weight < 0.1:  
                task_weights[task_name] = 0.5  
            elif pos_weight > 3.0:  
                task_weights[task_name] = 1.5  
            else:
                task_weights[task_name] = 1.0  
    
    for task_name in task_names:
        pos_weight = pos_weights.get(task_name, 1.0)
        pos_weight_tensor = torch.tensor(pos_weight, device=device)
        criterions[task_name] = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor, reduction='mean')
        print(f"{task_name}: Using BCEWithLogitsLoss with pos_weight={pos_weight:.4f}, task_weight={task_weights.get(task_name, 1.0):.4f}, reduction='mean'")
    
    
    from torch.optim.lr_scheduler import CosineAnnealingLR
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr * 0.01)
    
    if experiment_name is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        experiment_name = f"multi_task_{timestamp}"
    
    os.makedirs(log_dir, exist_ok=True)
    log_path = os.path.join(log_dir, experiment_name)
    writer = SummaryWriter(log_dir=log_path)
    
    config_dict = {
        'tasks': task_names,
        'num_epochs': num_epochs,
        'learning_rate': lr,
        'device': str(device),
        'pos_weights': pos_weights,
        'model_params': sum(p.numel() for p in model.parameters()),
        'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad)
    }
    
    for key, value in config_dict.items():
        writer.add_text(f'config/{key}', str(value), 0)
    
    global_step = 0
    best_val_loss = {task: float('inf') for task in task_names}
    best_val_metrics = {task: {} for task in task_names}
    
    print(f"Starting training: {experiment_name}")
    print(f"Log directory: {log_path}")
    print(f"Tasks: {task_names}")
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        model.train()
        
        epoch_losses = {task: 0.0 for task in task_names}
        epoch_align_loss = 0.0
        epoch_batches = 0
        epoch_start_time = time.time()
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        
        for batch_idx, batch in enumerate(pbar):
            try:
                if batch is None:
                    continue
            except RuntimeError as e:
                if "DataLoader worker" in str(e) or "killed by signal" in str(e):
                    continue
                else:
                    raise
            
            images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
            
    
            images = images.to(device, non_blocking=True)
            static_data = static_data.to(device, non_blocking=True)
            padded_time_series = padded_time_series.to(device, non_blocking=True)
            time_series_lengths = time_series_lengths.to(device, non_blocking=True)
            for k in text_item:
                text_item[k] = text_item[k].to(device, non_blocking=True)
            

            if padded_time_series.dim() == 5:
                ts_values = padded_time_series[:, :, :, :, 0]  # [B, N_med, T, F]
                ts_mask = padded_time_series[:, :, :, :, 1]
                B, N_med, T, F = ts_values.shape
                ts_values_flat = ts_values.view(-1, F)  # [B*N_med*T, F]

                if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                    ts_means = _ts_normalization_means.to(device)  # [F]
                    ts_stds = _ts_normalization_stds.to(device)  # [F]
                    if ts_means.shape[0] >= F and ts_stds.shape[0] >= F:
                        ts_means = ts_means[:F]
                        ts_stds = ts_stds[:F]
                        ts_values_flat = (ts_values_flat - ts_means.unsqueeze(0)) / (ts_stds.unsqueeze(0) + 1e-6)
                    else:
                        ts_means_batch = ts_values_flat.mean(dim=0, keepdim=True)  # [1, F]
                        ts_stds_batch = ts_values_flat.std(dim=0, keepdim=True) + 1e-6  # [1, F]
                        ts_values_flat = (ts_values_flat - ts_means_batch) / ts_stds_batch
                else:
                    ts_means_batch = ts_values_flat.mean(dim=0, keepdim=True)  # [1, F]
                    ts_stds_batch = ts_values_flat.std(dim=0, keepdim=True) + 1e-6  # [1, F]
                    ts_values_flat = (ts_values_flat - ts_means_batch) / ts_stds_batch
                ts_values = ts_values_flat.view(B, N_med, T, F)
                padded_time_series = torch.stack([ts_values, ts_mask], dim=-1)
            elif padded_time_series.dim() == 4:
                B, N_med, T, F = padded_time_series.shape
                ts_values_flat = padded_time_series.view(-1, F)  # [B*N_med*T, F]
                if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                    ts_means = _ts_normalization_means.to(device)  # [F]
                    ts_stds = _ts_normalization_stds.to(device)  # [F]
                    if ts_means.shape[0] >= F and ts_stds.shape[0] >= F:
                        ts_means = ts_means[:F]
                        ts_stds = ts_stds[:F]
                        ts_values_flat = (ts_values_flat - ts_means.unsqueeze(0)) / (ts_stds.unsqueeze(0) + 1e-6)
                    else:
                        ts_means_batch = ts_values_flat.mean(dim=0, keepdim=True)  # [1, F]
                        ts_stds_batch = ts_values_flat.std(dim=0, keepdim=True) + 1e-6  # [1, F]
                        ts_values_flat = (ts_values_flat - ts_means_batch) / ts_stds_batch
                else:
                    ts_means_batch = ts_values_flat.mean(dim=0, keepdim=True)  # [1, F]
                    ts_stds_batch = ts_values_flat.std(dim=0, keepdim=True) + 1e-6  # [1, F]
                    ts_values_flat = (ts_values_flat - ts_means_batch) / ts_stds_batch
                padded_time_series = ts_values_flat.view(B, N_med, T, F)
            
            if static_data is not None and static_data.numel() > 0:
                original_static_shape = static_data.shape
                if static_data.dim() == 3:
                    static_values = static_data[:, 0, :]  # [B, D_static]
                    static_reshaped = True
                elif static_data.dim() == 2:
                    static_values = static_data  # [B, D_static]
                    static_reshaped = False
                else:
                    static_values = static_data.view(static_data.shape[0], -1)
                    static_reshaped = False
                
                D_static = static_values.shape[-1]
                if _static_normalization_means is not None and _static_normalization_stds is not None:
                    static_means = _static_normalization_means.to(device)  # [D_static]
                    static_stds = _static_normalization_stds.to(device)  # [D_static]
                    if static_means.shape[0] >= D_static and static_stds.shape[0] >= D_static:
                        static_means = static_means[:D_static]
                        static_stds = static_stds[:D_static]
                        static_values = (static_values - static_means.unsqueeze(0)) / (static_stds.unsqueeze(0) + 1e-6)
                    else:
                        static_means_batch = static_values.mean(dim=0, keepdim=True)  # [1, D_static]
                        static_stds_batch = static_values.std(dim=0, keepdim=True) + 1e-6  # [1, D_static]
                        static_values = (static_values - static_means_batch) / static_stds_batch
                else:
                    static_means_batch = static_values.mean(dim=0, keepdim=True)  # [1, D_static]
                    static_stds_batch = static_values.std(dim=0, keepdim=True) + 1e-6  # [1, D_static]
                    static_values = (static_values - static_means_batch) / static_stds_batch
                
                if static_reshaped:
                    static_data = static_values.unsqueeze(1).expand(original_static_shape)
                else:
                    static_data = static_values.view(original_static_shape)
            
            packed_ts = {
                'ts_data': padded_time_series,
                'seq_lengths': time_series_lengths
            }
            
            
            use_text_in_model = True
            use_image_in_model = True
            if hasattr(model, 'base_model') and hasattr(model.base_model, 'ff') and hasattr(model.base_model.ff, 'use_text'):
                use_text_in_model = model.base_model.ff.use_text
                use_image_in_model = model.base_model.ff.use_image
            
            # Check images: if all zeros or contains NaN/Inf, consider missing
            has_image = True
            if not use_image_in_model:
                has_image = False
            elif images.numel() > 0:
                image_sum = images.abs().sum()
                has_nan_inf = torch.isnan(images).any() or torch.isinf(images).any()
                if has_nan_inf:
                    images = torch.randn_like(images) * 0.01
                    has_image = False
                elif image_sum.item() < 1e-6:
                    has_image = False
                    pass
            
            # Check text: if input_ids all zeros or contains NaN/Inf, consider missing
            has_text = True
            if 'input_ids' in text_item and text_item['input_ids'].numel() > 0:
                text_sum = text_item['input_ids'].abs().sum()
                has_nan_inf = torch.isnan(text_item['input_ids']).any() or torch.isinf(text_item['input_ids']).any()
                if has_nan_inf:
                    for k in text_item:
                        if isinstance(text_item[k], torch.Tensor):
                            if torch.isnan(text_item[k]).any() or torch.isinf(text_item[k]).any():
                                text_item[k] = torch.zeros_like(text_item[k])
                    has_text = False
                elif text_sum.item() < 1e-6:
                    has_text = False
                    pass
            
            # Get labels for all tasks
            if not isinstance(labels, dict):
                continue
            
            optimizer.zero_grad(set_to_none=True)
            
            has_missing_modality = not has_image or not has_text
            
            align_loss = torch.tensor(0.0).to(device)
            outputs_dict = {}  # Store logits for each task
            
            image_encoder_requires_grad = {}
            if not has_image and hasattr(model, 'base_model'):
                for name, param in model.base_model.named_parameters():
                    if 'image_encoder' in name:
                        image_encoder_requires_grad[name] = param.requires_grad
                        param.requires_grad = False
            
            text_encoder_requires_grad = {}
            if not has_text and hasattr(model, 'base_model'):
                for name, param in model.base_model.named_parameters():
                    if 'text_encoder' in name:
                        text_encoder_requires_grad[name] = param.requires_grad
                        param.requires_grad = False

            try:
                out = model(images, text_item, packed_ts, static_data, task='classification')
            except TypeError:
                out = model(images, text_item, packed_ts, static_data)
            if isinstance(out, tuple):
                outputs_dict, align_loss = out
                if not isinstance(outputs_dict, dict):
                    continue
                align_loss = align_loss if isinstance(align_loss, torch.Tensor) else torch.tensor(align_loss).to(device) if isinstance(align_loss, (int, float)) else torch.tensor(0.0).to(device)
                
                if has_missing_modality:
                    align_loss = torch.tensor(0.0).to(device).detach()
                    if batch_idx == 0 and epoch == 0:
                        print(f"[INFO] Skipping align_loss due to missing modalities")
                elif torch.isnan(align_loss) or torch.isinf(align_loss):
                    align_loss = torch.tensor(0.0).to(device).detach()
                    if batch_idx == 0 and epoch == 0:
                        print(f"[WARNING] align_loss is NaN/Inf, setting to 0")
                
                has_nan_in_outputs = False
                for task_name in task_names:
                    if task_name in outputs_dict:
                        logits = outputs_dict[task_name]
                        if isinstance(logits, torch.Tensor):
                            if torch.isnan(logits).any() or torch.isinf(logits).any():
                                print(f"[ERROR] {task_name} logits contain NaN/Inf in model output, skipping batch")
                                has_nan_in_outputs = True
                                break
                
                if has_nan_in_outputs:
                    continue
            elif isinstance(out, dict):
                if all(task in out for task in task_names):
                    outputs_dict = {task: out[task] for task in task_names}
                    
                    has_nan_in_outputs = False
                    for task_name in task_names:
                        logits = outputs_dict[task_name]
                        if isinstance(logits, torch.Tensor):
                            if torch.isnan(logits).any() or torch.isinf(logits).any():
                                has_nan_in_outputs = True
                                break
                    
                    if has_nan_in_outputs:
                        continue
                else:
                    logits = extract_logits_tensor(out)
                    if logits is not None:
                        if isinstance(logits, torch.Tensor):
                            if torch.isnan(logits).any() or torch.isinf(logits).any():
                                continue
                        for task in task_names:
                            outputs_dict[task] = logits
                    else:
                        continue
                align_loss = out.get('clip_loss', out.get('align_loss', torch.tensor(0.0).to(device)))
                if not isinstance(align_loss, torch.Tensor):
                    align_loss = torch.tensor(align_loss).to(device) if isinstance(align_loss, (int, float)) else torch.tensor(0.0).to(device)
                
                if has_missing_modality:
                    align_loss = torch.tensor(0.0).to(device).detach()
                    if batch_idx == 0 and epoch == 0:
                        print(f"[INFO] Skipping align_loss due to missing modalities")
                elif torch.isnan(align_loss) or torch.isinf(align_loss):
                    align_loss = torch.tensor(0.0).to(device).detach()
                    if batch_idx == 0 and epoch == 0:
                        print(f"[WARNING] align_loss is NaN/Inf, setting to 0")
            else:
                if isinstance(out, torch.Tensor):
                    for task in task_names:
                        outputs_dict[task] = out
                else:
                    continue
            
            if len(outputs_dict) == 0:
                continue
            
            if image_encoder_requires_grad:
                for name, param in model.base_model.named_parameters():
                    if name in image_encoder_requires_grad:
                        param.requires_grad = image_encoder_requires_grad[name]
            
            if text_encoder_requires_grad:
                for name, param in model.base_model.named_parameters():
                    if name in text_encoder_requires_grad:
                        param.requires_grad = text_encoder_requires_grad[name]
            
            if torch.isnan(align_loss) or torch.isinf(align_loss):
                align_loss = torch.tensor(0.0).to(device).detach()
                if batch_idx == 0 and epoch == 0:
                    print(f"[WARNING] align_loss is NaN/Inf, setting to 0")
            
            if has_missing_modality:
                pass
            elif align_loss.item() > 100:
                align_loss = align_loss.clamp(max=10.0)  
            
            if has_missing_modality or align_loss.requires_grad == False:
                align_loss_weight = 0.0
            else:
                align_loss_weight = 0.05  
            
            total_loss = align_loss * align_loss_weight
            task_losses = {}
            
            for task_name in task_names:
                target = labels.get(task_name)
                if target is None:
                    continue
                
                target = target.float().to(device, non_blocking=True)
                
                if target.dim() >= 3:
                    target = target[:, 0, :]

                if target.dim() == 2:
                    if target.shape[1] == 2:
                        target = target.argmax(dim=-1).float()
                    elif target.shape[1] == 1:
                        target = target.squeeze(-1)
                    else:
                        target = target[:, 0]
                elif target.dim() > 2:
                    target = target.view(-1, target.shape[-1])
                    if target.shape[-1] == 2:
                        target = target.argmax(dim=-1).float()
                    else:
                        target = target.squeeze(-1)
                
                target = target.view(-1)
                
                if task_name not in outputs_dict:
                    continue
                
                task_logits = outputs_dict[task_name]
                if not isinstance(task_logits, torch.Tensor):
                    continue
                if task_logits.dim() >= 3:
                    task_logits = task_logits[:, 0]
                if task_logits.dim() == 1:
                    pass
                elif task_logits.dim() == 2:
                    if task_logits.shape[-1] == 2:
                        task_logits = task_logits[:, 1]
                    elif task_logits.shape[-1] == 1:
                        task_logits = task_logits.squeeze(-1)
                    else:
                        task_logits = task_logits[:, 0]
                
                task_logits = task_logits.view(-1)
                
                batch_size = target.shape[0]
                if task_logits.shape[0] != batch_size:
                    if task_logits.shape[0] > batch_size:
                        task_logits = task_logits[:batch_size]
                    else:
                        continue
                
                loss = criterions[task_name](task_logits, target)
                
                
                task_weight = task_weights.get(task_name, 1.0)
                weighted_loss = loss * task_weight
                

                
                task_losses[task_name] = loss  
                total_loss = total_loss + weighted_loss  
            
            if len(task_losses) == 0:
                continue
            
            total_loss.backward()
            
            has_nan_inf_grad = False
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                        has_nan_inf_grad = True
                        break
            
            if torch.isnan(total_loss) or torch.isinf(total_loss):
                optimizer.zero_grad(set_to_none=True)
                continue
            
            if has_nan_inf_grad:
                optimizer.zero_grad(set_to_none=True)
                continue
            
            if not has_image:
                for name, param in model.named_parameters():
                    if 'image_encoder' in name and param.grad is not None:
                        param.grad.zero_()
            
            if not has_text or not use_text_in_model:
                for name, param in model.named_parameters():
                    if 'text_encoder' in name and param.grad is not None:
                        param.grad.zero_()
            
            try:
                total_norm = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                total_norm = total_norm ** (1. / 2)
            except Exception:
                pass
            
            optimizer.step()
            
            if batch_idx > 0 and batch_idx % 10 == 0:
                torch.cuda.empty_cache()
                gc.collect()
            
            for task_name, loss in task_losses.items():
                epoch_losses[task_name] += loss.item()
            epoch_align_loss += align_loss.item()
            epoch_batches += 1
            
            pbar.set_postfix({
                **{f'{t}_Loss': f"{task_losses.get(t, 0):.4f}" for t in task_names},
                'Align': f"{align_loss.item():.4f}",
                'Step': global_step
            })
            
            if (batch_idx + 1) % 10 == 0:
                for task_name, loss in task_losses.items():
                    writer.add_scalar(f'train/step_{task_name}_loss', loss.item(), global_step)
                writer.add_scalar('train/step_align_loss', align_loss.item(), global_step)
                writer.add_scalar('train/step_total_loss', total_loss.item(), global_step)
                
                current_lr = optimizer.param_groups[0]['lr']
                writer.add_scalar('train/learning_rate', current_lr, global_step)
            
            global_step += 1
        
        avg_epoch_losses = {task: epoch_losses[task] / epoch_batches if epoch_batches > 0 else 0.0 
                           for task in task_names}
        avg_epoch_align_loss = epoch_align_loss / epoch_batches if epoch_batches > 0 else 0.0
        epoch_time = time.time() - epoch_start_time
        
        for task_name in task_names:
            writer.add_scalar(f'train/epoch_{task_name}_loss', avg_epoch_losses[task_name], epoch)
        writer.add_scalar('train/epoch_align_loss', avg_epoch_align_loss, epoch)
        writer.add_scalar('train/epoch_time', epoch_time, epoch)
        
        print(f"Epoch {epoch+1} - Avg Losses: {avg_epoch_losses}, Avg Align Loss: {avg_epoch_align_loss:.4f}, Time: {epoch_time:.2f}s")

        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        writer.add_scalar('train/learning_rate_epoch', current_lr, epoch)
        
        if val_loader is not None:
            val_losses, val_align_loss, val_metrics = validate(model, val_loader, criterions, device, writer, global_step, task_names)
            
            total_val_loss = sum(val_losses.values())
            total_best_val_loss = sum(best_val_loss.values())
            
            if total_val_loss < total_best_val_loss:
                for task_name in task_names:
                    best_val_loss[task_name] = val_losses[task_name]
                    best_val_metrics[task_name] = val_metrics[task_name]
                
                model_save_path = os.path.join(log_path, 'best_model.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_losses': val_losses,
                    'val_metrics': val_metrics,
                    'config': config_dict
                }, model_save_path)
                print(f"New best model saved! Total Val Loss: {total_val_loss:.4f}")
            
            if epoch % 5 == 0 or epoch == num_epochs - 1:
                metrics_file = os.path.join(log_path, f'metrics_epoch_{epoch}.json')
                with open(metrics_file, 'w') as f:
                    json.dump({
                        'epoch': epoch,
                        'val_metrics': val_metrics,
                        'best_val_metrics': best_val_metrics,
                        'timestamp': datetime.now().isoformat()
                    }, f, indent=2)
    
    final_model_path = os.path.join(log_path, 'final_model.pth')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'best_val_metrics': best_val_metrics
    }, final_model_path)
    
    print(f"\nTraining completed!")
    print(f"Best validation losses: {best_val_loss}")
    print(f"Best validation metrics:")
    for task_name in task_names:
        print(f"  {task_name}:")
        for key, value in best_val_metrics[task_name].items():
            if isinstance(value, (int, float)):
                print(f"    {key}: {value:.4f}")
    
    writer.close()

    final_report = {
        'experiment_name': experiment_name,
        'tasks': task_names,
        'best_val_losses': {k: float(v) for k, v in best_val_loss.items()},
        'best_val_metrics': {k: {mk: float(mv) if isinstance(mv, (int, float)) else mv 
                                 for mk, mv in v.items()} 
                            for k, v in best_val_metrics.items()},
        'config': config_dict,
        'timestamp': datetime.now().isoformat()
    }
    
    report_path = os.path.join(log_path, 'final_report.json')
    with open(report_path, 'w') as f:
        json.dump(final_report, f, indent=2)
    
    print(f"\nFinal report saved to: {report_path}")
    print(f"Logs saved to: {log_path}")
    
    return best_val_metrics, log_path

if __name__ == "__main__":
    import argparse
    from model.factory import create_multimodal_model
    
    parser = argparse.ArgumentParser(description='Train multi-task (mortality_24h_48h and los_prediction_48h)')
    parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
    parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=1, help='Number of workers (set to 0 to disable multiprocessing)')
    parser.add_argument('--persistent_workers', action='store_true', default=False, help='Use persistent workers to avoid worker restart overhead')
    parser.add_argument('--prefetch_factor', type=int, default=2, help='Number of batches prefetched by each worker')
    parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name')
    parser.add_argument('--max_samples', type=int, default=80000, help='Maximum number of samples for training (0 for full dataset)')
    parser.add_argument('--max_val_samples', type=int, default=10000, help='Maximum number of samples for validation (0 for full validation set)')
    parser.add_argument('--max_test_samples', type=int, default=10000, help='Maximum number of samples for testing (0 for full test set)')
    parser.add_argument('--config', type=str, default=None, help='Path to config file')
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
    default_log_dir = os.path.join(BASE_DIR, "runs/classify")
    parser.add_argument('--log_dir', type=str, default=default_log_dir, help='TensorBoard log directory')
    
    args = parser.parse_args()
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print("="*80)
    print("Training Multi-Task Classification (Mortality 24h-48h and LOS Prediction 48h)")
    print("="*80)
    
    # Auto-detect config if not provided
    if args.config:
        config_path = args.config if os.path.isabs(args.config) else os.path.join(BASE_DIR, args.config)
    else:
        # Auto-detect: check if offline data exists
        offline_data_path = os.path.join(BASE_DIR, "data_dir/sample_data/sample_data.pkl")
        offline_config = os.path.join(BASE_DIR, "exp/mimic_data/exp_mortality_24h48h_los_offline.yaml")
        default_config = os.path.join(BASE_DIR, "exp/mimic_data/exp_mortality_24h48h_los.yaml")
        
        if os.path.exists(offline_data_path) and os.path.exists(offline_config):
            config_path = offline_config
            print("Offline data detected, using offline configuration")
        else:
            config_path = default_config
    
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    opt = OmegaConf.load(config_path)
    
    print("\nCreating dataset...")
    
    # Check if offline data is available
    offline_data_path = opt.data.train_val.get('offline_data_path', None)
    if offline_data_path:
        # Resolve relative path relative to release directory
        if not os.path.isabs(offline_data_path):
            offline_data_path = os.path.join(BASE_DIR, offline_data_path)
        
        if os.path.exists(offline_data_path):
            print(f"Using offline sample data: {offline_data_path}")
            from datapress.Aligned.offline_sample_dataset import OfflineSampleDataset
            med_dataset = OfflineSampleDataset(offline_data_path)
            wrapped_dataset = MedicalDatasetWrapper(med_dataset)
            print(f"Offline dataset size: {len(wrapped_dataset)}")
        else:
            raise FileNotFoundError(f"Offline data file not found: {offline_data_path}")
    else:
        # Use regular dataset with database
        try:
            from datapress.Aligned.medical_dataset_with_los import MedicalDatasetWithLOS
        except ImportError:
            raise ImportError(
                "MedicalDatasetWithLOS not available. This appears to be an offline-only release. "
                "Please use offline data by setting 'offline_data_path' in the configuration file."
            )
        med_dataset = MedicalDatasetWithLOS(
            **opt.data.train_val,
            **opt.data.shared_param,
            los_cache_file="data_dir/cache/los_cache.pkl",
            threshold_hours=48
        )
        print(f"Medical dataset size: {len(med_dataset)}")
        wrapped_dataset = MedicalDatasetWrapper(med_dataset)
        print(f"Wrapped dataset size: {len(wrapped_dataset)}")
    
    # Use same data split as baseline for fair comparison
    total_size = len(wrapped_dataset)
    train_size = int(0.8 * total_size)
    remaining_size = total_size - train_size
    val_size = int(0.5 * remaining_size)  # 10% of total
    test_size = remaining_size - val_size  # 10% of total
    
    train_full, val_full, test_full = random_split(
        wrapped_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"\nDataset splits: Train={len(train_full)}, Val={len(val_full)}, Test={len(test_full)}")
    
    # Use same subset logic as baseline for fair comparison
    if args.max_samples > 0 and args.max_samples < len(train_full):
        train_indices = list(range(args.max_samples))
        train_dataset = Subset(train_full, train_indices)
        print(f"Using subset of training dataset: {len(train_dataset)} samples")
    else:
        train_dataset = train_full
    
    if args.max_val_samples > 0 and args.max_val_samples < len(val_full):
        val_indices = list(range(args.max_val_samples))
        val_dataset = Subset(val_full, val_indices)
        print(f"Using subset of validation dataset: {len(val_dataset)} samples")
    else:
        val_dataset = val_full
    
    if args.max_test_samples > 0 and args.max_test_samples < len(test_full):
        test_indices = list(range(args.max_test_samples))
        test_dataset = Subset(test_full, test_indices)
        print(f"Using subset of test dataset: {len(test_dataset)} samples")
    else:
        test_dataset = test_full
    
    # Use same DataLoader setup as baseline for fair comparison
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    
    print(f"\nDataLoader sizes: Train={len(train_loader)} batches, Val={len(val_loader)} batches, Test={len(test_loader)} batches")
    
    print("\nInitializing model...")
    model_config = {
        'image_dim': 128,
        'text_dim': 128,
        'ts_dim': 64,
        'static_dim': 10,
        'shared_dim': 128,
        'out_dim': 128,  
        'out_len': 1,  
        'text_encoder_dim': 768,
        'enable_ae': True,
        'clip_temperature': 0.07,
        'use_text': False,  
        'use_image': False,  
        'use_multimodal': False,  
    }
    
    base_model = create_multimodal_model(model_config)
    
    from model.multi_task_wrapper import MultiTaskWrapper
    model = MultiTaskWrapper(
        base_model=base_model,
        task_names=TASK_NAMES,
        shared_dim=128,  
        out_dim=1  
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    print("\nCalculating pos_weights from training data...")
    pos_weights = {}
    for task_name in TASK_NAMES:
        pos_count = 0
        neg_count = 0
        total_count = 0
        
        sample_count = 0
        max_samples = 1000  
        
        try:
            for batch_idx, batch in enumerate(train_loader):
                if sample_count >= max_samples:
                    break
                try:
                    _, _, _, _, _, labels = batch
                    if isinstance(labels, dict) and task_name in labels:
                        target = labels[task_name].float()
                        # Use same label processing as baseline
                        if target.dim() == 2 and target.shape[1] == 2:
                            target = target[:, 1]  # Take positive class directly
                        target = target.view(-1)
                        pos_count += int((target == 1).sum().item())
                        neg_count += int((target == 0).sum().item())
                        total_count += len(target)
                        sample_count += len(target)
                except Exception:
                    continue
        except RuntimeError as e:
            if "DataLoader worker" in str(e) or "killed by signal" in str(e):
                pass
            else:
                raise
        
        # Use same pos_weight calculation as baseline
        if total_count > 0 and pos_count > 0:
            pos_weight = neg_count / pos_count
            pos_weights[task_name] = pos_weight
            pos_ratio = pos_count / total_count
            print(f"  {task_name}: Pos={pos_count} ({pos_ratio*100:.2f}%), Neg={neg_count}, pos_weight={pos_weight:.4f}")
        else:
            if task_name == 'mortality_24h_48h':
                pos_weights[task_name] = 0.02
            else:
                pos_weights[task_name] = 4.0
            print(f"  {task_name}: Using default pos_weight={pos_weights[task_name]:.4f}")
    
    print("\nStarting training...")
    task_weights = {}
    for task_name in TASK_NAMES:
        pos_weight = pos_weights.get(task_name, 1.0)
        if pos_weight < 0.1:  
            task_weights[task_name] = 0.5  
        elif pos_weight > 3.0:  
            task_weights[task_name] = 1.5  
        else:
            task_weights[task_name] = 1.0  
    
    best_metrics, log_path = train_multi_task(
        model, train_loader, val_loader=val_loader,
        num_epochs=args.epochs, lr=args.lr, device=device,
        log_dir=args.log_dir,
        experiment_name=args.experiment_name,
        pos_weights=pos_weights,
        task_names=TASK_NAMES,
        task_weights=task_weights
    )
    
    print("\n" + "="*80)
    print("Training Summary")
    print("="*80)
    print(f"Tasks: {TASK_NAMES}")
    print(f"Best Validation Metrics:")
    for task_name in TASK_NAMES:
        print(f"\n  {task_name}:")
        for key, value in best_metrics[task_name].items():
            if isinstance(value, (int, float)):
                print(f"    {key}: {value:.4f}")

    print("\n" + "="*80)
    print("Testing on Test Set")
    print("="*80)
    
    best_model_path = os.path.join(log_path, 'best_model.pth')
    if os.path.exists(best_model_path):
        print(f"Loading best model from: {best_model_path}")
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print("Best model loaded successfully!")
    else:
        print(f"Warning: Best model not found at {best_model_path}, using current model state")
    
    criterions = {}
    for task_name in TASK_NAMES:
        pos_weight = pos_weights.get(task_name, 1.0)
        pos_weight_tensor = torch.tensor(pos_weight, device=device)
        criterions[task_name] = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
    
    print("\nEvaluating on test set...")
    test_losses, test_align_loss, test_metrics = validate(
        model, test_loader, criterions, device, 
        writer=None, global_step=0, task_names=TASK_NAMES
    )
    
    print("\nTest Set Results:")
    print(f"Test - Align Loss: {test_align_loss:.4f}")
    for task_name in TASK_NAMES:
        metrics = test_metrics[task_name]
        print(f"  {task_name} - Loss: {test_losses[task_name]:.4f}, Acc: {metrics['accuracy']:.4f}, "
              f"F1: {metrics['f1']:.4f}, AUROC: {metrics['auroc']:.4f}, AUPRC: {metrics['auprc']:.4f}")
    
    final_report_path = os.path.join(log_path, 'final_report.json')
    if os.path.exists(final_report_path):
        with open(final_report_path, 'r') as f:
            final_report = json.load(f)
    else:
        final_report = {
            'experiment_name': args.experiment_name or 'unknown',
            'tasks': TASK_NAMES,
            'config': {}
        }
    
    final_report['test_losses'] = {k: float(v) for k, v in test_losses.items()}
    final_report['test_align_loss'] = float(test_align_loss)
    final_report['test_metrics'] = {k: {mk: float(mv) if isinstance(mv, (int, float)) else mv 
                                       for mk, mv in v.items()} 
                                   for k, v in test_metrics.items()}
    final_report['best_val_metrics'] = {k: {mk: float(mv) if isinstance(mv, (int, float)) else mv 
                                           for mk, mv in v.items()} 
                                       for k, v in best_metrics.items()}
    final_report['timestamp'] = datetime.now().isoformat()
    
    with open(final_report_path, 'w') as f:
        json.dump(final_report, f, indent=2)
    
    print(f"\nComplete report (validation + test) saved to: {final_report_path}")
    
    print("\n" + "="*80)
    print("Complete Results Summary")
    print("="*80)
    for task_name in TASK_NAMES:
        print(f"\n{task_name}:")
        print("  Validation Set:")
        for key, value in best_metrics[task_name].items():
            if isinstance(value, (int, float)) and key in ['accuracy', 'f1', 'auroc', 'auprc']:
                print(f"    {key}: {value:.4f}")
        print("  Test Set:")
        for key, value in test_metrics[task_name].items():
            if isinstance(value, (int, float)) and key in ['accuracy', 'f1', 'auroc', 'auprc']:
                print(f"    {key}: {value:.4f}")
