# training/train_EXCAP.py
# -*- coding: utf-8 -*-
import os
import gc
import csv
import time
import datetime
import yaml
from contextlib import suppress
from typing import Dict, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Optional profiler (for FLOPs estimation)
import torch.profiler as tprof
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# ========== Dependency imports ==========
from utils.optim_factory import Muon
from training.visualization import MultiModalVisualizer
from training.advanced_loss_functions import MaskedMultiTaskLoss, compute_Maskedmetrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, r2_score
)

# ---------------------------
# YAML configuration loading functions
# ---------------------------
def load_config_from_yaml(config_path: str) -> Dict:
    """
    Load configuration from YAML file
    
    Args:
        config_path: YAML configuration file path
        
    Returns:
        Configuration dictionary
    """
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    print(f"✅ Loaded configuration from YAML file: {config_path}")
    return config

def merge_config_with_args(config: Dict, args) -> Dict:
    """
    Merge YAML configuration with command line arguments, YAML config takes priority
    
    Args:
        config: YAML configuration dictionary
        args: Command line arguments object
        
    Returns:
        Merged configuration dictionary
    """
    merged_config = {}
    
    # Extract parameters from YAML configuration
    yaml_args = config.get('command_args', {})
    
    # Basic training parameters
    gpu_id = yaml_args.get('gpu', getattr(args, 'gpu', 4))
    merged_config.update({
        'gpu_id': gpu_id,
        'experiment_name': yaml_args.get('experiment_name', getattr(args, 'experiment_name', 'multimodal_training')),
        'log_dir': yaml_args.get('log_dir', getattr(args, 'log_dir', 'runs/debug')),
        'batch_size': int(yaml_args.get('batch_size', getattr(args, 'batch_size', 32))),
        'num_epochs': int(yaml_args.get('epochs', getattr(args, 'epochs', 10))),
        'learning_rate': float(yaml_args.get('learning_rate', yaml_args.get('lr', getattr(args, 'lr', 3e-4)))),
        'num_workers': int(yaml_args.get('num_workers', getattr(args, 'num_workers', 4))),
        'freeze_encoders': yaml_args.get('freeze_encoders', getattr(args, 'freeze_encoders', False)),
        'visualize': yaml_args.get('visualize', getattr(args, 'visualize', False)),
        'device': f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu',
    })
    
    # Model parameters
    merged_config.update({
        'shared_dim': int(yaml_args.get('shared_dim', getattr(args, 'shared_dim', 64))),
        'text_dim': int(yaml_args.get('text_dim', getattr(args, 'text_dim', 64))),
        'image_dim': int(yaml_args.get('image_dim', getattr(args, 'image_dim', 64))),
        'ts_dim': int(yaml_args.get('ts_dim', getattr(args, 'ts_dim', 64))),
        'static_dim': int(yaml_args.get('static_dim', getattr(args, 'static_dim', 5))),
        'num_heads': int(yaml_args.get('num_heads', getattr(args, 'num_heads', 4))),
        'num_layers': int(yaml_args.get('num_layers', getattr(args, 'num_layers', 2))),
        'dropout': float(yaml_args.get('dropout', getattr(args, 'dropout', 0.1))),
    })
    
    # Task configuration
    merged_config.update({
        'task': yaml_args.get('task', getattr(args, 'task', 'regression')),
        'task_name': yaml_args.get('task_name', getattr(args, 'task_name', 'debug')),
    })
    
    # Dataset split configuration
    split_config = config.get('data_split', {})
    merged_config.update({
        'split_ratio': {
            'train': float(split_config.get('train_ratio', yaml_args.get('train_ratio', getattr(args, 'train_ratio', 0.7)))),
            'val': float(split_config.get('val_ratio', yaml_args.get('val_ratio', getattr(args, 'val_ratio', 0.15)))),
            'test': float(split_config.get('test_ratio', yaml_args.get('test_ratio', getattr(args, 'test_ratio', 0.15))))
        },
        'random_seed': int(split_config.get('random_seed', yaml_args.get('random_seed', getattr(args, 'random_seed', 42)))),
        'force_rebuild_split': split_config.get('force_rebuild_split', yaml_args.get('force_rebuild_split', getattr(args, 'force_rebuild_split', False))),
    })
    
    # Cache configuration
    cache_config = config.get('cache', {})
    merged_config.update({
        'use_cache': cache_config.get('use_cache', yaml_args.get('use_cache', getattr(args, 'use_cache', True))),
        'cache_dir': cache_config.get('cache_dir', yaml_args.get('cache_dir', getattr(args, 'cache_dir', None))),
        'force_rebuild_cache': cache_config.get('force_rebuild_cache', yaml_args.get('force_rebuild_cache', getattr(args, 'force_rebuild_cache', False))),
        'cache_version': cache_config.get('cache_version', yaml_args.get('cache_version', getattr(args, 'cache_version', 'v1.0'))),
    })
    
    # Preprocessing configuration
    preprocessing_config = config.get('preprocessing', {})
    if preprocessing_config.get('enabled', False):
        merged_config['preprocessing_config'] = {
            'missing_value_strategy': preprocessing_config.get('missing_value_strategy', 'zero'),
            'outlier_handling': preprocessing_config.get('outlier_handling', 'winsorize'),
            'normalization_method': preprocessing_config.get('normalization_method', 'robust'),
            'feature_scaling': preprocessing_config.get('feature_scaling', 'minmax')
        }
    else:
        merged_config['preprocessing_config'] = None
    
    # EAMC configuration
    eamc_config = config.get('eamc', {})
    merged_config['eamc'] = {
        'node_budget': eamc_config.get('node_budget', 8),
        'max_segments': eamc_config.get('max_segments', 8),
        'use_nbc': eamc_config.get('use_nbc', True),
        'desired_threshold': eamc_config.get('desired_threshold', 0.8),
        'segment_len': eamc_config.get('segment_len', 12),
        'max_seq_len': eamc_config.get('max_seq_len', 96),
        'segment_mask_top_k_ratio': eamc_config.get('segment_mask_top_k_ratio', 0.3),
        'trimmer_types': eamc_config.get('trimmer_types', ["Decomposition"]),
    }
    
    # Dynamic segmentation control parameters
    merged_config.update({
        'use_dynseg': config.get('use_dynseg', True),
        'use_causal': config.get('use_causal', True),
        'frequency_branch_enabled': config.get('frequency_branch_enabled', True),
        'use_fft': config.get('use_fft', True),
        'use_wavelet': config.get('use_wavelet', True),
    })
        
    # Environment variables
    environment = config.get('environment', {})
    for key, value in environment.items():
        os.environ[key] = str(value)
        print(f"Set environment variable: {key}={value}")
    
    return merged_config

