#!/usr/bin/env python3
"""
Train baseline models (TimesNet, iTransformer, PatchTST) for multi-task classification
Uses the same dataset setup and evaluation metrics as train_mortality_los_complete.py
"""

import os
import sys
import time
import json
import argparse
from datetime import datetime
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Subset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, confusion_matrix
from omegaconf import OmegaConf

# Setup paths for release directory
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
RELEASE_DIR = os.path.dirname(BASE_DIR)  # release directory
PROJECT_ROOT = os.path.dirname(RELEASE_DIR)  # project root

# Add paths - prioritize release directory, fallback to project root
sys.path.insert(0, BASE_DIR)
sys.path.insert(0, RELEASE_DIR)

# Add datapress path
if os.path.exists(os.path.join(RELEASE_DIR, 'datapress')):
    sys.path.insert(0, os.path.join(RELEASE_DIR, 'datapress'))
    sys.path.insert(0, os.path.join(RELEASE_DIR, 'datapress', 'Aligned'))
else:
    sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datapress'))
    sys.path.insert(0, os.path.join(PROJECT_ROOT, 'datapress', 'Aligned'))

# Add baseline model paths (if available in project root)
ts_lib_path = os.path.join(PROJECT_ROOT, 'sota', 'repo', 'mimic', 'Time-Series-Library')
itransformer_path = os.path.join(PROJECT_ROOT, 'sota', 'repo', 'mimic', 'iTransformer')

# Note: MedicalDatasetWithLOS is imported conditionally to support offline-only releases
from datapress.Aligned.medical_dataset_wrapper import MedicalDatasetWrapper
from datapress.dataloader import collate_fn as collate_medical_batch

# Import baseline models
# Ensure Time-Series-Library is before iTransformer to avoid module conflicts
if os.path.exists(ts_lib_path):
    if itransformer_path in sys.path:
        sys.path.remove(itransformer_path)
    if ts_lib_path in sys.path:
        sys.path.remove(ts_lib_path)
    sys.path.insert(0, ts_lib_path)

TimesNetModel = None
iTransformerModel = None
PatchTSTModel = None

# Try importing baseline models
try:
    from models.TimesNet import Model as TimesNetModel
except ImportError:
    pass

try:
    from models.iTransformer import Model as iTransformerModel
except ImportError:
    if os.path.exists(itransformer_path):
        try:
            sys.path.insert(0, itransformer_path)
            from model.iTransformer import Model as iTransformerModel
        except ImportError:
            pass

try:
    from models.PatchTST import Model as PatchTSTModel
except ImportError:
    pass

TASK_NAMES = ['mortality_24h_48h', 'los_prediction_48h']

# 全局变量：存储训练集的时序特征统计量（用于标准化）
_ts_normalization_means = None
_ts_normalization_stds = None

def extract_logits_tensor(obj):
    """递归地从模型输出中提取logits tensor"""
    if isinstance(obj, torch.Tensor):
        return obj
    elif isinstance(obj, dict):
        for key in ['logits', 'output', 'pred', 'classification', 'logit']:
            if key in obj:
                result = extract_logits_tensor(obj[key])
                if result is not None:
                    return result
        for value in obj.values():
            result = extract_logits_tensor(value)
            if result is not None:
                return result
    elif isinstance(obj, (list, tuple)):
        for item in obj:
            result = extract_logits_tensor(item)
            if result is not None:
                return result
    return None

def calculate_classification_metrics(y_true, y_pred_proba, task_name="task", threshold=None):

    if isinstance(y_true, torch.Tensor):
        y_true = y_true.detach().cpu().numpy()
    if isinstance(y_pred_proba, torch.Tensor):
        y_pred_proba = y_pred_proba.detach().cpu().numpy()

    if y_true.ndim == 2 and y_true.shape[1] == 2:
        if np.allclose(y_true.sum(axis=1), 1.0):
            y_true_binary = y_true[:, 1].astype(int)
        else:
            y_true_binary = y_true.argmax(axis=1).astype(int)
    else:
        y_true_binary = y_true.flatten().astype(int)
    
    if y_pred_proba.ndim == 2 and y_pred_proba.shape[1] == 2:
        if y_pred_proba.max() > 1.0 or y_pred_proba.min() < 0.0:
            y_pred_proba = torch.softmax(torch.from_numpy(y_pred_proba), dim=-1).numpy()
        y_pred_proba_pos = y_pred_proba[:, 1]
    else:
        y_pred_proba_pos = y_pred_proba.flatten()
        if y_pred_proba_pos.max() > 1.0 or y_pred_proba_pos.min() < 0.0:
            y_pred_proba_pos = torch.sigmoid(torch.from_numpy(y_pred_proba_pos)).numpy()
    
    if threshold is None and len(np.unique(y_true_binary)) > 1:
        best_threshold = 0.5
        best_f1 = 0.0
        for thresh in np.arange(0.05, 0.51, 0.05):
            y_pred_binary = (y_pred_proba_pos > thresh).astype(int)
            try:
                f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0)
                if f1 > best_f1:
                    best_f1 = f1
                    best_threshold = thresh
            except:
                pass
        threshold = best_threshold
    else:
        threshold = threshold if threshold is not None else 0.5
    
    y_pred_binary = (y_pred_proba_pos > threshold).astype(int)
    
    min_len = min(len(y_true_binary), len(y_pred_proba_pos))
    y_true_binary = y_true_binary[:min_len]
    y_pred_proba_pos = y_pred_proba_pos[:min_len]
    y_pred_binary = y_pred_binary[:min_len]
    
    metrics = {}
    
    try:
        metrics['accuracy'] = float(accuracy_score(y_true_binary, y_pred_binary))
        metrics['precision'] = float(precision_score(y_true_binary, y_pred_binary, zero_division=0))
        metrics['recall'] = float(recall_score(y_true_binary, y_pred_binary, zero_division=0))
        metrics['f1'] = float(f1_score(y_true_binary, y_pred_binary, zero_division=0))
        
        if len(np.unique(y_true_binary)) > 1:
            metrics['auroc'] = float(roc_auc_score(y_true_binary, y_pred_proba_pos))
        else:
            metrics['auroc'] = 0.0
        
        if len(np.unique(y_true_binary)) > 1:
            metrics['auprc'] = float(average_precision_score(y_true_binary, y_pred_proba_pos))
        else:
            metrics['auprc'] = 0.0
        
        cm = confusion_matrix(y_true_binary, y_pred_binary)
        if cm.shape == (2, 2):
            metrics['tn'] = int(cm[0, 0])
            metrics['fp'] = int(cm[0, 1])
            metrics['fn'] = int(cm[1, 0])
            metrics['tp'] = int(cm[1, 1])
        else:
            metrics['tn'] = metrics['fp'] = metrics['fn'] = metrics['tp'] = 0
        
        metrics['num_positive'] = int(y_true_binary.sum())
        metrics['num_negative'] = int(len(y_true_binary) - y_true_binary.sum())
        metrics['num_total'] = int(len(y_true_binary))
        metrics['optimal_threshold'] = float(threshold)
        
    except Exception as e:
        print(f"Error calculating metrics for {task_name}: {e}")
        import traceback
        traceback.print_exc()
        metrics = {
            'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0,
            'auroc': 0.0, 'auprc': 0.0,
            'tn': 0, 'fp': 0, 'fn': 0, 'tp': 0,
            'num_positive': 0, 'num_negative': 0, 'num_total': 0
        }
    
    return metrics