def print_config_summary(config: Dict):
    """Print configuration summary"""
    print("\n=== Training Configuration Summary ===")
    print(f"📊 Basic Parameters:")
    print(f"  GPU: {config['gpu_id']}")
    print(f"  Experiment name: {config['experiment_name']}")
    print(f"  Batch size: {config['batch_size']}")
    print(f"  Number of epochs: {config['num_epochs']}")
    print(f"  Learning rate: {config['learning_rate']}")
    print(f"  Device: cuda:{config['gpu_id']}")
    
    print(f"\n🏗️ Model Parameters:")
    print(f"  Shared dimension: {config['shared_dim']}")
    print(f"  Text dimension: {config['text_dim']}")
    print(f"  Image dimension: {config['image_dim']}")
    print(f"  Time series dimension: {config['ts_dim']}")
    print(f"  Static data dimension: {config['static_dim']}")
    print(f"  Number of attention heads: {config['num_heads']}")
    print(f"  Number of layers: {config['num_layers']}")
    print(f"  Dropout: {config['dropout']}")
    
    print(f"\n📋 Task Configuration:")
    print(f"  Task type: {config['task']}")
    print(f"  Task name: {config['task_name']}")
    
    print(f"\n📊 Dataset Split Configuration:")
    split_ratio = config.get('split_ratio', {'train': 0.7, 'val': 0.15, 'test': 0.15})
    print(f"  Training set ratio: {split_ratio['train']:.1%}")
    print(f"  Validation set ratio: {split_ratio['val']:.1%}")
    print(f"  Test set ratio: {split_ratio['test']:.1%}")
    print(f"  Random seed: {config.get('random_seed', 42)}")
    print(f"  Force rebuild split: {config.get('force_rebuild_split', False)}")
    
    print(f"\n💾 Cache Configuration:")
    print(f"  Use cache: {config['use_cache']}")
    if config['cache_dir']:
        print(f"  Cache directory: {config['cache_dir']}")
    print(f"  Force rebuild: {config['force_rebuild_cache']}")
    print(f"  Cache version: {config['cache_version']}")
    
    if config['preprocessing_config']:
        print(f"\n🔧 Preprocessing Configuration:")
        for key, value in config['preprocessing_config'].items():
            print(f"  {key}: {value}")
    
    # EAMC configuration
    eamc_config = config.get('eamc', {})
    if eamc_config:
        print(f"\n🎯 EAMC Dynamic Segmentation Configuration:")
        print(f"  Node budget: {eamc_config.get('node_budget', 8)}")
        print(f"  Max segments: {eamc_config.get('max_segments', 8)}")
        print(f"  Budget control: {eamc_config.get('use_nbc', True)}")
        print(f"  Segmentation threshold: {eamc_config.get('desired_threshold', 0.8)}")
        print(f"  Segment length: {eamc_config.get('segment_len', 12)}")
        print(f"  Max sequence length: {eamc_config.get('max_seq_len', 96)}")
        print(f"  Retained segment ratio: {eamc_config.get('segment_mask_top_k_ratio', 0.3)}")
        print(f"  Segmenter types: {eamc_config.get('trimmer_types', ['Decomposition'])}")
    
    # Dynamic segmentation control parameters
    print(f"\n⚙️ Dynamic Segmentation Control:")
    print(f"  Enable dynamic segmentation: {config.get('use_dynseg', True)}")
    print(f"  Enable causal graph: {config.get('use_causal', True)}")
    print(f"  Enable frequency branch: {config.get('frequency_branch_enabled', True)}")
    print(f"  Enable FFT: {config.get('use_fft', True)}")
    print(f"  Enable wavelet: {config.get('use_wavelet', True)}")
    
    print(f"\n📁 Log directory: {config['log_dir']}")
    print("=" * 50)


def split_dataset_by_patient(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):

    import random
    import numpy as np
    
    random.seed(random_seed)
    np.random.seed(random_seed)
    
    total_ratio = train_ratio + val_ratio + test_ratio


    patient_ids = list(set(dataset.common_sids))

    random.shuffle(patient_ids)
    
    n_patients = len(patient_ids)
    train_end = int(n_patients * train_ratio)
    val_end = train_end + int(n_patients * val_ratio)

    train_patients = patient_ids[:train_end]
    val_patients = patient_ids[train_end:val_end]
    test_patients = patient_ids[val_end:]

    train_indices = []
    val_indices = []
    test_indices = []
    
    for i, patient_id in enumerate(dataset.common_sids):
        if patient_id in train_patients:
            train_indices.append(i)
        elif patient_id in val_patients:
            val_indices.append(i)
        elif patient_id in test_patients:
            test_indices.append(i)
    
    return train_indices, val_indices, test_indices

def create_split_datasets(multi_dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_seed=42):
    """
    Create split datasets
    """
    from torch.utils.data import Subset
    
    train_indices, val_indices, test_indices = split_dataset_by_patient(
        multi_dataset, train_ratio, val_ratio, test_ratio, random_seed
    )
    
    train_dataset = Subset(multi_dataset, train_indices)
    val_dataset = Subset(multi_dataset, val_indices)
    test_dataset = Subset(multi_dataset, test_indices)
    
    
    return train_dataset, val_dataset, test_dataset

def save_split_info(train_indices, val_indices, test_indices, save_path, random_seed=42):

    import json
    import os
    
    split_info = {
        'random_seed': random_seed,
        'train_indices': train_indices,
        'val_indices': val_indices,
        'test_indices': test_indices,
        'train_size': len(train_indices),
        'val_size': len(val_indices),
        'test_size': len(test_indices),
        'total_size': len(train_indices) + len(val_indices) + len(test_indices)
    }
    
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'w') as f:
        json.dump(split_info, f, indent=2)
    

def load_split_info(save_path):

    import json

    with open(save_path, 'r') as f:
        split_info = json.load(f)
    
    return split_info

def _get_task_type(model) -> str:
    try:
        if hasattr(model, 'config') and isinstance(model.config, dict):
            return model.config.get('task_type', 'classification')
    except Exception:
        pass
    return 'classification'

def calculate_classification_metrics(predictions, targets) -> Dict[str, float]:
    # raw logits -> sigmoid prob
    if isinstance(predictions, torch.Tensor):
        predictions = torch.sigmoid(predictions).detach().cpu().numpy()
    if isinstance(targets, torch.Tensor):
        targets = targets.detach().cpu().numpy()
    if predictions.ndim > 1: predictions = predictions.reshape(-1)
    if targets.ndim > 1:     targets = targets.reshape(-1)
    n = min(len(predictions), len(targets))
    if n == 0:
        return {'accuracy': 0, 'precision': 0, 'recall': 0, 'f1': 0, 'auc': 0, 'ap': 0}
    predictions = predictions[:n]; targets = targets[:n]
    pred_binary = (predictions > 0.5).astype(np.int32)
    metrics = {}
    try: metrics['accuracy']  = float(accuracy_score(targets, pred_binary))
    except Exception: metrics['accuracy'] = 0.0
    try: metrics['precision'] = float(precision_score(targets, pred_binary, zero_division=0))
    except Exception: metrics['precision'] = 0.0
    try: metrics['recall']    = float(recall_score(targets, pred_binary, zero_division=0))
    except Exception: metrics['recall'] = 0.0
    try: metrics['f1']        = float(f1_score(targets, pred_binary, zero_division=0))
    except Exception: metrics['f1'] = 0.0
    try:
        if len(np.unique(targets)) < 2: raise ValueError
        metrics['auc'] = float(roc_auc_score(targets, predictions))
    except Exception: metrics['auc'] = 0.0
    try:
        if (targets == 1).sum() == 0: raise ValueError
        metrics['ap'] = float(average_precision_score(targets, predictions))
    except Exception: metrics['ap'] = 0.0
    return metrics

def _to_numpy_float(x):
    if isinstance(x, torch.Tensor):
        x = x.detach().cpu().float().numpy()
    elif isinstance(x, dict):
        if 'y_pred' in x:
            x = x['y_pred']
            if isinstance(x, torch.Tensor):
                x = x.detach().cpu().float().numpy()
        else:
            first_key = next(iter(x.keys()))
            x = x[first_key]
            if isinstance(x, torch.Tensor):
                x = x.detach().cpu().float().numpy()
    return x

def calculate_forecasting_metrics(preds, targs) -> Dict[str, float]:
    preds = _to_numpy_float(preds); targs = _to_numpy_float(targs)
    
    if len(targs.shape) == 4 and targs.shape[1] == 1:
        targs = targs.squeeze(1)

    mask = None
    if len(targs.shape) >= 2 and targs.shape[-1] == 2:
        mask = targs[..., 1] != 0  
        targs = targs[..., 0]  
    
    if len(targs.shape) == 4 and targs.shape[1] == 1:
        targs = targs.squeeze(1)
        if mask is not None:
            mask = mask.squeeze(1)
    
    if preds.size == 0 or targs.size == 0:
        return {'mse':0,'mae':0,'rmse':0,'mape':0,'r2':0}
    
    n = min(preds.size, targs.size)
    preds = preds.reshape(-1)[:n]
    targs = targs.reshape(-1)[:n]
    
    if mask is not None:
        mask = mask.reshape(-1)[:n]
        valid_mask = mask.astype(bool)
        if valid_mask.sum() == 0:
            print("[DEBUG] No valid samples after masking")
            return {'mse':0,'mae':0,'rmse':0,'mape':0,'r2':0}
     
        preds = preds[valid_mask]
        targs = targs[valid_mask]

    
    diff = preds - targs
    abs_diff = np.abs(diff)
    mse = float(np.mean(diff**2))
    mae = float(np.mean(abs_diff))
    rmse = float(np.sqrt(max(mse, 0.0)))
    
    eps = 1e-8
    mape = float(np.mean(abs_diff / (np.abs(targs) + eps))) * 100.0
    smape = float(np.mean(2.0 * abs_diff / (np.abs(preds) + np.abs(targs) + eps))) * 100.0
    r2 = float(r2_score(targs, preds))
    return {'mse':mse,'mae':mae,'rmse':rmse,'mape':mape,'smape':smape,'r2':r2}

def _unwrap_dataset(obj):
    try:
        from torch.utils.data import Subset
    except Exception:
        Subset = None
    seen = set()
    cur = obj
    while True:
        if cur is None: return None
        if id(cur) in seen: return cur
        seen.add(id(cur))
        if Subset is not None and isinstance(cur, Subset):
            cur = cur.dataset
        elif hasattr(cur, 'dataset'):
            cur = getattr(cur, 'dataset')
        else:
            return cur

def _get_timemmd_target_stats_from_loader(loader) -> Tuple[float, float]:
    try:
        ds = getattr(loader, 'dataset', None)
        base = _unwrap_dataset(ds)
        mu = getattr(base, '_target_mean', None)
        std = getattr(base, '_target_std', None)
        if isinstance(mu, (int,float)) and isinstance(std, (int,float)) and std is not None and std > 0:
            return float(mu), float(std)
    except Exception:
        pass
    return None, None

def print_gpu_memory_usage():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved  = torch.cuda.memory_reserved() / 1024**3
        print(f"[GPU] allocated={allocated:.2f}GB reserved={reserved:.2f}GB")

def print_model_analysis(model, device):
    print(f"\n=== Model Analysis ===")
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    param_size = total_params * 4 / 1024**2  
    
    if hasattr(model, 'ff'):
        print(f"\n🔧 Module Parameters:")
        ff = model.ff
        modules = ['image_encoder', 'text_encoder', 'ts_encoder', 'static_encoder']
        for module_name in modules:
            if hasattr(ff, module_name):
                module = getattr(ff, module_name)
                module_params = sum(p.numel() for p in module.parameters())
                module_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
                print(f"  {module_name}: {module_params:,} parameters ({module_trainable:,} trainable)")
    
    if torch.cuda.is_available():
        print(f"\n💾 Memory Requirement Estimation:")
        # Forward propagation memory (model parameters + activations)
        forward_memory = param_size * 2  
        print(f"  Forward propagation: ~{forward_memory:.1f}MB")
        
        backward_memory = param_size * 3  
        print(f"  Backward propagation: ~{backward_memory:.1f}MB")
        
        total_memory = backward_memory / 1024  
        print(f"  Total requirement: ~{total_memory:.1f}GB")
        
        current_allocated = torch.cuda.memory_allocated() / 1024**3
        current_reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"  Current usage: {current_allocated:.1f}GB (allocated) / {current_reserved:.1f}GB (reserved)")

def _read_structural_from_model(model):
    V = E = dbar = None
    IRR = BCR = None
    
    li = getattr(model, 'last_eamc_info', None)
    if isinstance(li, dict):
        if isinstance(li.get('segment_centers_idx'), torch.Tensor):
            V = int(li['segment_centers_idx'].shape[1])
        elif isinstance(li.get('segment_mask'), torch.Tensor):
            try:
                V = int(li['segment_mask'].sum(dim=1).float().mean().item())
            except Exception:
                V = None
        elif isinstance(li.get('S'), torch.Tensor):
            try:
                V = int(li['S'].mean().item())
            except Exception:
                V = None
        
        IRR = li.get('IRR', None)
        BCR = li.get('BCR', None)

    gi = getattr(model, 'last_graph_info', None)
    if isinstance(gi, dict):
        V = gi.get('V', V)  
        E = gi.get('E', E)  
        dbar = gi.get('dbar', dbar)  
    
    return V, E, dbar, IRR, BCR

def _maybe_profile_flops(model, batch, device, enabled=False):
    if not enabled or 'tprof' not in globals() or tprof is None:
        return None
    images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
    images = images.to(device, non_blocking=True); static_data = static_data.to(device, non_blocking=True)
    if padded_time_series.dim() == 5:
        b, n, t, f1, f2 = padded_time_series.shape
        padded_time_series = padded_time_series.view(b, n, t, f1 * f2)
    padded_time_series = padded_time_series.to(device, non_blocking=True)
    time_series_lengths = time_series_lengths.to(device, non_blocking=True)
    for k in text_item:
        text_item[k] = text_item[k].to(device, non_blocking=True)
    packed_ts = {'ts_data': padded_time_series, 'seq_lengths': time_series_lengths}
    try:
        with tprof.profile(
            activities=[tprof.ProfilerActivity.CPU] + ([tprof.ProfilerActivity.CUDA] if torch.cuda.is_available() else []),
            record_shapes=True, with_flops=True, profile_memory=False
        ) as prof:
            with torch.no_grad():
                _ = model(images, text_item, packed_ts, static_data, task=_get_task_type(model))
        total = 0
        for evt in prof.key_averages():
            if hasattr(evt, 'flops') and evt.flops is not None:
                total += int(evt.flops)
        return int(total) if total > 0 else None
    except Exception:
        return None

# ---------------------------
# Parameter freezing and statistics
# ---------------------------
def freeze_modules_and_report(model, module_names=('text_encoder', 'image_encoder'), title="Freeze encoders"):
    """
    Freeze the parameters of the specified modules (requires_grad=False) and print:
      - Total number of parameters
      - Number of frozen parameters and ratio
      - Number of trainable parameters and ratio
      - Number of frozen parameters for each module
    """
    import math
    total_params = sum(p.numel() for p in model.parameters())

    per_module_frozen = {}
    for name in module_names:
        mod = getattr(model, name, None)
        if mod is None:
            print(f"[Freeze][Warn] Module not found: {name}")
            continue
        frozen_this = 0
        for p in mod.parameters():
            if p.requires_grad:
                p.requires_grad = False
            frozen_this += p.numel()
        per_module_frozen[name] = frozen_this

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params

    def _pct(x): 
        return 0.0 if total_params == 0 else (x / total_params * 100.0)

    print(f"\n=== {title} ===")
    for k, v in per_module_frozen.items():
        print(f"  - {k}: frozen {v:,} parameters (approx {_pct(v):.2f}%)")
    print("📊 Parameter Statistics:")
    print(f"  Total parameters:       {total_params:,}")
    print(f"  Frozen parameters:     {frozen_params:,} ({_pct(frozen_params):.2f}%)")
    print(f"  Trainable parameters:   {trainable_params:,} ({_pct(trainable_params):.2f}%)\n")