class BaselineModelWrapper(nn.Module):
    def __init__(self, base_model_name, base_model, task_names, ts_dim=64, seq_len=100, hidden_dim=128):
        super().__init__()
        self.base_model_name = base_model_name
        self.base_model = base_model
        self.task_names = task_names
        self.ts_dim = ts_dim
        self.seq_len = seq_len
        self.hidden_dim = hidden_dim
        
        self.task_heads = nn.ModuleDict()
        for task_name in task_names:
            task_head = nn.Sequential(
                nn.Linear(hidden_dim, 64),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(64, 1)
            )
            output_layer = task_head[-1]
            nn.init.xavier_uniform_(output_layer.weight, gain=0.1)
            nn.init.constant_(output_layer.bias, 0.0)
            self.task_heads[task_name] = task_head
        
        print(f"Created {base_model_name} wrapper with {len(task_names)} task heads")
    
    def _extract_time_series_features(self, packed_ts, static_data, seq_lengths=None):
        if isinstance(packed_ts, dict):
            ts_data = packed_ts['ts_data']
            seq_lengths = packed_ts.get('seq_lengths', seq_lengths)
        else:
            ts_data = packed_ts
        
        if ts_data.dim() == 5:
            ts_data = ts_data[:, :, :, :, 0]
            if seq_lengths is not None and seq_lengths.dim() == 2:
                seq_lengths = seq_lengths[:, 0]
            elif seq_lengths is not None and seq_lengths.dim() == 1:
                pass
        elif ts_data.dim() == 4:
            pass
        elif ts_data.dim() == 3:
            pass
        else:
            raise ValueError(f"Unexpected ts_data shape: {ts_data.shape}")
        
        if ts_data.dim() == 4:
            ts_data = ts_data[:, 0, :, :]  # [B, T, F]
            if seq_lengths is not None and seq_lengths.dim() == 2:
                seq_lengths = seq_lengths[:, 0]
        elif ts_data.dim() == 3:
            pass
        else:
            raise ValueError(f"After processing, unexpected ts_data shape: {ts_data.shape}")
        
        B, T, F = ts_data.shape
        
        if T < self.seq_len:
            padding = torch.zeros(B, self.seq_len - T, F, device=ts_data.device)
            ts_data = torch.cat([ts_data, padding], dim=1)  
            if seq_lengths is not None:
                seq_lengths = torch.full((B,), self.seq_len, device=seq_lengths.device, dtype=seq_lengths.dtype)
        elif T > self.seq_len:
            ts_data = ts_data[:, :self.seq_len, :]
            if seq_lengths is not None:
                seq_lengths = torch.clamp(seq_lengths, max=self.seq_len)
        
        if seq_lengths is not None:
            x_mark_enc = torch.zeros(B, self.seq_len, device=ts_data.device)
            for i, seq_len in enumerate(seq_lengths):
                x_mark_enc[i, :seq_len] = 1.0
        else:
            x_mark_enc = torch.ones(B, self.seq_len, device=ts_data.device)
        
        try:
            if self.base_model_name == 'TimesNet':

                target_seq_len = 127  # 127+1=128 (2^7), 确保pred_len+seq_len是2的幂次
                if ts_data.shape[1] != target_seq_len:
                    if ts_data.shape[1] < target_seq_len:
                        padding_len = target_seq_len - ts_data.shape[1]
                        padding = torch.zeros(B, padding_len, F, device=ts_data.device, dtype=ts_data.dtype)
                        ts_data_padded = torch.cat([ts_data, padding], dim=1)
                    else:
                        ts_data_padded = ts_data[:, :target_seq_len, :]
                else:
                    ts_data_padded = ts_data
                
                x_dec = torch.zeros(B, 1, F, device=ts_data.device, dtype=ts_data.dtype)
                
                ts_data_fp32 = ts_data_padded.float()
                x_dec_fp32 = x_dec.float()
                
                output = self.base_model.forecast(ts_data_fp32, None, x_dec_fp32, None)
                
                if output.dim() == 3:
                    output = output[:, -1, :]
                elif output.dim() == 2:
                    pass
                else:
                    output = output.view(B, -1)
                    
            elif self.base_model_name == 'iTransformer':
                x_dec = torch.zeros(B, 1, F, device=ts_data.device)
                output = self.base_model.forecast(ts_data, None, x_dec, None)
                if output.dim() == 3:  # [B, pred_len, F]
                    output = output.mean(dim=1)  # [B, F]
                elif output.dim() == 2:  # [B, F]
                    pass
                else:
                    output = output.view(B, -1)
                    
            elif self.base_model_name == 'PatchTST':        
                x_dec = torch.zeros(B, 1, F, device=ts_data.device)
                output = self.base_model.forecast(ts_data, None, x_dec, None)
                if output.dim() == 3:  # [B, pred_len, F]
                    output = output.mean(dim=1)  # [B, F]
                elif output.dim() == 2:  # [B, F]
                    pass
                else:
                    output = output.view(B, -1)
            else:
                output = ts_data.mean(dim=1)
                
        except Exception as e:
            # Fallback: use mean pooling
            output = ts_data.mean(dim=1)
        
        # Normalize base model output
        if not hasattr(self, '_base_output_norm'):
            self._base_output_norm = nn.LayerNorm(output.shape[-1]).to(output.device)
        
        if isinstance(output, torch.Tensor):
            output_mean = output.mean()
            output_std = output.std()
            output_clipped = torch.clamp(output, 
                                       min=output_mean - 5 * output_std,
                                       max=output_mean + 5 * output_std)
            output = self._base_output_norm(output_clipped)
        
        # Project to hidden_dim
        if output.shape[-1] != self.hidden_dim:
            if not hasattr(self, '_feature_proj'):
                self._feature_proj = nn.Linear(output.shape[-1], self.hidden_dim).to(output.device)
                nn.init.xavier_uniform_(self._feature_proj.weight, gain=0.1)
                nn.init.constant_(self._feature_proj.bias, 0.0)
            output = self._feature_proj(output)
        
        # Add static features if available
        if static_data is not None:
            if static_data.dim() == 4:
                static_data = static_data[:, 0, :, 0]
            elif static_data.dim() == 3:
                static_data = static_data[:, 0, :]
            
            if not hasattr(self, '_static_proj'):
                self._static_proj = nn.Linear(static_data.shape[-1], self.hidden_dim // 2).to(static_data.device)
                nn.init.xavier_uniform_(self._static_proj.weight, gain=0.1)
                nn.init.constant_(self._static_proj.bias, 0.0)
            
            static_feat = self._static_proj(static_data)
            output = torch.cat([output, static_feat], dim=-1)
            
            if not hasattr(self, '_final_proj'):
                self._final_proj = nn.Linear(self.hidden_dim + self.hidden_dim // 2, self.hidden_dim).to(output.device)
                nn.init.xavier_uniform_(self._final_proj.weight, gain=0.1)
                nn.init.constant_(self._final_proj.bias, 0.0)
            
            output = self._final_proj(output)
        
        # Final feature normalization
        if not hasattr(self, '_final_norm'):
            self._final_norm = nn.LayerNorm(self.hidden_dim).to(output.device)
        
        output = self._final_norm(output)
        
        return output
    
    def forward(self, images, text_item, packed_ts, static_data, task=None):

        seq_lengths = None
        if isinstance(packed_ts, dict) and 'seq_lengths' in packed_ts:
            seq_lengths = packed_ts['seq_lengths']
        
        # Extract time series features
        features = self._extract_time_series_features(packed_ts, static_data, seq_lengths)
        
        # Generate logits for each task
        task_outputs = {}
        for task_name in self.task_names:
            task_logits = self.task_heads[task_name](features)
            if task_logits.dim() == 2 and task_logits.shape[1] == 1:
                task_logits = task_logits.squeeze(-1)
            task_outputs[task_name] = task_logits
        
        return task_outputs


def create_baseline_model(model_name, ts_dim=64, seq_len=100, device='cuda'):
    
    class ModelConfig:
        def __init__(self):
            self.enc_in = ts_dim
            self.c_out = ts_dim
            self.seq_len = seq_len
            self.pred_len = 1   
            self.label_len = 0
            self.task_name = 'long_term_forecast' 
    
    cfg = ModelConfig()
    
    if model_name == 'TimesNet':
        if TimesNetModel is None:
            raise ImportError("TimesNet not available")
        cfg.seq_len = 127
        cfg.d_model = 128
        cfg.e_layers = 2
        cfg.d_ff = 256
        cfg.top_k = 2
        cfg.num_kernels = 6
        cfg.embed = 'timeF'
        cfg.freq = 'h'
        cfg.dropout = 0.1
        cfg.output_attention = False
        cfg.num_class = 1
        
        base_model = TimesNetModel(cfg)
        wrapper = BaselineModelWrapper('TimesNet', base_model, TASK_NAMES, ts_dim, seq_len, hidden_dim=128)
        
    elif model_name == 'iTransformer':
        if iTransformerModel is None:
            raise ImportError("iTransformer not available")
        cfg.d_model = 128
        cfg.n_heads = 4
        cfg.e_layers = 2
        cfg.d_ff = 256
        cfg.dropout = 0.1
        cfg.factor = 1
        cfg.activation = 'gelu'
        cfg.embed = 'timeF'
        cfg.freq = 'h'
        
        base_model = iTransformerModel(cfg)
        wrapper = BaselineModelWrapper('iTransformer', base_model, TASK_NAMES, ts_dim, seq_len, hidden_dim=128)
        
    elif model_name == 'PatchTST':
        if PatchTSTModel is None:
            raise ImportError("PatchTST not available")
        cfg.d_model = 64
        cfg.n_heads = 4
        cfg.e_layers = 3
        cfg.d_ff = 128
        cfg.dropout = 0.2
        cfg.factor = 1
        cfg.activation = 'gelu'
        cfg.patch_len = 16
        cfg.stride = 8
        cfg.individual = True
        cfg.embed = 'timeF'
        cfg.freq = 'h'
        
        base_model = PatchTSTModel(cfg)
        wrapper = BaselineModelWrapper('PatchTST', base_model, TASK_NAMES, ts_dim, seq_len, hidden_dim=128)
        
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    
    return wrapper.to(device)


@torch.no_grad()
def validate(model, val_loader, criterions, device, writer, global_step, task_names=TASK_NAMES):
    model.eval()
    total_losses = {task: 0.0 for task in task_names}
    total_batches = 0
    all_predictions = {task: [] for task in task_names}
    all_targets = {task: [] for task in task_names}
    
    for batch_idx, batch in enumerate(val_loader):
        images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
        images = images.to(device)
        static_data = static_data.to(device)
        padded_time_series = padded_time_series.to(device)
        time_series_lengths = time_series_lengths.to(device)
        for k in text_item:
            text_item[k] = text_item[k].to(device)
        packed_ts = {
            'ts_data': padded_time_series,
            'seq_lengths': time_series_lengths
        }
        
        if not isinstance(labels, dict):
            continue
        
        # Forward pass
        try:
            outputs_dict = model(images, text_item, packed_ts, static_data)
        except Exception as e:
            print(f"Error in model forward: {e}")
            import traceback
            traceback.print_exc()
            continue
        
        for task_name in task_names:
            target = labels.get(task_name)
            if target is None or task_name not in outputs_dict:
                continue
            
            target = target.float().to(device)
            
            if target.dim() >= 3:
                target = target[:, 0, :]
            
            if target.dim() == 2:
                if target.shape[1] == 2:
                    target = target.argmax(dim=-1).float()
                elif target.shape[1] == 1:
                    target = target.squeeze(-1)
                else:
                    target = target[:, 0]
            elif target.dim() > 2:
                target = target.view(-1, target.shape[-1])
                if target.shape[-1] == 2:
                    target = target.argmax(dim=-1).float()
                else:
                    target = target.squeeze(-1)
            
            target = target.view(-1)
            
            task_logits = outputs_dict[task_name]
            if not isinstance(task_logits, torch.Tensor):
                continue
            
            if task_logits.dim() >= 3:
                task_logits = task_logits[:, 0]
            if task_logits.dim() == 2:
                if task_logits.shape[-1] == 2:
                    task_logits = task_logits[:, 1]
                elif task_logits.shape[-1] == 1:
                    task_logits = task_logits.squeeze(-1)
                else:
                    task_logits = task_logits[:, 0]
            task_logits = task_logits.view(-1)
            
            batch_size = target.shape[0]
            if task_logits.shape[0] != batch_size:
                if task_logits.shape[0] > batch_size:
                    task_logits = task_logits[:batch_size]
                else:
                    continue
            
            loss = criterions[task_name](task_logits, target)
            total_losses[task_name] += loss.item()
            
            all_predictions[task_name].append(task_logits.detach())
            all_targets[task_name].append(target.detach())
        
        total_batches += 1
    
    avg_losses = {task: total_losses[task] / total_batches if total_batches > 0 else 0.0 
                  for task in task_names}
    
    metrics_dict = {}
    for task_name in task_names:
        if len(all_predictions[task_name]) > 0:
            all_pred = torch.cat(all_predictions[task_name], dim=0)
            all_tgt = torch.cat(all_targets[task_name], dim=0)
            
            metrics_dict[task_name] = calculate_classification_metrics(all_tgt, all_pred, task_name, threshold=None)

    for task_name in task_names:
        if task_name in metrics_dict:
            metrics = metrics_dict[task_name]
            writer.add_scalar(f'val/{task_name}/loss', avg_losses[task_name], global_step)
            writer.add_scalar(f'val/{task_name}/accuracy', metrics['accuracy'], global_step)
            writer.add_scalar(f'val/{task_name}/f1', metrics['f1'], global_step)
            writer.add_scalar(f'val/{task_name}/auroc', metrics['auroc'], global_step)
            writer.add_scalar(f'val/{task_name}/auprc', metrics['auprc'], global_step)
    
    return avg_losses, metrics_dict


def train_baseline_model(
    model_name, model, train_loader, val_loader=None, test_loader=None, num_epochs=10, lr=1e-4, 
    device='cuda', log_dir="/ssd/0/wzq/Multi_Med/runs/classify", experiment_name=None, 
    pos_weights=None, task_names=TASK_NAMES
):
    global _ts_normalization_means, _ts_normalization_stds
    
    ts_all_values = []
    max_samples_for_stats = 5000
    sample_count = 0
    
    for batch_idx, batch in enumerate(train_loader):
        if sample_count >= max_samples_for_stats:
            break
        try:
            _, _, padded_time_series, _, _, _ = batch
            if padded_time_series.dim() == 5:
                ts_values = padded_time_series[:, :, :, :, 0]
            elif padded_time_series.dim() == 4:
                ts_values = padded_time_series
            else:
                continue
            
            ts_values_flat = ts_values.view(-1, ts_values.shape[-1])
            ts_all_values.append(ts_values_flat.cpu())
            sample_count += ts_values.shape[0]
        except Exception:
            continue
    
    if len(ts_all_values) > 0:
        ts_all = torch.cat(ts_all_values, dim=0)  # [N_total, F]
        _ts_normalization_means = ts_all.mean(dim=0)  # [F]
        _ts_normalization_stds = ts_all.std(dim=0)  # [F]
        
        for f in range(ts_all.shape[1]):
            feature_data = ts_all[:, f]
            if _ts_normalization_stds[f] > 100.0:
                median = feature_data.median()
                mad = (feature_data - median).abs().median()
                robust_std = mad * 1.4826
                if robust_std > 1e-6:
                    _ts_normalization_means[f] = median
                    _ts_normalization_stds[f] = robust_std
        
        _ts_normalization_stds = torch.clamp(_ts_normalization_stds, min=1e-6)
    else:
        _ts_normalization_means = None
        _ts_normalization_stds = None
    
    from torch.utils.data import DataLoader
    train_dataset = train_loader.dataset
    train_loader = DataLoader(
        train_dataset,
        batch_size=train_loader.batch_size,
        shuffle=True,
        num_workers=train_loader.num_workers,
        pin_memory=train_loader.pin_memory,
        collate_fn=train_loader.collate_fn,
        persistent_workers=train_loader.persistent_workers if hasattr(train_loader, 'persistent_workers') else False,
        prefetch_factor=train_loader.prefetch_factor if hasattr(train_loader, 'prefetch_factor') else 2
    )
    
    model = model.to(device)
    effective_lr = lr * 0.1
    optimizer = torch.optim.AdamW(model.parameters(), lr=effective_lr, weight_decay=0.01)
    print(f"[INFO] Using effective learning rate: {effective_lr:.2e} (reduced from {lr:.2e} for stability)")
    
    if pos_weights is None:
        pos_weights = {}
        pos_weights['mortality_24h_48h'] = 0.02
        pos_weights['los_prediction_48h'] = 4.0
    
    criterions = {}
    for task_name in task_names:
        pos_weight = pos_weights.get(task_name, 1.0)
        pos_weight_tensor = torch.tensor(pos_weight, device=device)
        criterions[task_name] = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor, reduction='mean')
        print(f"{task_name}: Using BCEWithLogitsLoss with pos_weight={pos_weight:.4f}, reduction='mean'")
    
    # scaler = GradScaler()
    
    if experiment_name is None:
        experiment_name = f"{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    
    log_path = os.path.join(log_dir, experiment_name)
    os.makedirs(log_path, exist_ok=True)
    writer = SummaryWriter(log_path)
    
    best_val_loss = {task: float('inf') for task in task_names}
    best_val_metrics = {task: {} for task in task_names}
    
    print(f"\nStarting training: {experiment_name}")
    print(f"Log directory: {log_path}")
    print(f"Tasks: {task_names}")
    
    global_step = 0
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        model.train()
        epoch_losses = {task: 0.0 for task in task_names}
        epoch_batches = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        for batch_idx, batch in enumerate(pbar):
            images, text_item, padded_time_series, time_series_lengths, static_data, labels = batch
            images = images.to(device)
            static_data = static_data.to(device)
            padded_time_series = padded_time_series.to(device)
            time_series_lengths = time_series_lengths.to(device)
            for k in text_item:
                text_item[k] = text_item[k].to(device)
            
            if padded_time_series.dim() == 5:
                ts_values = padded_time_series[:, :, :, :, 0]  # [B, N_med, T, F]
                ts_mask = padded_time_series[:, :, :, :, 1]
                B, N_med, T, F = ts_values.shape
                ts_values_flat = ts_values.view(-1, F)
                if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                    ts_means = _ts_normalization_means.to(device)
                    ts_stds = _ts_normalization_stds.to(device)
                    for f in range(F):
                        if f < len(ts_means) and ts_stds[f] > 1e-6:
                            ts_values_flat[:, f] = (ts_values_flat[:, f] - ts_means[f]) / ts_stds[f]
                else:
                    for f in range(F):
                        feature_data = ts_values_flat[:, f]
                        feature_mean = feature_data.mean()
                        feature_std = feature_data.std()
                        if feature_std > 1e-6:
                            ts_values_flat[:, f] = (feature_data - feature_mean) / feature_std
                ts_values = ts_values_flat.view(B, N_med, T, F)
                padded_time_series = torch.stack([ts_values, ts_mask], dim=-1)
            elif padded_time_series.dim() == 4:
                B, N_med, T, F = padded_time_series.shape
                ts_values_flat = padded_time_series.view(-1, F)
                if _ts_normalization_means is not None and _ts_normalization_stds is not None:
                    ts_means = _ts_normalization_means.to(device)
                    ts_stds = _ts_normalization_stds.to(device)
                    for f in range(F):
                        if f < len(ts_means) and ts_stds[f] > 1e-6:
                            ts_values_flat[:, f] = (ts_values_flat[:, f] - ts_means[f]) / ts_stds[f]
                else:
                    for f in range(F):
                        feature_data = ts_values_flat[:, f]
                        feature_mean = feature_data.mean()
                        feature_std = feature_data.std()
                        if feature_std > 1e-6:
                            ts_values_flat[:, f] = (feature_data - feature_mean) / feature_std
                padded_time_series = ts_values_flat.view(B, N_med, T, F)
            
            packed_ts = {
                'ts_data': padded_time_series,
                'seq_lengths': time_series_lengths
            }
            
            optimizer.zero_grad(set_to_none=True)
            
            try:
                outputs_dict = model(images, text_item, packed_ts, static_data)
            except Exception as e:
                print(f"Error in model forward: {e}")
                import traceback
                traceback.print_exc()
                continue
            
            total_loss = 0.0
            task_losses = {}
            
            for task_name in task_names:
                target = labels.get(task_name)
                if target is None or task_name not in outputs_dict:
                    continue
                
                target = target.float().to(device)
                if target.dim() >= 2:
                    target = target[:, 0]
                if target.dim() > 1:
                    if target.shape[-1] > 1:
                        target = target.argmax(dim=-1).float()
                    else:
                        target = target.squeeze(-1)
                target = target.view(-1)
                
                task_logits = outputs_dict[task_name]
                if not isinstance(task_logits, torch.Tensor):
                    continue
                
                if task_logits.dim() >= 3:
                    task_logits = task_logits[:, 0]
                if task_logits.dim() == 2:
                    if task_logits.shape[-1] == 2:
                        task_logits = task_logits[:, 1]
                    elif task_logits.shape[-1] == 1:
                        task_logits = task_logits.squeeze(-1)
                    else:
                        task_logits = task_logits[:, 0]
                task_logits = task_logits.view(-1)
                
                batch_size = target.shape[0]
                if task_logits.shape[0] != batch_size:
                    if task_logits.shape[0] > batch_size:
                        task_logits = task_logits[:batch_size]
                    else:
                        continue
                
                loss = criterions[task_name](task_logits, target)
                
                if torch.isnan(loss) or torch.isinf(loss):
                    continue
                
                task_losses[task_name] = loss
                total_loss = total_loss + loss
            
            if len(task_losses) == 0:
                continue
            
            if torch.isnan(total_loss) or torch.isinf(total_loss):
                optimizer.zero_grad(set_to_none=True)
                continue
            
            total_loss.backward()
            
            has_nan_inf_grad = False
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
                        has_nan_inf_grad = True
                        break
            
            if has_nan_inf_grad:
                optimizer.zero_grad(set_to_none=True)
                continue
            
            try:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            except RuntimeError as e:
                print(f"[ERROR] Error during gradient clipping: {e}, skipping optimizer step")
                optimizer.zero_grad(set_to_none=True)
                continue
            
            optimizer.step()
            
            for task_name, loss in task_losses.items():
                loss_value = loss.item()
                if not (torch.isnan(loss) or torch.isinf(loss)):
                    epoch_losses[task_name] += loss_value
            epoch_batches += 1
            
            pbar.set_postfix({
                **{f'{t}_Loss': f"{task_losses.get(t, 0):.4f}" for t in task_names},
                'Step': global_step
            })
            
            global_step += 1
        
        avg_losses = {task: epoch_losses[task] / epoch_batches if epoch_batches > 0 else 0.0 
                     for task in task_names}
        
        print(f"Epoch {epoch+1} - Avg Losses: {avg_losses}")
        
        if val_loader is not None:
            val_losses, val_metrics = validate(model, val_loader, criterions, device, writer, global_step, task_names)
            
            print(f"Validation:")
            for task_name in task_names:
                if task_name in val_metrics:
                    metrics = val_metrics[task_name]
                    print(f"  {task_name} - Loss: {val_losses[task_name]:.4f}, "
                          f"Acc: {metrics['accuracy']:.4f}, F1: {metrics['f1']:.4f}, "
                          f"AUROC: {metrics['auroc']:.4f}, AUPRC: {metrics['auprc']:.4f}")
            
            total_val_loss = sum(val_losses.values())
            total_best_loss = sum(best_val_loss.values())
            if total_val_loss < total_best_loss:
                for task_name in task_names:
                    best_val_loss[task_name] = val_losses[task_name]
                    if task_name in val_metrics:
                        best_val_metrics[task_name] = val_metrics[task_name]
                
                model_save_path = os.path.join(log_path, 'best_model.pth')
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_losses': val_losses,
                    'val_metrics': val_metrics
                }, model_save_path)
                print(f"New best model saved! Total Val Loss: {total_val_loss:.4f}")
    
    writer.close()
    
    best_model_path = os.path.join(log_path, 'best_model.pth')
    test_metrics = None
    
    if test_loader is not None and os.path.exists(best_model_path):
        print("\n" + "="*80)
        print("Evaluating on test set...")
        print("="*80)
        
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        test_writer = SummaryWriter(os.path.join(log_path, 'test'))
        
        test_losses, test_metrics = validate(model, test_loader, criterions, device, test_writer, global_step, task_names)
        test_writer.close()
        
        print(f"\nTest Set Results:")
        for task_name in task_names:
            if task_name in test_metrics:
                metrics = test_metrics[task_name]
                print(f"\n  {task_name}:")
                print(f"    Loss: {test_losses[task_name]:.4f}")
                print(f"    Accuracy: {metrics['accuracy']:.4f}")
                print(f"    F1: {metrics['f1']:.4f}")
                print(f"    AUROC: {metrics['auroc']:.4f}")
                print(f"    AUPRC: {metrics['auprc']:.4f}")
                print(f"    Precision: {metrics['precision']:.4f}")
                print(f"    Recall: {metrics['recall']:.4f}")
                print(f"    Confusion Matrix: TN={metrics['tn']}, FP={metrics['fp']}, FN={metrics['fn']}, TP={metrics['tp']}")
        
        test_report_path = os.path.join(log_path, 'test_results.json')
        with open(test_report_path, 'w') as f:
            json.dump({
                'test_metrics': {k: {mk: float(mv) if isinstance(mv, (int, float)) else str(mv) 
                                   for mk, mv in v.items()} 
                               for k, v in test_metrics.items()},
                'test_losses': {k: float(v) for k, v in test_losses.items()},
                'timestamp': datetime.now().isoformat()
            }, f, indent=2)
        print(f"\nTest results saved to: {test_report_path}")
    else:
        print("\nSkipping test set evaluation (test_loader is None or best_model.pth not found)")
    
    return best_val_metrics, best_model_path