def validate(model, val_loader, criterion, device, writer, global_step, visualizer=None, task: str = None):
    model.eval()
    if task is None:
        task = _get_task_type(model)

    total_loss = 0.0
    total_align = 0.0
    preds_list, targs_list = [], []
    start = time.time()
    did_visualize = False
    mse_all, mae_all = [], []
    with torch.no_grad():
        for batch in val_loader:
            images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
            images = images.to(device, non_blocking=True)
            static_data = static_data.to(device, non_blocking=True)
            padded_time_series = padded_time_series.to(device, non_blocking=True)
            time_series_lengths = time_series_lengths.to(device, non_blocking=True)
            for k in text_item:
                text_item[k] = text_item[k].to(device, non_blocking=True)
            packed_ts = {'ts_data': padded_time_series, 'seq_lengths': time_series_lengths}

            # Labels
            if isinstance(labels, dict):
                key = 'death_24h' if 'death_24h' in labels else next(iter(labels.keys()))
                target = labels[key].float().to(device, non_blocking=True)
            else:
                target = labels.float().to(device, non_blocking=True)

            # Forward
            out = model(images, text_item, packed_ts, static_data, task=task, return_ae=False)
            logits = out['logits'] if isinstance(out, dict) else out
            align_loss = out.get('clip_loss', torch.tensor(0.0, device=device)) if isinstance(out, dict) else torch.tensor(0.0, device=device)

            # === Classification task (actual regression task for last step prediction) ===
            if task == 'classification':
                target = target[:,0]
                # Process target values: extract numerical part and missing flag
                if target.dim() == 3 and target.size(-1) == 2:
                    target_val = target[..., 0]  # [B, F] numerical part
                    target_mask = target[..., 1]  # [B, F] missing flag
                else:
                    target_val = target
                    target_mask = torch.zeros_like(target)
                
                # Check if the output format is MaskedMultiTaskLoss
                if isinstance(logits, dict) and 'y_pred' in logits and 'm_pred' in logits:
                    y_pred = logits['y_pred']
                    m_pred = logits['m_pred']
                    
                    loss_pred, loss_metrics = criterion(logits, target)
                    
                    # Compute extended masked metrics
                    with torch.no_grad():
                        masked_metrics = compute_Maskedmetrics(logits, target)
                        batch_mae = masked_metrics['mae']
                        batch_mse = masked_metrics['mse']
                    
                    # Save complete output for validation
                    preds_list.append(logits)
                    targs_list.append(target)
                    vis_predictions = y_pred
                    vis_targets = target
                else:
                    # Standard regression processing
                    if isinstance(logits, torch.Tensor):
                        # Ensure logits and target_val shapes match
                        if logits.dim() > target_val.dim():
                            logits = logits.squeeze(-1)
                        if target_val.dim() > logits.dim():
                            target_val = target_val.squeeze(-1)
                        
                        # Only compute loss at non-missing positions
                        diff = (logits - target_val) * (1 - target_mask)
                        loss_pred = F.smooth_l1_loss(diff, torch.zeros_like(diff), reduction='sum')
                        
                        # Average (avoid NaN with all-zero mask)
                        denom = target_mask.sum().clamp(min=1.0)
                        loss_pred = loss_pred / denom
                        
                        # Compute MAE and MSE
                        with torch.no_grad():
                            if denom > 0:
                                batch_mae = torch.abs(diff).sum() / denom
                                batch_mse = torch.square(diff).sum() / denom
                            else:
                                batch_mae = torch.tensor(0.0, device=device)
                                batch_mse = torch.tensor(0.0, device=device)
                        mse_all.append(batch_mse.item())
                        mae_all.append(batch_mae.item())
                    else:
                        loss_pred = torch.tensor(0.0, device=device)
                        batch_mae = torch.tensor(0.0, device=device)
                        batch_mse = torch.tensor(0.0, device=device)

                    preds_list.append(logits[0].detach())
                    targs_list.append(target.squeeze(1).detach())
                    vis_predictions = logits
                    vis_targets = target

            # === Regression/Prediction task ===
            else:
                if isinstance(logits, torch.Tensor) and isinstance(target, torch.Tensor):
                    target_val = target[..., 0].squeeze(1)   # [B, N]
                    target_mask = target[..., 1].squeeze(1)  # [B, N]
                    diff = (logits - target_val) * (1 - target_mask)

                    loss_pred = F.smooth_l1_loss(diff, torch.zeros_like(diff), reduction='sum')
                    denom = target_mask.sum().clamp(min=1.0)
                    loss_pred = loss_pred / denom

                    if denom > 0:
                        batch_mae = torch.abs(diff).sum() / denom
                        batch_mse = torch.square(diff).sum() / denom
                    else:
                        batch_mae, batch_mse = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)
                else:
                    loss_pred = torch.tensor(0.0, device=device)
                    batch_mae = torch.tensor(0.0, device=device)
                    batch_mse = torch.tensor(0.0, device=device)

                preds_list.append(logits.detach())
                targs_list.append(target.detach())
                vis_predictions, vis_targets = logits, target

            # === Loss weighting ===
            align_weight = float(os.environ.get('ALIGN_WEIGHT', '0.05'))
            loss = (loss_pred if isinstance(loss_pred, torch.Tensor) else loss_pred[0]) + align_weight * align_loss
            total_loss += float(loss.item())
            total_align += float(align_loss.item())

            # === Visualization ===
            # if (not did_visualize) and visualizer is not None:
            #     pass
                # try:
                #     vis_batch = (images, text_item, packed_ts, static_data, {'target': vis_targets})
                #     visualizer.create_comprehensive_report(
                #         vis_batch,
                #         predictions=vis_predictions,
                #         targets=vis_targets,
                #         total_loss=float(loss.item()),
                #         align_loss=float(align_loss.item()),
                #         pred_loss=float(loss_pred.item() if isinstance(loss_pred, torch.Tensor) else loss_pred[0].item()),
                #         step=global_step
                #     )
                #     did_visualize = True
                # except Exception as e:
                #     print(f"[Visualizer][Warn] validation visualize failed: {e}")

    # === Summary ===
    avg_loss = total_loss / max(len(val_loader), 1)
    avg_align = total_align / max(len(val_loader), 1)
    
    if preds_list and isinstance(preds_list[0], dict):
        pred_y_list = [pred['y_pred'] for pred in preds_list if isinstance(pred, dict) and 'y_pred' in pred]
        pred_m_list = [pred['m_pred'] for pred in preds_list if isinstance(pred, dict) and 'm_pred' in pred]
        preds = {'y_pred': torch.cat(pred_y_list, dim=0) if pred_y_list else torch.zeros(1, 1),
                'm_pred': torch.cat(pred_m_list, dim=0) if pred_m_list else torch.zeros(1, 1)}
    else:
        tensor_preds = [pred for pred in preds_list if isinstance(pred, torch.Tensor)]
        if tensor_preds:
            preds = {'y_pred': torch.cat(tensor_preds, dim=0), 'm_pred': torch.zeros_like(torch.cat(tensor_preds, dim=0))}
        else:
            preds = {'y_pred': torch.zeros(1, 1), 'm_pred': torch.zeros(1, 1)}
    
    targs = torch.cat(targs_list, dim=0) if targs_list else torch.zeros(1, 1)

    if task == 'classification':
        metrics_ori = calculate_forecasting_metrics(preds['y_pred'], targs)
        
        if isinstance(preds, dict) and 'y_pred' in preds:
            masked_metrics = compute_Maskedmetrics(preds, targs)
            metrics = metrics_ori.copy()
            metrics.update(masked_metrics)
            
            if 'imputation_mse' in metrics:
                print(f"[VAL] imputation_mse={metrics.get('imputation_mse',0):.4f} "
                      f"imputation_mae={metrics.get('imputation_mae',0):.4f}")
            if 'missing_pred_acc' in metrics:
                print(f"[VAL] missing_pred_acc={metrics.get('missing_pred_acc',0):.4f} "
                      f"missing_pred_auc={metrics.get('missing_pred_auc',0):.4f}")
        else:
            metrics = metrics_ori
            print(f"[VAL] loss={avg_loss:.4f} align={avg_align:.4f} "
                  f"mse={metrics.get('mse',0):.4f} mae={metrics.get('mae',0):.4f} "
                  f"rmse={metrics.get('rmse',0):.4f} r2={metrics.get('r2',0):.4f}")
    else:
        metrics = calculate_forecasting_metrics(preds['y_pred'], targs)
        msg = f"[VAL] loss={avg_loss:.4f} align={avg_align:.4f} mae={metrics['mae']:.4f} rmse={metrics['rmse']:.4f} mape={metrics['mape']:.4f} r2={metrics['r2']:.4f}"
        print(msg)

    # Structural metrics
    V, E, dbar, IRR, BCR = _read_structural_from_model(model)
    metrics.update({'V': V, 'E': E, 'dbar': dbar, 'IRR': IRR, 'BCR': BCR})
    for k in ('E', 'dbar', 'IRR', 'BCR'):
        if metrics.get(k, None) is None:
            metrics[k] = 0
            print(f"[VAL][Info] Structural metric '{k}' missing, set to 0.")
    metrics['latency'] = float(time.time() - start)

    return avg_loss, avg_align, metrics

# ---------------------------
# Training function (module level)
# ---------------------------
def train(model, train_loader, val_loader=None, num_epochs=10, lr=1e-3, device='cuda',
          log_dir="runs", experiment_name=None, visualize: bool = True, task: str = None):

    if task is None:
        task = _get_task_type(model)

    MULTI_PATH = False
    # Lightweight performance options
    torch.backends.cudnn.benchmark = True
    model = model.to(device)
    
    # Performance analysis statistics
    performance_stats = {
        'data_load_times': [],
        'forward_times': [],
        'backward_times': [],
        'optimizer_times': [],
        'total_batch_times': [],
        'gpu_memory_usage': [],
        'batch_sizes': []
    }

    if os.environ.get('FREEZE_TEXT_IMAGE', '') or os.environ.get('FREEZE_ENCODERS', ''):
        freeze_modules_and_report(model, ('text_encoder', 'image_encoder'), title="Freeze text/image encoders before training")

    optimizer = Muon(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.01)
    scaler = GradScaler(enabled=torch.cuda.is_available())
    # Log directory
    if experiment_name is None:
        experiment_name = f"{task}_train_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    log_path = os.path.join(log_dir, experiment_name)
    os.makedirs(log_path, exist_ok=True)
    writer = SummaryWriter(log_dir=log_path)

    config_dict = {
        'num_epochs': num_epochs, 'learning_rate': lr, 'device': str(device), 'task': task,
        'model_params': sum(p.numel() for p in model.parameters()),
        'trainable_params': sum(p.numel() for p in model.parameters() if p.requires_grad),
    }
    # for k, v in config_dict.items():
    #     writer.add_text(f'config/{k}', str(v), 0)

    print(f"==> Start Training [{task}]: {experiment_name}")
    print(f"Log dir: {log_path}")
    print(f"Total epochs to train: {num_epochs}")
    print_gpu_memory_usage()
    global_step = 0
    best_val = float('inf')
    best_metrics = {}
    task = 'classification'
    with suppress(Exception):
        _mb = os.environ.get('MAX_TRAIN_BATCHES', '')
        max_train_batches = int(_mb) if _mb else None
    criterion = MaskedMultiTaskLoss()
    for epoch in range(num_epochs):
        model.train()
        if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats()
        current_epoch = epoch + 1
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {current_epoch}/{num_epochs}")

        epoch_loss = 0.0
        epoch_align = 0.0
        epoch_mae = 0.0
        epoch_mse = 0.0
        n_batches = 0
        t0 = time.time()

        for batch_idx, batch in pbar:
            global_step += 1
            batch_start_time = time.time()
            
            if (batch_idx % 10 == 0) and torch.cuda.is_available():
                torch.cuda.empty_cache()

            data_load_start = time.time()
            images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
            images = images.to(device, non_blocking=True)
            static_data = static_data.to(device, non_blocking=True)
            padded_time_series = padded_time_series.to(device, non_blocking=True)
            time_series_lengths = time_series_lengths.to(device, non_blocking=True)
            for k in text_item:
                text_item[k] = text_item[k].to(device, non_blocking=True)
            packed_ts = {'ts_data': padded_time_series, 'seq_lengths': time_series_lengths}
            data_load_time = time.time() - data_load_start

            if isinstance(labels, dict):
                key = 'death_24h' if 'death_24h' in labels else next(iter(labels.keys()))
                target = labels[key].float().to(device, non_blocking=True)
            else:
                target = labels.float().to(device, non_blocking=True)
            

            forward_start = time.time()
            optimizer.zero_grad(set_to_none=True)
            with autocast('cuda', enabled=torch.cuda.is_available()):
                out = model(images, text_item, packed_ts, static_data, task=task, return_ae=False)
                logits = out['logits'] if isinstance(out, dict) else out
                align_loss = out.get('clip_loss', torch.tensor(0.0, device=device)) if isinstance(out, dict) else torch.tensor(0.0, device=device)
                
                
            forward_time = time.time() - forward_start

            loss_start = time.time()

            if task == 'classification':
                target = target[:,0]
                if target.dim() == 3 and target.size(-1) == 2:
                    target_val = target[..., 0]  # [B, F] numerical part
                    target_mask = target[..., 1]  # [B, F] missing flag
                else:
                    target_val = target
                    target_mask = torch.zeros_like(target)
                
                if isinstance(logits, dict) and 'y_pred' in logits and 'm_pred' in logits:
                    y_pred = logits['y_pred']
                    m_pred = logits['m_pred']
                    
                    loss_pred, loss_metrics = criterion(logits, target)
                    
                    with torch.no_grad():
                        masked_metrics = compute_Maskedmetrics(logits, target)
                        batch_mae = masked_metrics['mae']
                        batch_mse = masked_metrics['mse']
                else:
                    if isinstance(logits, torch.Tensor):
                        # Ensure logits and target_val shapes match
                        if logits.dim() > target_val.dim():
                            logits = logits.squeeze(-1)
                        if target_val.dim() > logits.dim():
                            target_val = target_val.squeeze(-1)
                        
                        # Only compute loss at non-missing positions
                        diff = (logits - target_val) * (1 - target_mask)
                        loss_pred = F.smooth_l1_loss(diff, torch.zeros_like(diff), reduction='sum')
                        
                        # Average (avoid NaN with all-zero mask)
                        denom = target_mask.sum().clamp(min=1.0)
                        loss_pred = loss_pred / denom
                        
                        # Compute MAE and MSE
                        with torch.no_grad():
                            if denom > 0:
                                batch_mae = torch.abs(diff).sum() / denom
                                batch_mse = torch.square(diff).sum() / denom
                            else:
                                batch_mae = torch.tensor(0.0, device=device)
                                batch_mse = torch.tensor(0.0, device=device)
                    else:
                        loss_pred = torch.tensor(0.0, device=device)
                        batch_mae = torch.tensor(0.0, device=device)
                        batch_mse = torch.tensor(0.0, device=device)

            else:
                # Prediction task (forecasting/regression)
                if isinstance(logits, torch.Tensor) and isinstance(target, torch.Tensor):
                    target_val = target[..., 0].squeeze(1)   # [B, N]
                    target_mask = target[..., 1].squeeze(1)  # [B, N]
                    MULTI_PATH = False
                    if MULTI_PATH:
                        logits = logits.view(logits.size(0), -1, logits.size(-1))
                        target_val = target_val.view(target_val.size(0), -1, target_val.size(-1))
                        target_mask = target_mask.view(target_mask.size(0), -1, target_mask.size(-1))
                    else:
                        logits = logits[:, 0]
                        target_val = target_val[:, 0]
                        target_mask = target_mask[:, 0]
                    
                    # Only compute loss at non-missing positions
                    diff = (logits - target_val) * (1-target_mask)
                    loss_pred = F.smooth_l1_loss(diff, torch.zeros_like(diff), reduction='sum')

                    # Average (avoid NaN with all-zero mask)
                    denom = target_mask.sum().clamp(min=1.0)
                    loss_pred = loss_pred / denom
                    
                    # Compute MAE and MSE (prediction task)
                    with torch.no_grad():
                        if denom > 0:
                            batch_mae = torch.abs(diff).sum() / denom
                            batch_mse = torch.square(diff).sum() / denom
                        else:
                            batch_mae = torch.tensor(0.0, device=device)
                            batch_mse = torch.tensor(0.0, device=device)
                else:
                    loss_pred = torch.tensor(0.0, device=device)
                    batch_mae = torch.tensor(0.0, device=device)
                    batch_mse = torch.tensor(0.0, device=device)

            # Adjust alignment loss weight: reduce alignment loss impact when extremely unbalanced
            align_weight = float(os.environ.get('ALIGN_WEIGHT', '0.05'))
            if isinstance(loss_pred, tuple):
                total_loss = loss_pred[0] + align_weight * align_loss
            else:
                total_loss = loss_pred + align_weight * align_loss

            # Backward propagation and optimizer time measurement
            backward_start = time.time()
            scaler.scale(total_loss).backward()
            scaler.step(optimizer)
            scaler.update()
            backward_time = time.time() - backward_start

            if isinstance(loss_pred, tuple):
                epoch_loss += float(loss_pred[0].item())
            else:
                epoch_loss += float(loss_pred.item())
            epoch_align += float(align_loss.item())
            epoch_mae += float(batch_mae)
            epoch_mse += float(batch_mse)
            n_batches += 1
            
            # Calculate total batch time
            total_batch_time = time.time() - batch_start_time
            
            # Collect performance statistics
            performance_stats['data_load_times'].append(data_load_time)
            performance_stats['forward_times'].append(forward_time)
            performance_stats['backward_times'].append(backward_time)
            performance_stats['total_batch_times'].append(total_batch_time)
            performance_stats['batch_sizes'].append(images.size(0))
            
            if torch.cuda.is_available():
                performance_stats['gpu_memory_usage'].append(torch.cuda.memory_allocated() / 1024**3)

            if max_train_batches is not None and n_batches >= max_train_batches:
                break

            if (batch_idx + 1) % 10 == 0:
                if isinstance(loss_pred, tuple):
                    writer.add_scalar('train/step_loss', loss_pred[0].item(), global_step)
                else:
                    writer.add_scalar('train/step_loss', loss_pred.item(), global_step)
                writer.add_scalar('train/step_align_loss', align_loss.item(), global_step)
                writer.add_scalar('train/step_mae', batch_mae, global_step)
                writer.add_scalar('train/step_mse', batch_mse, global_step)
                writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step)
                if torch.cuda.is_available():
                    writer.add_scalar('system/gpu_memory_allocated', torch.cuda.memory_allocated() / 1024**3, global_step)
                    writer.add_scalar('system/gpu_memory_reserved',  torch.cuda.memory_reserved() / 1024**3,  global_step)
                
                # Add performance metrics to TensorBoard
                writer.add_scalar('performance/data_load_time', data_load_time, global_step)
                writer.add_scalar('performance/forward_time', forward_time, global_step)
                writer.add_scalar('performance/backward_time', backward_time, global_step)
                writer.add_scalar('performance/total_batch_time', total_batch_time, global_step)
                writer.add_scalar('performance/batch_size', images.size(0), global_step)
                
                if isinstance(loss_pred, tuple):
                    loss_value = loss_pred[0].item()
                else:
                    loss_value = loss_pred.item()
                pbar.set_postfix({
                    'loss': f"{loss_value:.4f}", 
                    'align': f"{align_loss.item():.4f}",
                    'mae': f"{batch_mae:.4f}",
                    'mse': f"{batch_mse:.4f}",
                    'data': f"{data_load_time*1000:.1f}ms",
                    'forward': f"{forward_time*1000:.1f}ms",
                    'backward': f"{backward_time*1000:.1f}ms",
                    'total': f"{total_batch_time*1000:.1f}ms"
                })

        avg_loss = epoch_loss / max(n_batches, 1)
        avg_align = epoch_align / max(n_batches, 1)
        avg_mae = epoch_mae / max(n_batches, 1)
        avg_mse = epoch_mse / max(n_batches, 1)
        epoch_time = time.time() - t0
        
        # Performance analysis summary
        if performance_stats['data_load_times']:
            avg_data_load = sum(performance_stats['data_load_times']) / len(performance_stats['data_load_times'])
            avg_forward = sum(performance_stats['forward_times']) / len(performance_stats['forward_times'])
            avg_backward = sum(performance_stats['backward_times']) / len(performance_stats['backward_times'])
            avg_batch_time = sum(performance_stats['total_batch_times']) / len(performance_stats['total_batch_times'])
            avg_batch_size = sum(performance_stats['batch_sizes']) / len(performance_stats['batch_sizes'])
            
            # Calculate throughput (samples/second)
            throughput = avg_batch_size / avg_batch_time if avg_batch_time > 0 else 0
            
            # Calculate time ratio of each stage
            data_ratio = avg_data_load / avg_batch_time * 100 if avg_batch_time > 0 else 0
            forward_ratio = avg_forward / avg_batch_time * 100 if avg_batch_time > 0 else 0
            backward_ratio = avg_backward / avg_batch_time * 100 if avg_batch_time > 0 else 0
            
            if torch.cuda.is_available() and performance_stats['gpu_memory_usage']:
                avg_gpu_mem = sum(performance_stats['gpu_memory_usage']) / len(performance_stats['gpu_memory_usage'])
                max_gpu_mem = max(performance_stats['gpu_memory_usage'])
            
            # Record to TensorBoard
            writer.add_scalar('performance/avg_data_load_time', avg_data_load, current_epoch)
            writer.add_scalar('performance/avg_forward_time', avg_forward, current_epoch)
            writer.add_scalar('performance/avg_backward_time', avg_backward, current_epoch)
            writer.add_scalar('performance/avg_batch_time', avg_batch_time, current_epoch)
            writer.add_scalar('performance/throughput', throughput, current_epoch)
            writer.add_scalar('performance/data_ratio', data_ratio, current_epoch)
            writer.add_scalar('performance/forward_ratio', forward_ratio, current_epoch)
            writer.add_scalar('performance/backward_ratio', backward_ratio, current_epoch)
            
            # Clear performance statistics, prepare for next epoch
            performance_stats = {
                'data_load_times': [],
                'forward_times': [],
                'backward_times': [],
                'optimizer_times': [],
                'total_batch_times': [],
                'gpu_memory_usage': [],
                'batch_sizes': []
            }
        
        writer.add_scalar('train/epoch_loss', avg_loss, current_epoch)
        writer.add_scalar('train/epoch_align_loss', avg_align, current_epoch)
        writer.add_scalar('train/epoch_mae', avg_mae, current_epoch)
        writer.add_scalar('train/epoch_mse', avg_mse, current_epoch)
        writer.add_scalar('train/epoch_time', epoch_time, current_epoch)
        print(f"[EPOCH {current_epoch}] loss={avg_loss:.4f} align={avg_align:.4f} mae={avg_mae:.4f} mse={avg_mse:.4f} time={epoch_time:.1f}s")

        # ===== Validation and CSV writing =====
        if val_loader is not None:
            val_loss, val_align, val_metrics = validate(model, val_loader, criterion, device, writer, global_step, visualizer=None, task=task)
            writer.add_scalar('val/epoch_loss', val_loss, current_epoch)
            writer.add_scalar('val/epoch_align_loss', val_align, current_epoch)
            # Log forecasting metrics to TensorBoard if present
            for _k in ('mse','rmse','mae','mape','smape','r2'):
                _v = val_metrics.get(_k, None)
                if isinstance(_v, (int, float)):
                    writer.add_scalar(f'val/{_k}', _v, current_epoch)

            # Structural/fidelity metrics
            Vbar   = val_metrics.get('V', None)
            Ebar   = val_metrics.get('E', None)
            dbar   = val_metrics.get('dbar', None)
            IRRbar = val_metrics.get('IRR', None)
            BCRbar = val_metrics.get('BCR', None)

            # Task metrics (fill None for missing)
            AUC = val_metrics.get('auc', None)
            AP  = val_metrics.get('ap',  None)
            F1  = val_metrics.get('f1',  None)
            MAE = val_metrics.get('mae', val_metrics.get('MAE', None))

            vram_gb = torch.cuda.max_memory_allocated() / 1024**3 if torch.cuda.is_available() else None
            flops_val = None
            with suppress(Exception):
                if os.environ.get('PROFILE_FLOPS', ''):
                    first_batch = next(iter(val_loader))
                    flops_val = _maybe_profile_flops(model, first_batch, device, enabled=True)

            val_latency = val_metrics.get('latency', None)

            csv_path = os.path.join(log_path, 'epoch_metrics.csv')
            row = {
                'epoch': current_epoch,
                'train_loss': avg_loss,
                'val_loss':   val_loss,
                '|V|':        Vbar,
                '|E|':        Ebar,
                'dbar':       dbar,
                'IRR':        IRRbar,
                'BCR':        BCRbar,
                'FLOPs':      flops_val,
                'VRAM_GB':    vram_gb,
                'Latency_s':  val_latency,
                'AUC':        AUC,
                'AP':         AP,
                'F1':         F1,
                'MAE':        MAE,
            }
            fieldnames = ['epoch','train_loss','val_loss','|V|','|E|','dbar','IRR','BCR','FLOPs','VRAM_GB','Latency_s','AUC','AP','F1','MAE']  # Remove denormalization metrics
            file_exists = os.path.exists(csv_path)
            with open(csv_path, 'a', newline='') as f:
                w = csv.DictWriter(f, fieldnames=fieldnames)
                if not file_exists: w.writeheader()
                w.writerow(row)
            if val_loss < best_val:
                best_val = val_loss
                best_metrics = val_metrics
                save_path = os.path.join(log_path, 'best_model.pth')
                torch.save({
                    'epoch': current_epoch,  # Use consistent epoch counting
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_metrics': val_metrics,
                    'config': config_dict
                }, save_path)
                print(f"==> New best saved @ {save_path} (epoch={current_epoch}, val_loss={val_loss:.4f})")
    final_path = os.path.join(log_path, 'final_model.pth')
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict
    }, final_path)

    writer.add_text('training_summary', f"""
    Training completed:
    - Total epochs: {num_epochs}
    - Best val loss: {best_val:.4f}
    - Best val metrics: {best_metrics}
    - Training completed at epoch: {num_epochs}
    """, 0)

    if 'visualizer' in locals() and visualizer is not None:
        with suppress(Exception):
            visualizer.close()
    writer.close()
    print(f"Training completed! Logs: {log_path}")
    print(f"Best val loss: {best_val:.4f}")
    if best_metrics:
        print(f"Best val metrics: {best_metrics}")