if __name__ == "__main__":
    # Get release directory path
    SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
    RELEASE_DIR = os.path.dirname(SCRIPT_DIR)
    
    parser = argparse.ArgumentParser(description='Train baseline models (TimesNet, iTransformer, PatchTST)')
    parser.add_argument('--model', type=str, required=True, choices=['TimesNet', 'iTransformer', 'PatchTST'],
                       help='Model to train')
    parser.add_argument('--gpu', type=str, default='0', help='GPU ID')
    parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers')
    parser.add_argument('--experiment_name', type=str, default=None, help='Experiment name')
    parser.add_argument('--max_samples', type=int, default=80000, help='Maximum number of samples for training (0 for full dataset)')
    parser.add_argument('--max_val_samples', type=int, default=10000, help='Maximum number of samples for validation (0 for full validation set)')
    parser.add_argument('--max_test_samples', type=int, default=10000, help='Maximum number of samples for testing (0 for full test set)')
    parser.add_argument('--log_dir', type=str, default=None, help='TensorBoard log directory')
    parser.add_argument('--config', type=str, default=None, help='Path to config file')
    
    args = parser.parse_args()
    
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Set default log_dir if not provided
    if args.log_dir is None:
        args.log_dir = os.path.join(RELEASE_DIR, 'runs', 'classify')
    
    print("="*80)
    print(f"Training {args.model} Baseline Model")
    print("="*80)
    
    # Load configuration
    if args.config:
        config_path = args.config if os.path.isabs(args.config) else os.path.join(RELEASE_DIR, args.config)
    else:
        # Auto-detect: check if offline data exists
        offline_data_path = os.path.join(RELEASE_DIR, "data_dir/sample_data/sample_data.pkl")
        offline_config = os.path.join(RELEASE_DIR, "exp/mimic_data/exp_mortality_24h48h_los_offline.yaml")
        default_config = os.path.join(RELEASE_DIR, "exp/mimic_data/exp_mortality_24h48h_los.yaml")
        
        if os.path.exists(offline_data_path) and os.path.exists(offline_config):
            config_path = offline_config
            print("Offline data detected, using offline configuration")
        else:
            config_path = default_config
    
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")
    opt = OmegaConf.load(config_path)
    
    # Create dataset
    print("\nCreating dataset...")
    
    # Check if offline data is available
    offline_data_path = opt.data.train_val.get('offline_data_path', None)
    if offline_data_path:
        # Resolve relative path relative to release directory
        if not os.path.isabs(offline_data_path):
            offline_data_path = os.path.join(RELEASE_DIR, offline_data_path)
        
        if os.path.exists(offline_data_path):
            print(f"Using offline sample data: {offline_data_path}")
            from datapress.Aligned.offline_sample_dataset import OfflineSampleDataset
            med_dataset = OfflineSampleDataset(offline_data_path)
            wrapped_dataset = MedicalDatasetWrapper(med_dataset)
            print(f"Offline dataset size: {len(wrapped_dataset)}")
        else:
            # Offline data path specified but file doesn't exist
            raise FileNotFoundError(f"Offline data file not found: {offline_data_path}")
    else:
        # Use regular dataset with database
        try:
            from datapress.Aligned.medical_dataset_with_los import MedicalDatasetWithLOS
        except ImportError:
            raise ImportError(
                "MedicalDatasetWithLOS not available. This appears to be an offline-only release. "
                "Please use offline data by setting 'offline_data_path' in the configuration file, "
                "or restore the medical_dataset_with_los.py file for database mode."
            )
        med_dataset = MedicalDatasetWithLOS(
            **opt.data.train_val,
            **opt.data.shared_param,
            los_cache_file="data_dir/cache/los_cache.pkl",
            threshold_hours=48
        )
        print(f"Medical dataset size: {len(med_dataset)}")
        wrapped_dataset = MedicalDatasetWrapper(med_dataset)
        print(f"Wrapped dataset size: {len(wrapped_dataset)}")
    
    total_size = len(wrapped_dataset)
    train_size = int(0.8 * total_size)
    remaining_size = total_size - train_size
    val_size = int(0.5 * remaining_size)
    test_size = remaining_size - val_size
    
    train_full, val_full, test_full = random_split(
        wrapped_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"\nDataset splits: Train={len(train_full)}, Val={len(val_full)}, Test={len(test_full)}")
    
    if args.max_samples > 0 and args.max_samples < len(train_full):
        train_indices = list(range(args.max_samples))
        train_dataset = Subset(train_full, train_indices)
        print(f"Using subset of training dataset: {len(train_dataset)} samples")
    else:
        train_dataset = train_full
    
    if args.max_val_samples > 0 and args.max_val_samples < len(val_full):
        val_indices = list(range(args.max_val_samples))
        val_dataset = Subset(val_full, val_indices)
        print(f"Using subset of validation dataset: {len(val_dataset)} samples")
    else:
        val_dataset = val_full
    
    if args.max_test_samples > 0 and args.max_test_samples < len(test_full):
        test_indices = list(range(args.max_test_samples))
        test_dataset = Subset(test_full, test_indices)
        print(f"Using subset of test dataset: {len(test_dataset)} samples")
    else:
        test_dataset = test_full
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        collate_fn=collate_medical_batch
    )
    
    print(f"\nDataLoader sizes: Train={len(train_loader)} batches, Val={len(val_loader)} batches, Test={len(test_loader)} batches")
    
    print(f"\nCreating {args.model} model...")

    model = create_baseline_model(args.model, ts_dim=97, seq_len=167, device=device)
    
    print("\nCalculating pos_weights from training data...")
    pos_weights = {}
    for task_name in TASK_NAMES:
        pos_count = 0
        neg_count = 0
        total_count = 0
        sample_count = 0
        max_samples = 1000
        
        for batch_idx, batch in enumerate(train_loader):
            if sample_count >= max_samples:
                break
            _, _, _, _, _, labels = batch
            if isinstance(labels, dict) and task_name in labels:
                target = labels[task_name].float()
                if target.dim() == 2 and target.shape[1] == 2:
                    target = target[:, 1]
                target = target.view(-1)
                pos_count += int((target == 1).sum().item())
                neg_count += int((target == 0).sum().item())
                total_count += len(target)
                sample_count += len(target)
        
        if total_count > 0 and pos_count > 0:
            pos_weight = neg_count / pos_count
            pos_weights[task_name] = pos_weight
            pos_ratio = pos_count / total_count
            print(f"  {task_name}: Pos={pos_count} ({pos_ratio*100:.2f}%), Neg={neg_count}, pos_weight={pos_weight:.4f}")
        else:
            if task_name == 'mortality_24h_48h':
                pos_weights[task_name] = 0.02
            else:
                pos_weights[task_name] = 4.0
            print(f"  {task_name}: Using default pos_weight={pos_weights[task_name]:.4f}")
    
    print("\nStarting training...")
    experiment_name = args.experiment_name or f"{args.model.lower()}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    result = train_baseline_model(
        args.model, model, train_loader, val_loader=val_loader, test_loader=test_loader,
        num_epochs=args.epochs, lr=args.lr, device=device,
        log_dir=args.log_dir,
        experiment_name=experiment_name,
        pos_weights=pos_weights,
        task_names=TASK_NAMES
    )
    
    if isinstance(result, tuple):
        best_val_metrics, best_model_path = result
    else:
        best_val_metrics = result
        best_model_path = None
    
    print("\n" + "="*80)
    print(f"Final Summary - {args.model}")
    print("="*80)
    print(f"\nBest Validation Metrics:")
    for task_name in TASK_NAMES:
        if task_name in best_val_metrics:
            print(f"\n  {task_name}:")
            for key, value in best_val_metrics[task_name].items():
                if isinstance(value, (int, float)):
                    print(f"    {key}: {value:.4f}")
    
    test_results_path = os.path.join(args.log_dir, experiment_name, 'test_results.json')
    if os.path.exists(test_results_path):
        print(f"\n{'='*80}")
        print("Test Set Results:")
        print("="*80)
        with open(test_results_path, 'r') as f:
            test_results = json.load(f)
            for task_name in TASK_NAMES:
                if task_name in test_results['test_metrics']:
                    metrics = test_results['test_metrics'][task_name]
                    print(f"\n  {task_name}:")
                    print(f"    Loss: {test_results['test_losses'][task_name]:.4f}")
                    print(f"    Accuracy: {metrics['accuracy']:.4f}")
                    print(f"    F1: {metrics['f1']:.4f}")
                    print(f"    AUROC: {metrics['auroc']:.4f}")
                    print(f"    AUPRC: {metrics['auprc']:.4f}")
                    print(f"    Precision: {metrics['precision']:.4f}")
                    print(f"    Recall: {metrics['recall']:.4f}")
    
    report_path = os.path.join(args.log_dir, experiment_name, 'final_report.json')
    report = {
        'model': args.model,
        'experiment_name': experiment_name,
        'config': {
            'max_samples': args.max_samples,
            'max_val_samples': args.max_val_samples,
            'max_test_samples': args.max_test_samples,
            'epochs': args.epochs,
            'batch_size': args.batch_size,
            'lr': args.lr
        },
        'best_validation_metrics': {k: {mk: float(mv) if isinstance(mv, (int, float)) else str(mv) 
                                        for mk, mv in v.items()} 
                                    for k, v in best_val_metrics.items()},
        'timestamp': datetime.now().isoformat()
    }
    
    if os.path.exists(test_results_path):
        with open(test_results_path, 'r') as f:
            test_results = json.load(f)
            report['test_metrics'] = test_results['test_metrics']
            report['test_losses'] = test_results['test_losses']
    
    with open(report_path, 'w') as f:
        json.dump(report, f, indent=2)
    
    print(f"\nFinal report saved to: {report_path}")