def sanity_forward_step(model, data_loader, device):
    """
    Run a single batch through the model to validate shapes and device placement.
    """
    model.eval()
    batch = next(iter(data_loader))
    images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
    images = images.to(device, non_blocking=True)
    static_data = static_data.to(device, non_blocking=True)
    if padded_time_series.dim() == 5:
        b, n, t, f1, f2 = padded_time_series.shape
        padded_time_series = padded_time_series.view(b, n, t, f1 * f2)
    padded_time_series = padded_time_series.to(device, non_blocking=True)
    time_series_lengths = time_series_lengths.to(device, non_blocking=True)
    for k in text_item:
        text_item[k] = text_item[k].to(device, non_blocking=True)
    packed_ts = {'ts_data': padded_time_series, 'seq_lengths': time_series_lengths}
    with torch.no_grad():
        out = model(images, text_item, packed_ts, static_data, task=_get_task_type(model), return_ae=False)
    logits = out['logits'] if isinstance(out, dict) else out
    print("[SANITY] logits shape:", tuple(logits.shape))
    return out

def process_labels_for_excap(labels, task_name='next_step'):

    if task_name not in labels:
        print(f"Warning: {task_name} not found in labels: {list(labels.keys())}")
        task_name = list(labels.keys())[0]
        print(f"Using {task_name} instead")

    task_labels = labels[task_name]   # [B, 1, T, 2]
    B, _, T, C = task_labels.shape    
    task_labels = task_labels.squeeze(1)
    processed_labels = task_labels.clone()  # [B, T, 2]


    return processed_labels

def main_training(gpu_id=0, exp_name="multimodal_training", batch_size=1, num_epochs=10, 
                 lr=1e-4, log_dir="runs/debug", 
                 shared_dim=64, text_dim=64, image_dim=64, ts_dim=64, 
                 static_dim=5, num_heads=4, num_layers=2, dropout=0.1, 
                 task_name='death_24h', visualize=False, num_workers=0, freeze_encoders=True,
                 use_cache=True, cache_dir=None, force_rebuild_cache=False, cache_version="v1.0",
                 config_file=None, preprocessing_config=None, train_ratio=0.7, val_ratio=0.15, 
                 test_ratio=0.15, random_seed=42, force_rebuild_split=False):

    
    config = None
    if config_file:
        try:
            yaml_config = load_config_from_yaml(config_file)
            
            class Args:
                def __init__(self):
                    self.gpu = gpu_id
                    self.experiment_name = exp_name
                    self.log_dir = log_dir
                    self.batch_size = batch_size
                    self.epochs = num_epochs
                    self.lr = lr
                    self.num_workers = num_workers
                    self.freeze_encoders = freeze_encoders
                    self.visualize = visualize
                    self.shared_dim = shared_dim
                    self.text_dim = text_dim
                    self.image_dim = image_dim
                    self.ts_dim = ts_dim
                    self.static_dim = static_dim
                    self.num_heads = num_heads
                    self.num_layers = num_layers
                    self.dropout = dropout
                    self.task = 'regression'
                    self.task_name = task_name
                    self.use_cache = use_cache
                    self.cache_dir = cache_dir
                    self.force_rebuild_cache = force_rebuild_cache
                    self.cache_version = cache_version
                    self.train_ratio = train_ratio
                    self.val_ratio = val_ratio
                    self.test_ratio = test_ratio
                    self.random_seed = random_seed
                    self.force_rebuild_split = force_rebuild_split
            
            args = Args()
            config = merge_config_with_args(yaml_config, args)
            print_config_summary(config)
            
        except Exception as e:
            print(f"YAML configuration loading failed: {e}")
            config = None
    
    if config is None:
        config = {
            'task_name': task_name,
            'batch_size': batch_size,
            'num_epochs': num_epochs,
            'learning_rate': lr,
            'device': f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu',
            'log_dir': log_dir,
            'experiment_name': exp_name,
            'visualize': visualize,
            'num_workers': num_workers,
            'use_cache': use_cache,
            'cache_dir': cache_dir or "data/cache",
            'force_rebuild_cache': force_rebuild_cache,
            'cache_version': cache_version,
            'preprocessing_config': preprocessing_config,
            'split_ratio': {'train': train_ratio, 'val': val_ratio, 'test': test_ratio},
            'random_seed': random_seed,
            'force_rebuild_split': force_rebuild_split
        }
        
        print(f"Training configuration:")
        for key, value in config.items():
            print(f"  {key}: {value}")
    
    
    from datapress.Aligned.mimiccxr_dataset import MIMICCXRDataset
    from datapress.Aligned.medical_dataset import MedicalDataset
    from datapress.Aligned.multiModelAligned_dataset import MultiModalAlignedDataset
    from datapress.dataloader import create_data_loader
    from model.factory import create_multimodal_model
    from omegaconf import OmegaConf
    from torchvision import transforms
    
    try:
        from datapress.cache_dataset import create_cached_data_loader, CacheDataset
        CACHE_AVAILABLE = True
    except ImportError:
        CACHE_AVAILABLE = False
    
    # Use relative paths - users should configure these in their config file
    base_data_path = os.getcwd()
    index_file = os.path.join(base_data_path, 'data/MIMIC/index.json')
    image_dir = os.path.join(base_data_path, 'data/MIMIC/images')
    reports_dir = os.path.join(base_data_path, 'data/MIMIC/reports')
    
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])
    
    cxr_dataset = MIMICCXRDataset(
        index_file_path=index_file,
        image_root=image_dir,
        reports_root=reports_dir,
        transform=transform
    )

    opt = OmegaConf.load("exp/mimic_data/exp_mortality_24h48h_los.yaml")
    med_dataset = MedicalDataset(**opt.data.train_val, **opt.data.shared_param)
    aligned_json_path = os.path.join(base_data_path, 'data', 'aligned_subjects.json')
    multi_dataset = MultiModalAlignedDataset(cxr_dataset, med_dataset, sid_json_path=aligned_json_path)
    if len(multi_dataset) == 0:
        return False
    
    split_ratio = config.get('split_ratio', {'train': 0.7, 'val': 0.15, 'test': 0.15})
    random_seed = config.get('random_seed', 42)
    train_ratio = split_ratio['train']
    val_ratio = split_ratio['val']
    test_ratio = split_ratio['test']
    
    config['num_workers'] = 0
    
    if config['use_cache'] and CACHE_AVAILABLE:
        full_cache_loader, full_cache_dataset = create_cached_data_loader(
            multi_dataset,
            batch_size=config['batch_size'],
            cache_dir=config['cache_dir'],
            shuffle=False,
            num_workers=config['num_workers'],
            force_rebuild=config['force_rebuild_cache'],
            cache_version=config['cache_version'],
            preprocessing_config=config.get('preprocessing_config', None)
        )
        
        from torch.utils.data import Subset
        import random
        
        random.seed(random_seed)
        np.random.seed(random_seed)
        
        total_samples = len(full_cache_dataset)
        all_indices = list(range(total_samples))
        random.shuffle(all_indices)
        
        train_end = int(total_samples * train_ratio)
        val_end = train_end + int(total_samples * val_ratio)
        
        train_indices = all_indices[:train_end]
        val_indices = all_indices[train_end:val_end]
        test_indices = all_indices[val_end:]
        
        train_cache_dataset = Subset(full_cache_dataset, train_indices)
        val_cache_dataset = Subset(full_cache_dataset, val_indices)
        test_cache_dataset = Subset(full_cache_dataset, test_indices)
        
        from datapress.cache_dataset import enhanced_collate_fn
        from torch.utils.data import DataLoader
        
        train_loader = DataLoader(
            train_cache_dataset,
            batch_size=config['batch_size'],
            shuffle=True,
            collate_fn=lambda batch: enhanced_collate_fn(batch, preprocessing_config=config.get('preprocessing_config', None)),
            num_workers=config['num_workers'],
            pin_memory=True,
            persistent_workers=True if config['num_workers'] > 0 else False,
            prefetch_factor=2 if config['num_workers'] > 0 else None,
            drop_last=False,
            multiprocessing_context='spawn' if config['num_workers'] > 0 else None
        )
        
        val_loader = DataLoader(
            val_cache_dataset,
            batch_size=config['batch_size'],
            shuffle=False,  
            collate_fn=lambda batch: enhanced_collate_fn(batch, preprocessing_config=config.get('preprocessing_config', None)),
            num_workers=config['num_workers'],
            pin_memory=True,
            persistent_workers=True if config['num_workers'] > 0 else False,
            prefetch_factor=2 if config['num_workers'] > 0 else None,
            drop_last=False,
            multiprocessing_context='spawn' if config['num_workers'] > 0 else None
        )
        
        test_loader = DataLoader(
            test_cache_dataset,
            batch_size=config['batch_size'],
            shuffle=False,
            collate_fn=lambda batch: enhanced_collate_fn(batch, preprocessing_config=config.get('preprocessing_config', None)),
            num_workers=config['num_workers'],
            pin_memory=True,
            persistent_workers=True if config['num_workers'] > 0 else False,
            prefetch_factor=2 if config['num_workers'] > 0 else None,
            drop_last=False,
            multiprocessing_context='spawn' if config['num_workers'] > 0 else None
        )

    model_config = {
        "shared_dim": shared_dim,
        "out_dim": 20,  # 96 classes
        "out_len": 1,
        "text_dim": text_dim,
        "image_dim": image_dim,
        "ts_dim": ts_dim,
        "static_dim": static_dim,
        "num_heads": num_heads,
        "num_layers": num_layers,
        "dropout": dropout,
        
        "eamc": {
            "node_budget": config.get('eamc', {}).get('node_budget', 8),
            "max_segments": config.get('eamc', {}).get('max_segments', 8),
            "use_nbc": config.get('eamc', {}).get('use_nbc', True),
            "desired_threshold": config.get('eamc', {}).get('desired_threshold', 0.8),
            "segment_len": config.get('eamc', {}).get('segment_len', 12),
            "max_seq_len": config.get('eamc', {}).get('max_seq_len', 96),
            "segment_mask_top_k_ratio": config.get('eamc', {}).get('segment_mask_top_k_ratio', 0.3),
            "trimmer_types": config.get('eamc', {}).get('trimmer_types', ["Decomposition"]),
        },
        
        "use_dynseg": config.get('use_dynseg', True),
        "use_causal": config.get('use_causal', True),
        "use_nbc": config.get('eamc', {}).get('use_nbc', True),
        "frequency_branch_enabled": config.get('frequency_branch_enabled', True),
        "use_fft": config.get('use_fft', True),
        "use_wavelet": config.get('use_wavelet', True),
    }
    
    model = create_multimodal_model(model_config)
    
    freeze_encoders = True
    if freeze_encoders:
        try:
            device = torch.device(config['device'])
            model = model.to(device)
            
            dummy_images = torch.randn(1, 1, 3, 224, 224).to(device)
            dummy_text = {
                'input_ids': torch.randint(0, 1000, (1, 128)).to(device),
                'attention_mask': torch.ones(1, 128).to(device)
            }
            dummy_ts = {
                'ts_data': torch.randn(1, 2, 24, 10).to(device),
                'seq_lengths': torch.tensor([24]).to(device)
            }
            dummy_static = torch.randn(1, 5).to(device)
            
            with torch.no_grad():
                _ = model(dummy_images, dummy_text, dummy_ts, dummy_static, task='classification', return_ae=False)
        except Exception as e:
            print(f"Dynamic parameter initialization failed: {e}")
            print("Continuing with freezing operation...")
        frozen_params = 0
        
        if hasattr(model, 'ff') and hasattr(model.ff, 'image_encoder'):
            for param in model.ff.image_encoder.parameters():
                param.requires_grad = False
                frozen_params += param.numel()
        
        if hasattr(model, 'ff') and hasattr(model.ff, 'text_encoder'):
            for param in model.ff.text_encoder.parameters():
                param.requires_grad = False
                frozen_params += param.numel()
        
        if hasattr(model, 'ff') and hasattr(model.ff, 'ts_encoder'):
            for param in model.ff.ts_encoder.parameters():
                param.requires_grad = False
                frozen_params += param.numel()

        if hasattr(model, 'ff') and hasattr(model.ff, 'static_encoder'):
            for param in model.ff.static_encoder.parameters():
                param.requires_grad = False
                frozen_params += param.numel()
        
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in model.parameters())

    else:
        print("Encoders not frozen, all parameters will be trained")
    
    train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=config['num_epochs'],
        lr=config['learning_rate'],
        device=config['device'],
        log_dir=config['log_dir'],
        experiment_name=config['experiment_name'],
        visualize=config['visualize'],
        task='regression'
    )
    
    print(f"Model and logs saved in: {config['log_dir']}")
    return True

if __name__ == "__main__":
    import argparse
    
    # Set paths (using relative paths for portability)
    import sys
    import os
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    if base_dir not in sys.path:
        sys.path.insert(0, base_dir)

    parser = argparse.ArgumentParser(description='Train multimodal model')
    parser.add_argument('--gpu', type=int, default=4, help='GPU device ID')
    parser.add_argument('--experiment_name', type=str, default='multimodal_training', help='Experiment name')
    parser.add_argument('--log_dir', type=str, default='runs/mimic', help='Log directory')
    parser.add_argument('--save_dir', type=str, default='runs/checkpoints', help='Model save directory')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=0, help='Data loader worker count')
    parser.add_argument('--epochs', type=int, default=10, help='Training epochs')
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--freeze_encoders', action='store_true', help='Freeze encoders')
    parser.add_argument('--shared_dim', type=int, default=64, help='Shared dimension')
    parser.add_argument('--text_dim', type=int, default=64, help='Text dimension')
    parser.add_argument('--image_dim', type=int, default=64, help='Image dimension')
    parser.add_argument('--ts_dim', type=int, default=64, help='Time series dimension')
    parser.add_argument('--static_dim', type=int, default=5, help='Static data dimension')
    parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads')
    parser.add_argument('--num_layers', type=int, default=2, help='Number of layers')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
    parser.add_argument('--task', type=str, default='regression', help='Task type')
    parser.add_argument('--task_name', type=str, default='debug', help='Specific task name')
    parser.add_argument('--visualize', action='store_true', help='Enable visualization')
    
    parser.add_argument('--use_cache', action='store_true', default=True, help='Use cached dataset for training (default enabled)')
    parser.add_argument('--no_cache', action='store_true', help='Disable cached dataset')
    parser.add_argument('--cache_dir', type=str, default=None, help='Cache directory')
    parser.add_argument('--force_rebuild_cache', action='store_true', help='Force rebuild cache')
    parser.add_argument('--cache_version', type=str, default='v1.0', help='Cache version')
    
    parser.add_argument('--config', type=str, default='configs/cache_debug.yaml', help='config')
    
    parser.add_argument('--train_ratio', type=float, default=0.7, help='Training set ratio')
    parser.add_argument('--val_ratio', type=float, default=0.15, help='Validation set ratio')
    parser.add_argument('--test_ratio', type=float, default=0.15, help='Test set ratio')
    parser.add_argument('--random_seed', type=int, default=42, help='Random seed')
    parser.add_argument('--force_rebuild_split', action='store_true', help='Force rebuild split')
    
    args = parser.parse_args()

    use_cache = args.use_cache and not args.no_cache
    
    print(f"Using GPU: {args.gpu}")
    print(f"Experiment name: {args.experiment_name}")
    print(f"Using cached dataset: {use_cache}")
    if use_cache:
        print(f"Cache directory: {args.cache_dir or 'data/cache'}")
        print(f"Force rebuild cache: {args.force_rebuild_cache}")
        print(f"Cache version: {args.cache_version}")
    
    success = main_training(
        gpu_id=args.gpu,
        exp_name=args.experiment_name,
        batch_size=args.batch_size,
        num_epochs=args.epochs,
        lr=args.lr,
        log_dir=args.log_dir,
        shared_dim=args.shared_dim,
        text_dim=args.text_dim,
        image_dim=args.image_dim,
        ts_dim=args.ts_dim,
        static_dim=args.static_dim,
        num_heads=args.num_heads,
        num_layers=args.num_layers,
        dropout=args.dropout,
        task_name=args.task_name,
        visualize=args.visualize,
        num_workers=args.num_workers,
        freeze_encoders=args.freeze_encoders,
        use_cache=use_cache,
        cache_dir=args.cache_dir,
        force_rebuild_cache=args.force_rebuild_cache,
        cache_version=args.cache_version,
        config_file=args.config,
        train_ratio=args.train_ratio,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        random_seed=args.random_seed,
        force_rebuild_split=args.force_rebuild_split
    )
    
    if success:
        print("\nTraining completed successfully!")
    else:
        print("\nTraining failed!")
