import os
import sys
import gc
import psutil

def check_system_memory():
    # Checks available system memory and returns whether low-memory optimizations are needed
    memory = psutil.virtual_memory()
    available_gb = memory.available / (1024**3)
    total_gb = memory.total / (1024**3)
    
    print(f"System Memory: {available_gb:.1f}GB available / {total_gb:.1f}GB total")
    
    if available_gb < 4.0:
        print("WARNING: Low system memory detected. Applying optimizations...")
        return True
    return False

low_memory = check_system_memory()

if low_memory or os.name == 'nt':
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ['TORCH_USE_CUDA_DSA'] = '1'

try:
    import torch
    if torch.cuda.is_available():
        try:
            test_tensor = torch.zeros(1).cuda()
            del test_tensor
            torch.cuda.empty_cache()
            DEVICE = "cuda"
            print("CUDA available and working")
        except Exception as e:
            print(f"CUDA available but not working: {e}")
            print("Falling back to CPU mode")
            DEVICE = "cpu"
    else:
        DEVICE = "cpu"
        print("CUDA not available, using CPU")
except Exception as e:
    print(f"Error importing torch: {e}")
    print("This might be due to insufficient virtual memory.")
    print("Please try the system-level fixes mentioned in the console.")
    sys.exit(1)

import json
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import List, Dict, Any, Optional

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image

from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, EvalPrediction, PreTrainedModel, EarlyStoppingCallback
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
import safetensors

from sklearn.metrics import (
    roc_auc_score,
    f1_score,
    recall_score,
    precision_score,
)
from sklearn.metrics import average_precision_score
from sklearn.model_selection import train_test_split
from tabulate import tabulate

import math
import random
import logging
import numpy as np

from dataclasses import dataclass

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('training_mimic_on_chexpert_optimized')

# ==========================
# GAZE/ALIGNMENT CONFIG TOGGLES (no CLI args)
# ==========================
CONFIG = {
    "USE_GAZE_TOKEN": True,
    "USE_GAZE_WEIGHTED_POOLING": True,
    "USE_SOFT_GATING": True,
    "USE_AGREEMENT_LOSS": True,
    "USE_MULTI_SCALE_SUPERVISION": True,
    "USE_TEMPORAL_CONFIRM_FOCUS": True,
    "USE_ANATOMY_MASKING": True,
    "USE_PHRASE_REGION_GROUNDING": True,
    "SAVE_OVERLAYS": True,
}

# Condition-aware gaze strength (higher = stronger gaze influence)
COND_GAZE_WEIGHTS = torch.tensor([
    0.8,
    0.6,
    0.9,
    0.7,
    0.4,
    1.0,
    0.7,
    1.0,
], dtype=torch.float32)

############################
# 1) Fixed label list (8 conditions as requested)
############################
CONDITIONS = [
    'Atelectasis', 'Cardiomegaly', 'Edema',
    'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pneumonia', 'Support Devices'
]

cond2idx = {c: i for i, c in enumerate(CONDITIONS)}

#####################
# === CONSTANTS === #
#####################
TRAIN_CSV    = "dataset_splits/train.csv"
VAL_CSV      = "dataset_splits/val.csv"
TEST_CSV     = "dataset_splits/test.csv"

FIXSTATS_NPZ = "fixstats.npz"

IMG_DIR      = os.path.join("data_dump", "output", "img_png")
BBOX_DIR     = os.path.join("data_dump", "output", "bbox_mask")
FIXSEQ_DIR   = os.path.join("data_dump", "output", "fix_seq")

VIT_MODEL_NAME = "google/vit-base-patch16-224-in21k"
TEXT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

CHEXPERT_MODEL_PATH = os.path.join("new_scripts", "output", "vit_chexpert_training_output", "model_1.chexpert_finetune_only_feat_image", "model.safetensors")

OUTPUT_DIR   = os.path.join("new_scripts", "output", "0.(full+enhanced_gaze)_training_mimic_on_chexpert_optimized")
EXP_NAME     = "0.(full+enhanced_gaze)_training_mimic_on_chexpert_optimized"

# Memory-optimized hyperparameters
if low_memory or DEVICE == "cpu":
    BATCH_SIZE   = 8
    GRAD_ACCUM   = 12
    DATALOADER_WORKERS = 0
    FP16_ENABLED = False
else:
    BATCH_SIZE   = 32
    GRAD_ACCUM   = 6
    DATALOADER_WORKERS = 2
    FP16_ENABLED = DEVICE == "cuda"

LR           = 6e-6
WD           = 8e-4
EPOCHS       = 40
EARLY_STOPPING_PATIENCE = 10
WARMUP_STEPS = 175

SEED         = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

if DEVICE == "cuda":
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

os.makedirs(OUTPUT_DIR, exist_ok=True)

#############################
# === LOAD & FIX STATS ===  #
#############################
try:
    stats = np.load(FIXSTATS_NPZ)
    MU, SIGMA = stats["mu"], stats["sigma"]
    SIGMA = SIGMA + 1e-6
except Exception as e:
    logger.error(f"Error loading fixation stats: {e}")
    logger.warning("Using default fixation stats")
    MU, SIGMA = np.zeros(4), np.ones(4)

#########################
# === DATASET DEF ===   #
#########################
class CXRDataset(Dataset):
    def __init__(self, df: pd.DataFrame, split: str):
        self.df = df.reset_index(drop=True)
        self.split = split
        
        self.df = self.df[
            self.df["bbox_path"].notna() & 
            self.df["fixations_path"].notna() &
            self.df["transcript_path"].notna()
        ].reset_index(drop=True)
        
        self.use_geo_aug = split == "train" and not low_memory

        self.photometric_tf = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1) if self.use_geo_aug else None
        self.random_erasing = transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)) if self.use_geo_aug else None

        self.to_tensor_norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
        ])

        self.mask_tf = transforms.Compose([
            transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        try:
            img_path = os.path.join(IMG_DIR, Path(row["image_path"]).stem + ".png")
            img_raw = Image.open(img_path).convert("RGB")
        except Exception as e:
            logger.error(f"Error loading image {row['image_path']}: {e}")
            img_raw = Image.new("RGB", (224, 224))

        try:
            bbox_path = os.path.join(BBOX_DIR, Path(row["bbox_path"]).stem + ".png")
            bbox_raw = Image.open(bbox_path).convert("L")
        except Exception as e:
            logger.error(f"Error loading mask {row['bbox_path']}: {e}")
            bbox_raw = Image.new("L", (img_raw.size[0], img_raw.size[1]))
        
        try:
            fix_path = os.path.join(FIXSEQ_DIR, Path(row["fixations_path"]).stem + ".npz")
            arr = np.load(fix_path)
            seq = arr["seq"].astype(np.float32)
            mask = arr["mask"].astype(bool)
            
            seq = np.nan_to_num(seq, nan=0.0, posinf=0.0, neginf=0.0)
            
            seq = torch.from_numpy(seq).float()
            seq = seq.clamp(-10.0, 10.0)
            mask = torch.from_numpy(mask)
            
            arr.close() if hasattr(arr, 'close') else None
            
        except Exception as e:
            logger.error(f"Error loading fixation sequence {row['fixations_path']}: {e}")
            seq = torch.zeros(128, 4)
            mask = torch.zeros(128, dtype=torch.bool)
        
        try:
            transcript_path = row["transcript_path"]
            if pd.isna(transcript_path) or not os.path.exists(transcript_path):
                transcript_text = ""
            else:
                with open(transcript_path, 'r', encoding='utf-8') as f:
                    transcript_data = json.load(f)
                    transcript_text = transcript_data.get("transcript", "")
        except Exception as e:
            logger.error(f"Error loading transcript {row.get('transcript_path', 'N/A')}: {e}")
            transcript_text = ""
        
        labels = torch.tensor(row["cond_vec"], dtype=torch.float32)

        img_pil, bbox_pil, seq, mask = self._apply_geo_consistent_transforms(img_raw, bbox_raw, seq, mask)

        skip_gaze = bool(mask.sum().item() < 3)

        if self.photometric_tf is not None:
            img_pil = self.photometric_tf(img_pil)

        img = self.to_tensor_norm(img_pil)

        try:
            bbox = TF.to_tensor(bbox_pil)
        except Exception:
            bbox = self.mask_tf(bbox_pil)

        if low_memory and idx % 100 == 0:
            gc.collect()
        
        return {
            "img": img, 
            "bbox": bbox, 
            "fix_seq": seq, 
            "fix_mask": mask, 
            "transcript": transcript_text,
            "labels": labels,
            "skip_gaze": torch.tensor(1 if skip_gaze else 0, dtype=torch.int8)
        }

    def _apply_geo_consistent_transforms(self, img: Image.Image, bbox_img: Image.Image, seq: torch.Tensor, fix_mask: torch.Tensor):
        # Applies identical geometric transforms to image, mask, and fixation coordinates
        try:
            W0, H0 = img.size
        except Exception:
            W0, H0 = 224, 224

        if isinstance(seq, torch.Tensor) and seq.ndim == 2 and seq.size(1) >= 2:
            x_norm = seq[:, 0].clone()
            y_norm = seq[:, 1].clone()
        else:
            x_norm = torch.zeros_like(fix_mask, dtype=torch.float32)
            y_norm = torch.zeros_like(fix_mask, dtype=torch.float32)

        x_pix = (x_norm * 0.5 + 0.5) * (W0 - 1)
        y_pix = (y_norm * 0.5 + 0.5) * (H0 - 1)

        image = img
        bbox = bbox_img

        image = TF.resize(image, [256, 256])
        bbox = TF.resize(bbox, [256, 256], interpolation=transforms.InterpolationMode.NEAREST)
        sx, sy = 256.0 / W0, 256.0 / H0
        x_pix = x_pix * sx
        y_pix = y_pix * sy

        if self.use_geo_aug:
            i, j, h, w = transforms.RandomResizedCrop.get_params(image, scale=(0.85, 1.0), ratio=(0.9, 1.1))
            image = TF.resized_crop(image, i, j, h, w, size=[224, 224])
            bbox = TF.resized_crop(bbox, i, j, h, w, size=[224, 224], interpolation=transforms.InterpolationMode.NEAREST)
            x_pix = (x_pix - j) * (224.0 / w)
            y_pix = (y_pix - i) * (224.0 / h)

            if random.random() < 0.5:
                image = TF.hflip(image)
                bbox = TF.hflip(bbox)
                x_pix = (224.0 - 1.0) - x_pix

            angle = random.uniform(-5.0, 5.0)
            if abs(angle) > 1e-3:
                image = TF.rotate(image, angle, interpolation=transforms.InterpolationMode.BILINEAR, expand=False)
                bbox = TF.rotate(bbox, angle, interpolation=transforms.InterpolationMode.NEAREST, expand=False)
                rad = math.radians(angle)
                cx, cy = 112.0, 112.0
                x_shift = x_pix - cx
                y_shift = y_pix - cy
                cos_t, sin_t = math.cos(rad), math.sin(rad)
                x_rot = x_shift * cos_t - y_shift * sin_t
                y_rot = x_shift * sin_t + y_shift * cos_t
                x_pix = x_rot + cx
                y_pix = y_rot + cy

            tx = (random.uniform(-0.05, 0.05)) * 224.0
            ty = (random.uniform(-0.05, 0.05)) * 224.0
            if abs(tx) + abs(ty) > 1e-6:
                image = TF.affine(image, angle=0.0, translate=(int(tx), int(ty)), scale=1.0, shear=[0.0, 0.0])
                bbox = TF.affine(bbox, angle=0.0, translate=(int(tx), int(ty)), scale=1.0, shear=[0.0, 0.0])
                x_pix = x_pix + tx
                y_pix = y_pix + ty
        else:
            image = TF.resize(image, [224, 224])
            bbox = TF.resize(bbox, [224, 224], interpolation=transforms.InterpolationMode.NEAREST)
            x_pix = x_pix * (224.0 / 256.0)
            y_pix = y_pix * (224.0 / 256.0)

        valid = (x_pix >= 0.0) & (x_pix <= 223.0) & (y_pix >= 0.0) & (y_pix <= 223.0)
        if not isinstance(fix_mask, torch.Tensor) or fix_mask.numel() != valid.numel():
            fix_mask = valid.clone()
        else:
            fix_mask = fix_mask & valid

        x_norm_new = ((x_pix / 223.0) - 0.5) * 2.0
        y_norm_new = ((y_pix / 223.0) - 0.5) * 2.0

        if isinstance(seq, torch.Tensor) and seq.ndim == 2 and seq.size(1) >= 4:
            seq_out = seq.clone()
            seq_out[:, 0] = x_norm_new
            seq_out[:, 1] = y_norm_new
        else:
            L = x_norm_new.numel()
            seq_out = torch.zeros(L, 4, dtype=torch.float32)
            seq_out[:, 0] = x_norm_new
            seq_out[:, 1] = y_norm_new

        return image, bbox, seq_out, fix_mask

###############################
# === LOAD & SPLIT DATA ===   #
###############################
train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)
test_df = pd.read_csv(TEST_CSV)

def make_cond_vec(cs: str):
    if pd.isna(cs) or cs == "":
        return [0] * len(CONDITIONS)
    if isinstance(cs, list):
        conds = cs
    elif isinstance(cs, str):
        if '|' in cs:
            conds = [c.strip() for c in cs.split('|')]
        elif ',' in cs:
            conds = [c.strip() for c in cs.split(',')]
        else:
            conds = [cs.strip()]
    else:
        conds = []
    vec = [0] * len(CONDITIONS)
    for c in conds:
        if c in cond2idx:
            vec[cond2idx[c]] = 1
    return vec

train_df["cond_vec"] = train_df["condition"].apply(make_cond_vec)
val_df["cond_vec"] = val_df["condition"].apply(make_cond_vec)
test_df["cond_vec"] = test_df["condition"].apply(make_cond_vec)

train_dataset = CXRDataset(train_df, split="train")

val_dataset = CXRDataset(val_df, split="val")

test_dataset = CXRDataset(test_df, split="test")

#################################
# Custom Data Collator for multimodal batching #
#################################
@dataclass
class MultiModalDataCollator:
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = {}
        
        batch["img"] = torch.stack([f["img"] for f in features])
        
        batch["bbox"] = torch.stack([f["bbox"] for f in features])
        
        max_fix_len = max(f["fix_seq"].size(0) for f in features)
        fix_seqs = []
        fix_masks = []
        
        for f in features:
            seq = f["fix_seq"]
            mask = f["fix_mask"]
            
            pad_len = max_fix_len - seq.size(0)
            if pad_len > 0:
                seq = torch.cat([seq, torch.zeros(pad_len, seq.size(1))], dim=0)
                mask = torch.cat([mask, torch.zeros(pad_len, dtype=torch.bool)], dim=0)
            
            fix_seqs.append(seq)
            fix_masks.append(mask)
        
        batch["fix_seq"] = torch.stack(fix_seqs)
        batch["fix_mask"] = torch.stack(fix_masks)
        
        batch["transcript"] = [f["transcript"] for f in features]
        
        batch["labels"] = torch.stack([f["labels"] for f in features])
        if "skip_gaze" in features[0]:
            batch["skip_gaze"] = torch.stack([f["skip_gaze"] for f in features])
        
        return batch

#################################
# Custom Trainer for Contiguous Saving #
#################################
class ContiguousSavingTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step_count = 0
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        self.step_count += 1
        
        if self.step_count % 5 == 0:
            monitor_memory()
            
        if self.step_count % 10 == 0:
            aggressive_memory_cleanup()
        
        if hasattr(model, 'training_step') and isinstance(model.training_step, torch.Tensor):
            model.training_step += 1

        outputs = model(
            img=inputs["img"],
            bbox=inputs["bbox"], 
            fix_seq=inputs["fix_seq"],
            fix_mask=inputs["fix_mask"],
            transcript=inputs["transcript"],
            labels=inputs["labels"],
            skip_gaze=inputs.get("skip_gaze", None)
        )
        
        loss = outputs["loss"]
        
        if self.step_count % 3 == 0:
            gc.collect()
            
        return (loss, outputs) if return_outputs else loss
    
    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        model.eval()
        
        aggressive_memory_cleanup()
        
        with torch.no_grad():
            outputs = model(
                img=inputs["img"],
                bbox=inputs["bbox"],
                fix_seq=inputs["fix_seq"], 
                fix_mask=inputs["fix_mask"],
                transcript=inputs["transcript"],
                labels=inputs["labels"],
                skip_gaze=inputs.get("skip_gaze", None)
            )
            
            loss = outputs["loss"]
            logits = outputs["logits"]
            
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        if prediction_loss_only:
            return (loss, None, None)
        
        return (loss, logits, inputs["labels"])

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        print(f"ContiguousSavingTrainer: Saving model checkpoint to {output_dir}")
        
        aggressive_memory_cleanup()

        if isinstance(self.model, PreTrainedModel):
            super()._save(output_dir, state_dict=state_dict)
        else:
            model_to_save = self.model
            
            if state_dict is None:
                current_state_dict = model_to_save.state_dict()
            else:
                current_state_dict = state_dict

            contiguous_state_dict = {}
            for k, v in current_state_dict.items():
                if isinstance(v, torch.Tensor):
                    contiguous_state_dict[k] = v.contiguous()
                else:
                    contiguous_state_dict[k] = v
            
            if self.args.save_safetensors:
                safetensors.torch.save_file(
                    contiguous_state_dict, 
                    os.path.join(output_dir, SAFE_WEIGHTS_NAME)
                )
            else:
                torch.save(
                    contiguous_state_dict, 
                    os.path.join(output_dir, WEIGHTS_NAME)
                )

            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)

            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
            
        aggressive_memory_cleanup()
        monitor_memory()

#########################
# === MODEL DEF ===     #
#########################
class MultiModalMIMICModel(nn.Module):
    def __init__(self, num_conditions: int, contrastive_temperature: float = 0.1, loss_type="asymmetric"):
        super().__init__()
        logger.info(f"Initializing MultiModalMIMICModel with {num_conditions} conditions using transfer learning from CheXpert.")
        self.contrastive_temperature = contrastive_temperature
        self.loss_type = loss_type
        
        self._keys_to_ignore_on_save = None

        self.vit = AutoModel.from_pretrained(VIT_MODEL_NAME, output_attentions=True)
        img_feature_dim = self.vit.config.hidden_size
        logger.info(f"Loaded ViT base model. Feature dimension: {img_feature_dim}")

        if low_memory:
            # Simpler architecture for low memory
            self.img_proj = nn.Sequential(
                nn.Linear(img_feature_dim, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1),
                nn.Linear(512, 512)
            )
            hidden_dim = 512
        else:
            self.img_proj = nn.Sequential(
                nn.Linear(img_feature_dim, 768),
                nn.LayerNorm(768),
                nn.GELU(),
                nn.Dropout(0.15),
                nn.Linear(768, 768),
                nn.LayerNorm(768),
                nn.Dropout(0.1)
            )
            hidden_dim = 768
        
        if low_memory:
            self.bbox_encoder = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((2, 2))
            )
            self.bbox_proj = nn.Sequential(
                nn.Linear(128, 512),  # 32*4 = 128 -> 512
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1)
            )
        else:
            self.bbox_encoder = nn.Sequential(
                nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((2, 2))
            )
            self.bbox_proj = nn.Sequential(
                nn.Linear(512, 768),
                nn.LayerNorm(768),
                nn.GELU(),
                nn.Dropout(0.15),
                nn.Linear(768, 768),
                nn.LayerNorm(768),
                nn.Dropout(0.1)
            )
        
        if low_memory:
            self.fix_emb = nn.Linear(4, 64)
            self.fix_gru = nn.GRU(64, 128, num_layers=1, batch_first=True, bidirectional=True, dropout=0.1)
            self.fix_proj = nn.Sequential(
                nn.Linear(256, 512),  # 2 * 128 = 256 -> 512
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1)
            )
        else:
            self.fix_emb = nn.Linear(4, 128)
            self.fix_gru = nn.GRU(128, 384, num_layers=2, batch_first=True, bidirectional=True, dropout=0.15)
            self.fix_proj = nn.Sequential(
                nn.Linear(1536, 768),
                nn.LayerNorm(768),
                nn.GELU(),
                nn.Dropout(0.15),
                nn.Linear(768, 768),
                nn.LayerNorm(768),
                nn.Dropout(0.1)
            )
        
        self.tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
        self.text_encoder = AutoModel.from_pretrained(TEXT_MODEL_NAME)
        text_feature_dim = self.text_encoder.config.hidden_size
        
        if low_memory:
            self.text_proj = nn.Sequential(
                nn.Linear(text_feature_dim, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.1)
            )
        else:
            self.text_proj = nn.Sequential(
                nn.Linear(text_feature_dim, 768),
                nn.LayerNorm(768),
                nn.GELU(),
                nn.Dropout(0.15),
                nn.Linear(768, 768),
                nn.LayerNorm(768),
                nn.Dropout(0.1)
            )
        
        # === ENHANCED GAZE-GUIDED ARCHITECTURE (Priority 2) ===
        if not low_memory:
            self.vit_attention_head = nn.MultiheadAttention(
                embed_dim=img_feature_dim,
                num_heads=8,
                dropout=0.1,
                batch_first=True
            )
            self.attention_projection = nn.Sequential(
                nn.Linear(img_feature_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 1),
                nn.Sigmoid()
            )
            self.gaze_token_proj = nn.Linear(768 if not low_memory else 512, img_feature_dim)
            self.gaze_token_attn = nn.MultiheadAttention(img_feature_dim, 8, dropout=0.1, batch_first=True)
            self.gaze_token_norm = nn.LayerNorm(img_feature_dim)
            
            self.gaze_text_projector = nn.Sequential(
                nn.Linear(768, 512),
                nn.LayerNorm(512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.LayerNorm(256)
            )
            
            self.spatial_attention_multi_scale = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(hidden_dim, 512),
                    nn.ReLU(),
                    nn.Linear(512, scale * scale),
                    nn.Sigmoid()
                ) for scale in [14, 28, 56]
            ])
        
        if low_memory:
            self.fusion = nn.Sequential(
                nn.Linear(hidden_dim * 4, hidden_dim * 2),  # 512*4 -> 1024
                nn.LayerNorm(hidden_dim * 2),
                nn.GELU(),
                nn.Dropout(0.15),
                nn.Linear(hidden_dim * 2, hidden_dim)  # 1024 -> 512
            )
            
            # Simple classifier for low memory
            self.global_classifier = nn.Sequential(
                nn.Linear(hidden_dim, 256),
                nn.LayerNorm(256),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(256, num_conditions)
            )
            self.condition_specific_heads = None
            self.fusion_attention = None
        else:
            self.fusion_attention = nn.MultiheadAttention(768, 8, dropout=0.1, batch_first=True)
            self.fusion_norm = nn.LayerNorm(768)
            
            self.fusion = nn.Sequential(
                nn.Linear(768 * 4, 1536),
                nn.LayerNorm(1536),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(1536, 768),
                nn.LayerNorm(768),
                nn.GELU(),
                nn.Dropout(0.15)
            )
            
            self.condition_specific_heads = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(768, 256),
                    nn.LayerNorm(256),
                    nn.GELU(),
                    nn.Dropout(0.3),
                    nn.Linear(256, 1)
                ) for _ in range(num_conditions)
            ])
            
            self.global_classifier = nn.Sequential(
                nn.Linear(768, 512),
                nn.LayerNorm(512),
                nn.GELU(),
                nn.Dropout(0.3),
                nn.Linear(512, num_conditions)
            )

        self.register_buffer("training_step", torch.zeros(1, dtype=torch.long))
        self.register_buffer("gaze_curriculum_max", torch.tensor(1.0))
        
        self._initialize_weights()
        
        self.loss_fn = None

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def set_loss_function(self, pos_weights, loss_type="focal"):
        if loss_type == "focal":
            self.loss_fn = FocalLoss(alpha=1.10, gamma=2.15, pos_weight=pos_weights, label_smoothing=0.08)
        elif loss_type == "asymmetric":
            self.loss_fn = AsymmetricLoss(
                gamma_neg=4,      
                gamma_pos=1,      
                clip=0.05,        
                disable_torch_grad_focal_loss=True
            )
        else:
            self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
        
        logger.info(f"Set loss function: {loss_type} with minimal tweaks (alpha=1.10, gamma=2.15, smoothing=0.08)")

    def load_chexpert_weights(self, chexpert_model_path: str):
        logger.info(f"Loading CheXpert pre-trained weights from {chexpert_model_path}")
        try:
            if not os.path.exists(chexpert_model_path):
                logger.warning(f"CheXpert model not found at {chexpert_model_path}. Skipping weight loading.")
                return
                
            chexpert_state_dict = safetensors.torch.load_file(chexpert_model_path, device="cpu")
            
            filtered_state_dict = {}
            for key, value in chexpert_state_dict.items():
                if key.startswith("vit."):
                    filtered_state_dict[key] = value
                elif key.startswith("img_proj.") and not low_memory:
                    filtered_state_dict[key] = value
            
            missing_keys, unexpected_keys = self.load_state_dict(filtered_state_dict, strict=False)
            logger.info(f"Loaded CheXpert weights. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}")
            
        except Exception as e:
            logger.error(f"Error loading CheXpert weights: {e}")
            logger.warning("Continuing without pre-trained weights")

    def info_nce_loss(self, features1, features2, temperature=None, eps=1e-6):
        temp = temperature if temperature is not None else self.contrastive_temperature
        
        if features1.shape[0] <= 1 or features2.shape[0] <= 1:
            return torch.tensor(0.0, device=features1.device, requires_grad=True)

        features1 = F.normalize(features1, p=2, dim=1)
        features2 = F.normalize(features2, p=2, dim=1)
        similarity_matrix = torch.matmul(features1, features2.T)
        similarity_matrix = torch.clamp(similarity_matrix, -1.0 + eps, 1.0 - eps)
        similarity_matrix = similarity_matrix / temp
        batch_size = features1.shape[0]
        labels = torch.arange(batch_size, device=features1.device)
        loss = F.cross_entropy(similarity_matrix, labels)
        return loss

    def generate_gaze_heatmap(self, fix_seq, fix_mask, image_size=(224, 224)):
        batch_size, seq_len, _ = fix_seq.shape
        height, width = image_size
        
        heatmaps = torch.zeros(batch_size, 1, height, width, device=fix_seq.device)
        
        for b in range(batch_size):
            valid_mask = fix_mask[b]
            if not valid_mask.any():
                continue
                
            valid_fixations = fix_seq[b][valid_mask]
            
            x_norm = valid_fixations[:, 0]
            y_norm = valid_fixations[:, 1]
            duration = valid_fixations[:, 2]
            pupil_area = valid_fixations[:, 3]
            
            x_pixels = (x_norm * 0.5 + 0.5) * (width - 1)
            y_pixels = (y_norm * 0.5 + 0.5) * (height - 1)
            
            x_pixels = torch.clamp(x_pixels, 0, width - 1)
            y_pixels = torch.clamp(y_pixels, 0, height - 1)
            
            duration_weight = duration * 0.5 + 0.5
            pupil_weight = pupil_area * 0.5 + 0.5
            
            t_idx = torch.arange(len(valid_fixations), device=fix_seq.device, dtype=torch.float32)
            if CONFIG.get("USE_TEMPORAL_CONFIRM_FOCUS", True):
                temporal_weights = torch.exp(0.06 * (t_idx - t_idx.mean()))
            else:
                temporal_weights = torch.exp(-0.1 * t_idx)
            
            attention_weights = duration_weight * pupil_weight * temporal_weights
            attention_weights = torch.clamp(attention_weights, 0.05, 1.0)
            
            if attention_weights.sum() > 0:
                attention_weights = attention_weights / attention_weights.sum()
            
            y_grid, x_grid = torch.meshgrid(
                torch.arange(height, device=fix_seq.device, dtype=torch.float32),
                torch.arange(width, device=fix_seq.device, dtype=torch.float32),
                indexing='ij'
            )
            
            for i in range(len(valid_fixations)):
                cx, cy = x_pixels[i], y_pixels[i]
                weight = attention_weights[i]
                
                base_sigma = min(height, width) * 0.08
                sigma = base_sigma * (0.5 + 1.5 * weight)
                
                dist_sq = (x_grid - cx) ** 2 + (y_grid - cy) ** 2
                gaussian = torch.exp(-dist_sq / (2 * sigma ** 2))
                
                heatmaps[b, 0] += weight * gaussian
        
        for b in range(batch_size):
            hmap = heatmaps[b, 0]
            if hmap.max() > 0:
                try:
                    kernel = torch.tensor([[1., 2., 1.], [2., 4., 2.], [1., 2., 1.]], device=hmap.device)
                    kernel = kernel / kernel.sum()
                    kernel = kernel.view(1, 1, 3, 3)
                    
                    hmap_padded = F.pad(hmap.unsqueeze(0).unsqueeze(0), (1, 1, 1, 1), mode='reflect')
                    hmap_smoothed = F.conv2d(hmap_padded, kernel, padding=0).squeeze()
                    hmap = hmap_smoothed
                except Exception:
                    pass
                
                heatmaps[b, 0] = hmap / hmap.max()
        
        return heatmaps

    def gaze_attention_loss(self, model_attention, gaze_heatmap, fix_seq, fix_mask, bbox_mask=None, class_weights: Optional[torch.Tensor]=None):
        batch_size = model_attention.size(0)
        
        if model_attention.shape != gaze_heatmap.shape:
            gaze_heatmap = F.interpolate(gaze_heatmap, size=model_attention.shape[2:], mode='bilinear', align_corners=False)
        
        sample_weights = []
        fixation_quality_scores = []
        
        for b in range(batch_size):
            valid_mask = fix_mask[b]
            if valid_mask.any():
                num_fixations = valid_mask.sum().float()
                valid_fixations = fix_seq[b][valid_mask]
                
                avg_duration = valid_fixations[:, 2].mean()
                avg_pupil = valid_fixations[:, 3].mean()
                duration_std = valid_fixations[:, 2].std() if len(valid_fixations) > 1 else 0.0
                
                if len(valid_fixations) > 1:
                    coords = valid_fixations[:, :2]
                    coord_std = coords.std(dim=0).mean()
                    spatial_spread = torch.clamp(coord_std, 0.1, 1.0)
                else:
                    spatial_spread = 0.5
                
                temporal_quality = (avg_duration * 0.5 + 0.5) * (1.0 - torch.clamp(duration_std, 0, 0.5))
                physiological_quality = avg_pupil * 0.5 + 0.5
                spatial_quality = spatial_spread
                
                quality_score = (temporal_quality + physiological_quality + spatial_quality) / 3.0
                
                sample_weight = torch.sqrt(num_fixations / 8.0) * quality_score
                sample_weight = torch.clamp(sample_weight, 0.05, 1.0)
                
                fixation_quality_scores.append(quality_score)
            else:
                sample_weight = 0.0
                fixation_quality_scores.append(0.0)
            
            sample_weights.append(sample_weight)
        
        sample_weights = torch.stack(sample_weights).to(model_attention.device)
        quality_scores = torch.stack(fixation_quality_scores).to(model_attention.device)
        
        model_flat = model_attention.view(batch_size, -1)
        gaze_flat = gaze_heatmap.view(batch_size, -1)
        
        mse_loss = F.mse_loss(model_flat, gaze_flat, reduction='none').mean(dim=1)
        
        eps = 1e-8
        model_prob = F.softmax(model_flat, dim=1) + eps
        gaze_prob = F.softmax(gaze_flat, dim=1) + eps
        kl_items = (gaze_prob * (gaze_prob.log() - model_prob.log())).sum(dim=1)
        kl_loss = kl_items / model_flat.size(1)
        
        def pearson_correlation_loss(x, y):
            x_mean = x.mean(dim=1, keepdim=True)
            y_mean = y.mean(dim=1, keepdim=True)
            x_centered = x - x_mean
            y_centered = y - y_mean
            
            numerator = (x_centered * y_centered).sum(dim=1)
            x_std = torch.sqrt((x_centered ** 2).sum(dim=1) + eps)
            y_std = torch.sqrt((y_centered ** 2).sum(dim=1) + eps)
            
            correlation = numerator / (x_std * y_std + eps)
            return 1.0 - correlation
        
        correlation_loss = pearson_correlation_loss(model_flat, gaze_flat)
        
        def center_of_mass_loss(pred_map, target_map):
            batch_size, h, w = pred_map.size(0), pred_map.size(2), pred_map.size(3)
            
            y_coords = torch.arange(h, device=pred_map.device, dtype=torch.float32).view(1, 1, h, 1)
            x_coords = torch.arange(w, device=pred_map.device, dtype=torch.float32).view(1, 1, 1, w)
            
            pred_mass = pred_map.sum(dim=(2, 3), keepdim=True) + eps
            target_mass = target_map.sum(dim=(2, 3), keepdim=True) + eps
            
            pred_center_y = (pred_map * y_coords).sum(dim=(2, 3)) / pred_mass.squeeze()
            pred_center_x = (pred_map * x_coords).sum(dim=(2, 3)) / pred_mass.squeeze()
            
            target_center_y = (target_map * y_coords).sum(dim=(2, 3)) / target_mass.squeeze()
            target_center_x = (target_map * x_coords).sum(dim=(2, 3)) / target_mass.squeeze()
            
            center_dist = torch.sqrt((pred_center_y - target_center_y) ** 2 + 
                                   (pred_center_x - target_center_x) ** 2 + eps)
            
            diagonal = torch.sqrt(torch.tensor(h ** 2 + w ** 2, device=pred_map.device, dtype=torch.float32))
            return center_dist / diagonal
        
        com_loss = center_of_mass_loss(model_attention, gaze_heatmap)
        
        if CONFIG.get("USE_AGREEMENT_LOSS", True):
            with torch.no_grad():
                pred_bin = (model_attention > model_attention.mean(dim=(2,3), keepdim=True)).float()
                target_bin = (gaze_heatmap > gaze_heatmap.mean(dim=(2,3), keepdim=True)).float()
                agree = (pred_bin * target_bin).view(batch_size, -1)
                agree_ratio = (agree.sum(dim=1) / (agree.size(1) + 1e-6)).clamp(min=1e-3)
            mse_loss = mse_loss * agree_ratio
            kl_loss = kl_loss * agree_ratio
            correlation_loss = correlation_loss * agree_ratio
            com_loss = com_loss * agree_ratio

        if CONFIG.get("USE_ANATOMY_MASKING", True) and bbox_mask is not None:
            with torch.no_grad():
                mask_flat = F.interpolate(bbox_mask, size=model_attention.shape[2:], mode='nearest').view(batch_size, -1)
                keep_ratio = (mask_flat.mean(dim=1)).clamp(min=1e-3)
            mse_loss = mse_loss * keep_ratio
            kl_loss = kl_loss * keep_ratio
            correlation_loss = correlation_loss * keep_ratio
            com_loss = com_loss * keep_ratio

        if class_weights is not None:
            cw = class_weights.mean()
        else:
            cw = 1.0
        mse_loss = mse_loss / model_flat.size(1)
        base_loss = cw * (0.4 * mse_loss + 0.3 * kl_loss + 0.2 * correlation_loss + 0.1 * com_loss)
        
        quality_weighted_loss = base_loss * sample_weights * (0.5 + 0.5 * quality_scores)
        
        valid_samples = (sample_weights > 0).sum()
        if valid_samples > 0:
            return quality_weighted_loss.sum() / valid_samples
        else:
            return torch.tensor(0.0, device=model_attention.device, requires_grad=True)

    def generate_spatial_attention(self, fused_features, image_size):
        batch_size = fused_features.size(0)
        hidden_dim = fused_features.size(1)
        H, W = image_size
        
        if not hasattr(self, 'spatial_attention_head'):
            if low_memory:
                self.spatial_attention_head = nn.Sequential(
                    nn.Linear(hidden_dim, 256),
                    nn.ReLU(),
                    nn.Linear(256, H * W),
                    nn.Sigmoid()
                ).to(fused_features.device)
            else:
                self.spatial_attention_head = nn.Sequential(
                    nn.Linear(hidden_dim, 512),
                    nn.ReLU(),
                    nn.Dropout(0.1),
                    nn.Linear(512, H * W),
                    nn.Sigmoid()
                ).to(fused_features.device)
        
        attention_scores = self.spatial_attention_head(fused_features)
        attention_map = attention_scores.view(batch_size, 1, H, W)
        
        return attention_map

    def attention_rollout(self, attentions: List[torch.Tensor]):
        if attentions is None or len(attentions) == 0:
            return None
        with torch.no_grad():
            attn_mats = [a.mean(dim=1) for a in attentions]
            rollout = attn_mats[0]
            for a in attn_mats[1:]:
                rollout = a @ rollout
        return rollout

    def _multi_scale_attention_targets(self, vit_attns: Optional[List[torch.Tensor]]):
        if vit_attns is None or len(vit_attns) == 0:
            return None, None
        rollout = self.attention_rollout(vit_attns)
        if rollout is None or rollout.size(1) <= 1:
            return None, None
        cls_to_patches = rollout[:, 0, 1:]
        B, NP = cls_to_patches.size(0), cls_to_patches.size(1)
        ps = int(np.sqrt(NP)) if int(np.sqrt(NP))**2 == NP else 14
        target14 = cls_to_patches.view(B, 1, ps, ps)
        target28 = F.interpolate(target14, size=(28, 28), mode='bilinear', align_corners=False)
        t14 = target14 / (target14.amax(dim=(2,3), keepdim=True) + 1e-8)
        t28 = target28 / (target28.amax(dim=(2,3), keepdim=True) + 1e-8)
        return t14, t28

    def contrastive_gaze_text_loss(self, gaze_features, text_features, fix_mask, transcript_texts, temperature=0.07):
        batch_size = gaze_features.size(0)
        
        valid_samples = []
        valid_gaze_feats = []
        valid_text_feats = []
        
        for i in range(batch_size):
            has_fixations = fix_mask[i].any()
            has_transcript = (i < len(transcript_texts) and 
                            transcript_texts[i] is not None and 
                            len(transcript_texts[i].strip()) > 10)
            
            if has_fixations and has_transcript:
                valid_samples.append(i)
                valid_gaze_feats.append(gaze_features[i])
                valid_text_feats.append(text_features[i])
        
        if len(valid_samples) < 2:
            return torch.tensor(0.0, device=gaze_features.device, requires_grad=True)
        
        valid_gaze_feats = torch.stack(valid_gaze_feats)
        valid_text_feats = torch.stack(valid_text_feats)
        
        if not low_memory and hasattr(self, 'gaze_text_projector'):
            gaze_proj = self.gaze_text_projector(valid_gaze_feats)
            text_proj = self.gaze_text_projector(valid_text_feats)
        else:
            gaze_proj = F.normalize(valid_gaze_feats, dim=1)
            text_proj = F.normalize(valid_text_feats, dim=1)
        
        gaze_proj = F.normalize(gaze_proj, dim=1)
        text_proj = F.normalize(text_proj, dim=1)
        
        similarity_matrix = torch.matmul(gaze_proj, text_proj.T) / temperature
        
        labels = torch.arange(len(valid_samples), device=gaze_features.device)
        
        loss_gaze_to_text = F.cross_entropy(similarity_matrix, labels)
        loss_text_to_gaze = F.cross_entropy(similarity_matrix.T, labels)
        
        contrastive_loss = (loss_gaze_to_text + loss_text_to_gaze) / 2.0
        
        return contrastive_loss

    def enhanced_spatial_attention_extraction(self, img_features, img_hidden_states=None, vit_attns: Optional[List[torch.Tensor]] = None):
        batch_size = img_features.size(0)
        
        if low_memory or not hasattr(self, 'vit_attention_head'):
            return self.generate_spatial_attention(img_features, (224, 224))
        
        try:
            if vit_attns is not None and len(vit_attns) > 0:
                rollout = self.attention_rollout(vit_attns)
                if rollout is not None and rollout.size(1) > 1:
                    cls_to_patches = rollout[:, 0, 1:]
                    B, NP = cls_to_patches.size(0), cls_to_patches.size(1)
                    ps = int(np.sqrt(NP)) if int(np.sqrt(NP))**2 == NP else 14
                    spatial_grid = cls_to_patches.view(B, 1, ps, ps)
                    spatial_attention = F.interpolate(spatial_grid, size=(224, 224), mode='bilinear', align_corners=False)
                    return spatial_attention

            if img_hidden_states is not None and img_hidden_states.size(1) > 1:
                patch_tokens = img_hidden_states[:, 1:, :]
                
                num_patches = patch_tokens.size(1)
                if num_patches == 196:
                    attended_patches, attention_weights = self.vit_attention_head(
                        patch_tokens, patch_tokens, patch_tokens
                    )
                    
                    if attention_weights.dim() >= 3:
                        if attention_weights.dim() == 4:
                            spatial_weights = attention_weights.mean(dim=1)
                            spatial_weights = torch.diagonal(spatial_weights, dim1=-2, dim2=-1)
                        elif attention_weights.dim() == 3:
                            spatial_weights = torch.diagonal(attention_weights, dim1=-2, dim2=-1)
                        else:
                            spatial_weights = attention_weights
                        
                        if spatial_weights.size(1) == num_patches:
                            patch_size = int(np.sqrt(num_patches))
                            spatial_grid = spatial_weights.view(batch_size, 1, patch_size, patch_size)
                            
                            spatial_attention = F.interpolate(
                                spatial_grid, size=(224, 224), mode='bilinear', align_corners=False
                            )
                        else:
                            spatial_attention = self.generate_spatial_attention(img_features, (224, 224))
                    else:
                        spatial_attention = self.generate_spatial_attention(img_features, (224, 224))
                else:
                    spatial_attention = self.generate_spatial_attention(img_features, (224, 224))
            else:
                spatial_attention = self.generate_spatial_attention(img_features, (224, 224))
                
        except Exception as e:
            logger.warning(f"Error in enhanced spatial attention extraction: {e}")
            spatial_attention = self.generate_spatial_attention(img_features, (224, 224))
        
        return spatial_attention

    def forward(self, img, bbox, fix_seq, fix_mask, transcript, labels=None, contrastive_weight=0.1, gaze_weight=0.3, skip_gaze=None):
        batch_size = img.size(0)
        
        if DEVICE == "cuda" and hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
        
        img_outputs = self.vit(pixel_values=img)
        if hasattr(img_outputs, 'pooler_output') and img_outputs.pooler_output is not None:
            img_feat = img_outputs.pooler_output
        else:
            img_feat = img_outputs.last_hidden_state[:, 0, :]
        
        img_hidden_states = img_outputs.last_hidden_state if not low_memory else None
        vit_attns = list(img_outputs.attentions) if (not low_memory and hasattr(img_outputs, 'attentions') and img_outputs.attentions is not None) else None
        img_proj = self.img_proj(img_feat)
        
        del img_outputs, img_feat
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
        
        bbox_feat = self.bbox_encoder(bbox).flatten(1)
        bbox_proj = self.bbox_proj(bbox_feat)
        del bbox_feat
        
        fix_lens = fix_mask.sum(dim=1).cpu()
        
        if torch.all(fix_lens == 0):
            if low_memory:
                fix_feat_gru = torch.zeros(batch_size, 256, device=fix_seq.device)
            else:
                fix_feat_gru = torch.zeros(batch_size, 1536, device=fix_seq.device)
        else:
            fix_emb = self.fix_emb(fix_seq)
            
            packed_seq = nn.utils.rnn.pack_padded_sequence(
                fix_emb, 
                fix_lens.clamp(min=1),
                batch_first=True, 
                enforce_sorted=False
            )
            
            _, h_n = self.fix_gru(packed_seq)
            fix_feat_gru = h_n.transpose(0, 1).reshape(batch_size, -1)
            
            del fix_emb, packed_seq, h_n
        
        fix_proj = self.fix_proj(fix_feat_gru)
        del fix_feat_gru
        
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
        
        if isinstance(transcript, list):
            try:
                max_length = 256 if low_memory else 512
                transcript_tokens = self.tokenizer(
                    transcript, 
                    return_tensors="pt", 
                    padding=True, 
                    truncation=True, 
                    max_length=max_length
                ).to(img.device)
                
                text_outputs = self.text_encoder(**transcript_tokens)
                text_feat = text_outputs.last_hidden_state[:, 0, :]
                text_proj = self.text_proj(text_feat)
                
                del transcript_tokens, text_outputs, text_feat
                
            except Exception as e:
                logger.warning(f"Error processing transcript: {e}")
                hidden_dim = 512 if low_memory else 768
                text_proj = torch.zeros(batch_size, hidden_dim, device=img.device)
        else:
            hidden_dim = 512 if low_memory else 768
            text_proj = torch.zeros(batch_size, hidden_dim, device=img.device)
        
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
        
        if not low_memory and img_hidden_states is not None:
            spatial_attn_rollout = self.enhanced_spatial_attention_extraction(img_proj, img_hidden_states, vit_attns)
            if spatial_attn_rollout is not None:
                B, _, Hs, Ws = spatial_attn_rollout.size()
                patch_weights = F.interpolate(spatial_attn_rollout, size=(14, 14), mode='bilinear', align_corners=False).flatten(2)
                patch_weights = patch_weights.squeeze(1)
                pw = patch_weights / (patch_weights.sum(dim=1, keepdim=True) + 1e-8)
                entropy = -(pw * (pw + 1e-8).log()).sum(dim=1) / math.log(pw.size(1) + 1e-8)
                gate_strength = 1.0 - entropy
                gate_strength = gate_strength.clamp(0.1, 0.9).unsqueeze(-1)
                patch_tokens = img_hidden_states[:, 1:, :]
                pooled_gaze_img = (pw.unsqueeze(-1) * patch_tokens).sum(dim=1)
            else:
                pw = None
                gate_strength = torch.ones(img_proj.size(0), 1, device=img.device) * 0.3
                pooled_gaze_img = img_proj

            if CONFIG.get("USE_GAZE_WEIGHTED_POOLING", True):
                mix_gate = torch.sigmoid(nn.Linear(pooled_gaze_img.size(1), 1, bias=True).to(img.device)(pooled_gaze_img))
                mix_gate = mix_gate.clamp(0.2, 0.8)
                img_proj_mix = mix_gate * pooled_gaze_img + (1.0 - mix_gate) * img_proj
            else:
                img_proj_mix = img_proj

            if CONFIG.get("USE_GAZE_TOKEN", True):
                gaze_token = self.gaze_token_proj(fix_proj)
                tokens = torch.cat([gaze_token.unsqueeze(1), img_hidden_states[:, 1:, :]], dim=1)
                attended, _ = self.gaze_token_attn(tokens, tokens, tokens)
                attended = self.gaze_token_norm(attended)
                gaze_token_out = attended[:, 0, :]
                fusion_gate = torch.sigmoid(nn.Linear(gaze_token_out.size(1), 1, bias=True).to(img.device)(gaze_token_out)).clamp(0.0, 0.4)
                img_proj_aug = img_proj_mix + fusion_gate * gaze_token_out
            else:
                img_proj_aug = img_proj_mix
        else:
            img_proj_aug = img_proj

        if self.training and (not low_memory) and img_hidden_states is not None:
            with torch.no_grad():
                spatial_attn_rollout = self.enhanced_spatial_attention_extraction(img_proj, img_hidden_states, vit_attns)
                if spatial_attn_rollout is not None:
                    pw = F.interpolate(spatial_attn_rollout, size=(14, 14), mode='bilinear', align_corners=False).flatten(2).squeeze(1)
                    pw = pw / (pw.sum(dim=1, keepdim=True) + 1e-8)
                    entropy = -(pw * (pw + 1e-8).log()).sum(dim=1) / math.log(pw.size(1) + 1e-8)
                    soft_gate = (1.0 - entropy).clamp(0.05, 0.5).unsqueeze(-1).unsqueeze(-1)
                    img_hidden_states[:, 1:, :] = img_hidden_states[:, 1:, :] * (1.0 - soft_gate) + img_hidden_states[:, 1:, :] * soft_gate

        combined = torch.cat([img_proj_aug, bbox_proj, fix_proj, text_proj], dim=1)
        fused = self.fusion(combined)
        del combined
        
        if not low_memory and self.fusion_attention is not None:
            stacked_features = torch.stack([img_proj, bbox_proj, fix_proj, text_proj], dim=1)
            attended_features, attn_weights = self.fusion_attention(stacked_features, stacked_features, stacked_features)
            attended_combined = attended_features.mean(dim=1)
            final_fused = self.fusion_norm(fused + attended_combined)
            
            del stacked_features, attended_features, attn_weights, attended_combined
        else:
            final_fused = fused
        
        del fused
        
        global_logits = self.global_classifier(final_fused)
        
        if not low_memory and self.condition_specific_heads is not None:
            condition_logits = []
            for i, head in enumerate(self.condition_specific_heads):
                cond_logit = head(final_fused)
                condition_logits.append(cond_logit)
            
            condition_logits = torch.cat(condition_logits, dim=1)
            
            alpha = 0.7
            logits_output = alpha * global_logits + (1 - alpha) * condition_logits
            del condition_logits
        else:
            logits_output = global_logits
        
        if not low_memory and img_hidden_states is not None:
            spatial_attention = self.enhanced_spatial_attention_extraction(final_fused, img_hidden_states, vit_attns)
        else:
            spatial_attention = self.generate_spatial_attention(final_fused, img.shape[2:])
        
        attn_map_output = spatial_attention
        
        if img_hidden_states is not None:
            del img_hidden_states
        
        loss = None
        loss_cls = torch.tensor(0.0, device=img.device)
        loss_contrastive = torch.tensor(0.0, device=img.device)
        loss_gaze = torch.tensor(0.0, device=img.device)
        loss_gaze_text_contrastive = torch.tensor(0.0, device=img.device)

        if labels is not None:
            if self.loss_fn is None:
                self.loss_fn = nn.BCEWithLogitsLoss()
                logger.warning("Using fallback BCEWithLogitsLoss!")
            
            loss_cls = self.loss_fn(logits_output, labels)

            if img_proj.shape[0] > 1 and not low_memory:
                loss_img_fix = self.info_nce_loss(img_proj, fix_proj)
                loss_img_text = self.info_nce_loss(img_proj, text_proj)
                loss_contrastive = (loss_img_fix + loss_img_text) / 2.0
            else:
                loss_contrastive = torch.tensor(0.0, device=img.device)

            try:
                gaze_heatmap = self.generate_gaze_heatmap(fix_seq, fix_mask, image_size=(224, 224))
                
                with torch.no_grad():
                    gh = gaze_heatmap.view(gaze_heatmap.size(0), -1)
                    gprob = gh / (gh.sum(dim=1, keepdim=True) + 1e-8)
                    gentropy = -(gprob * (gprob + 1e-8).log()).sum(dim=1) / math.log(gprob.size(1) + 1e-8)
                    gaze_conf = (1.0 - gentropy).clamp(0.1, 1.0)
                bbox_mask = bbox if CONFIG.get("USE_ANATOMY_MASKING", True) else None
                if skip_gaze is not None and skip_gaze.sum() > 0:
                    loss_gaze_raw = torch.tensor(0.0, device=img.device)
                else:
                    loss_gaze_raw = self.gaze_attention_loss(
                        spatial_attention, gaze_heatmap, fix_seq, fix_mask, bbox_mask=bbox_mask, class_weights=COND_GAZE_WEIGHTS.to(img.device)
                    )
                loss_gaze = (gaze_conf.mean() * loss_gaze_raw)
                
            except Exception as e:
                logger.warning(f"Error computing gaze loss: {e}")
                loss_gaze = torch.tensor(0.0, device=img.device)

            loss_gaze_text_contrastive = torch.tensor(0.0, device=img.device)

            if low_memory:
                loss = loss_cls + gaze_weight * self._gaze_schedule() * loss_gaze
            else:
                gaze_text_weight = 0.0
                enhanced_gaze_weight = gaze_weight * 1.2 * self._gaze_schedule()
                
                loss = (loss_cls + 
                       contrastive_weight * loss_contrastive + 
                       enhanced_gaze_weight * loss_gaze + 
                       gaze_text_weight * loss_gaze_text_contrastive)

        del img_proj, img_proj_aug, bbox_proj, fix_proj, text_proj, final_fused
        
        if DEVICE == "cuda":
            torch.cuda.empty_cache()

        output = {
            "logits": logits_output,
            "zi": None,
            "zg": None,
            "zb": None, 
            "zt": None,
            "attn_map": attn_map_output,
            "loss_cls": loss_cls.detach() if loss is not None else None,
            "loss_con": loss_contrastive.detach() if loss is not None else None,
            "loss_gaze": loss_gaze.detach() if loss is not None else None,
            "loss_gaze_text": loss_gaze_text_contrastive.detach() if loss is not None else None
        }
        if loss is not None:
            output["loss"] = loss

        return output

    def _gaze_schedule(self):
        step = float(self.training_step.item()) if hasattr(self, 'training_step') else 0.0
        k = 0.01
        x0 = 500.0
        val = 1.0 / (1.0 + math.exp(-k * (step - x0)))
        return float(val)

############################
# Advanced Loss Functions #
############################
class FocalLoss(nn.Module):
    # Focal loss supporting label smoothing and pos_weight
    def __init__(self, alpha=1, gamma=2, pos_weight=None, label_smoothing=0.1):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.pos_weight = pos_weight
        self.label_smoothing = label_smoothing
        
    def forward(self, inputs, targets):
        # Apply label smoothing
        targets_smooth = targets * (1 - self.label_smoothing) + 0.5 * self.label_smoothing
        
        # Compute BCE loss with pos_weight
        if self.pos_weight is not None:
            bce_loss = F.binary_cross_entropy_with_logits(inputs, targets_smooth, 
                                                        pos_weight=self.pos_weight, reduction='none')
        else:
            bce_loss = F.binary_cross_entropy_with_logits(inputs, targets_smooth, reduction='none')
        
        # Compute probabilities
        p_t = torch.sigmoid(inputs)
        p_t = torch.where(targets_smooth >= 0.5, p_t, 1 - p_t)
        
        # Compute focal weight
        focal_weight = (1 - p_t) ** self.gamma
        
        # Apply focal weight
        focal_loss = self.alpha * focal_weight * bce_loss
        
        return focal_loss.mean()

class AsymmetricLoss(nn.Module):
    # Asymmetric loss variant for multi-label classification
    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
        super().__init__()
        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

    def forward(self, x, y):
        # Calculating Probabilities
        x_sigmoid = torch.sigmoid(x)
        xs_pos = x_sigmoid
        xs_neg = 1 - x_sigmoid

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            xs_neg = (xs_neg + self.clip).clamp(max=1)

        # Basic CE calculation
        los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
        los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
        loss = los_pos + los_neg

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            pt0 = xs_pos * y
            pt1 = xs_neg * (1 - y)
            pt = pt0 + pt1
            one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
            one_sided_w = torch.pow(1 - pt, one_sided_gamma)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            loss *= one_sided_w

        return -loss.sum()

###############################
# Cross-Modal Attention Module #
###############################
class CrossModalAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, dropout=0.1):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, query, key, value, attn_mask=None):
        attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
        x = self.norm1(query + attn_output)
        
        ffn_output = self.ffn(x)
        x = self.norm2(x + ffn_output)
        
        return x

def compute_metrics(eval_pred: EvalPrediction):
    predictions_dict = eval_pred.predictions 
    labels_raw = eval_pred.label_ids

    if isinstance(predictions_dict, dict):
         logits = predictions_dict['logits']
    elif isinstance(predictions_dict, tuple):
         logits = predictions_dict[0] 
    else:
         logits = predictions_dict
    
    if isinstance(logits, torch.Tensor):
        logits = logits.cpu().numpy()
    elif not isinstance(logits, np.ndarray):
        logits = np.array(logits)

    if isinstance(labels_raw, torch.Tensor):
        labels = labels_raw.cpu().numpy()
    elif not isinstance(labels_raw, np.ndarray):
        labels = np.array(labels_raw)
    else:
        labels = labels_raw

    probs = 1 / (1 + np.exp(-logits))
    
    optimal_thresholds = []
    for i in range(labels.shape[1]):
        best_threshold = 0.5
        best_f1 = 0.0
        
        for threshold in np.arange(0.38, 0.78, 0.02):
            temp_preds = (probs[:, i] > threshold).astype(int)
            temp_f1 = f1_score(labels[:, i], temp_preds, zero_division=0)
            temp_precision = precision_score(labels[:, i], temp_preds, zero_division=0)
            
            if temp_precision > 0.3:
                adjusted_score = 0.85 * temp_f1 + 0.15 * temp_precision
            else:
                adjusted_score = temp_f1
            
            if adjusted_score > best_f1:
                best_f1 = adjusted_score
                best_threshold = threshold
        
        optimal_thresholds.append(best_threshold)
    
    preds = np.zeros_like(probs)
    for i in range(labels.shape[1]):
        preds[:, i] = (probs[:, i] > optimal_thresholds[i]).astype(int)
    
    preds_05 = (probs > 0.5).astype(int)
    
    try:
        micro_f1 = f1_score(labels.flatten(), preds.flatten(), zero_division=0)
        micro_rec = recall_score(labels.flatten(), preds.flatten(), zero_division=0)
        micro_prec = precision_score(labels.flatten(), preds.flatten(), zero_division=0)
        
        macro_f1 = f1_score(labels, preds, average="macro", zero_division=0)
        macro_rec = recall_score(labels, preds, average="macro", zero_division=0)
        macro_prec = precision_score(labels, preds, average="macro", zero_division=0)
        
        exact_match_acc = np.mean(np.all(labels == preds, axis=1))
        
        hamming_acc = np.mean(labels == preds)
        
        subset_acc = exact_match_acc
        
        if labels.shape == probs.shape:
            auc_scores = [roc_auc_score(labels[:, i], probs[:, i]) for i in range(labels.shape[1]) if len(np.unique(labels[:, i])) > 1]
            mean_auc = np.mean(auc_scores) if auc_scores else 0.0
        else:
            mean_auc = 0.0
            
        micro_f1_05 = f1_score(labels.flatten(), preds_05.flatten(), zero_division=0)
        
        logger.info(f"Optimal thresholds: {[f'{t:.2f}' for t in optimal_thresholds]}")
        logger.info(f"Micro F1 with optimal thresholds: {micro_f1:.3f} vs Micro F1 with 0.5: {micro_f1_05:.3f}")
            
    except Exception as e:
        logger.error(f"Error calculating metrics: {e}")
        micro_f1, micro_rec, micro_prec, mean_auc = 0.0, 0.0, 0.0, 0.0
        exact_match_acc, hamming_acc, subset_acc = 0.0, 0.0, 0.0
        macro_f1, macro_rec, macro_prec = 0.0, 0.0, 0.0
        
    return {
        "auc": mean_auc, 
        "f1": micro_f1,
        "recall": micro_rec, 
        "precision": micro_prec,
        "macro_f1": macro_f1,
        "macro_recall": macro_rec,
        "macro_precision": macro_prec,
        "accuracy": exact_match_acc,  # Exact match accuracy
        "hamming_accuracy": hamming_acc,  # Per-label accuracy
        "subset_accuracy": subset_acc  # Same as exact match
    }

#################
# === Main ===  #
#################
def main():
    global DEVICE
    
    logger.info("Starting ENHANCED GAZE-GUIDED MIMIC training with Priority Implementation")
    logger.info(f"Using device: {DEVICE}")
    logger.info(f"Priority Features: Enhanced gaze supervision + Contrastive gaze-text alignment + Multi-level attention")
    logger.info(f"Loss Components: Classification + Traditional Contrastive + Enhanced Gaze + NEW Gaze-Text Contrastive")
    logger.info(f"Gaze Enhancements: Multi-component loss (MSE+KL+Correlation+CenterMass) + Quality weighting + Temporal decay")
    
    model = MultiModalMIMICModel(len(CONDITIONS))
    
    logger.info("Calculating class weights for imbalanced data...")
    train_labels = np.array([row for row in train_df["cond_vec"].values])
    
    pos_counts = train_labels.sum(axis=0)
    neg_counts = len(train_labels) - pos_counts
    
    pos_weights = []
    for i in range(len(CONDITIONS)):
        if pos_counts[i] > 0:
            total_samples = len(train_labels)
            pos_ratio = pos_counts[i] / total_samples
            pos_weight = (1.0 - pos_ratio) / (pos_ratio + 0.07)
            pos_weight = min(pos_weight, 16.0)
            pos_weight = max(pos_weight, 0.6)
        else:
            pos_weight = 1.0
        pos_weights.append(pos_weight)
    
    pos_weights = torch.tensor(pos_weights, dtype=torch.float32, device=DEVICE)
    
    model.set_loss_function(pos_weights, loss_type="focal")
    
    logger.info("\nClass Distribution and Weights:")
    for i, condition in enumerate(CONDITIONS):
        logger.info(f"{condition}: {int(pos_counts[i])}/{len(train_labels)} samples ({pos_counts[i]/len(train_labels)*100:.1f}%) - weight: {pos_weights[i]:.2f}")
    
    model.load_chexpert_weights(CHEXPERT_MODEL_PATH)
    
    data_collator = MultiModalDataCollator()
    
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM,
        warmup_steps=WARMUP_STEPS,
        learning_rate=LR,
        weight_decay=WD,
        logging_dir=os.path.join(OUTPUT_DIR, "logs"),
        eval_strategy="steps",
        eval_steps=30,
        save_strategy="steps",
        save_steps=30,
        save_total_limit=3,
        load_best_model_at_end=True,
        metric_for_best_model="eval_f1",
        greater_is_better=True,
        dataloader_num_workers=DATALOADER_WORKERS,
        remove_unused_columns=False,
        report_to="none",
        logging_steps=15,
        max_grad_norm=0.5,
        lr_scheduler_type="cosine",
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.98,
        adam_epsilon=1e-8,
        dataloader_pin_memory=not low_memory,
        fp16=FP16_ENABLED,
        dataloader_persistent_workers=not low_memory,
        eval_accumulation_steps=6,
        save_on_each_node=False,
        prediction_loss_only=False,
        include_inputs_for_metrics=False,
        ignore_data_skip=True,
    )
    
    callbacks = [EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)]
    
    resume_checkpoint = find_latest_checkpoint(OUTPUT_DIR)
    if resume_checkpoint:
        logger.info(f"Found existing checkpoint: {resume_checkpoint}")
        logger.info("Will resume training from this checkpoint...")
    else:
        logger.info("No existing checkpoint found. Starting training from scratch.")
    
    try:
        trainer = ContiguousSavingTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
            callbacks=callbacks
        )
        
        logger.info("Starting training...")
        trainer.train(resume_from_checkpoint=resume_checkpoint)
        
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            logger.error("CUDA out of memory during training!")
            logger.info("Attempting recovery with CPU training...")
            
            DEVICE = "cpu"
            model = model.to(DEVICE)
            
            training_args.per_device_train_batch_size = max(1, BATCH_SIZE // 2)
            training_args.per_device_eval_batch_size = max(1, BATCH_SIZE // 2)
            training_args.gradient_accumulation_steps = GRAD_ACCUM * 2
            training_args.dataloader_num_workers = 0
            training_args.fp16 = False
            training_args.dataloader_pin_memory = False
            training_args.dataloader_persistent_workers = False
            
            trainer = ContiguousSavingTrainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=val_dataset,
                data_collator=data_collator,
                compute_metrics=compute_metrics,
                callbacks=callbacks
            )
            
            logger.info("Resuming training with CPU...")
            # Resume from checkpoint even in CPU mode
            trainer.train(resume_from_checkpoint=resume_checkpoint)
        else:
            raise e
    
    trainer.save_model(os.path.join(OUTPUT_DIR, f"model_{EXP_NAME}"))
    logger.info(f"Model saved to {os.path.join(OUTPUT_DIR, f'model_{EXP_NAME}')}")
    
    print("=== Final Test Evaluation ===")
    test_metrics = trainer.evaluate(test_dataset)
    
    logger.info(f"Test Results Summary:")
    logger.info(f"F1 Score: {test_metrics['eval_f1']:.3f}")
    logger.info(f"AUC: {test_metrics['eval_auc']:.3f}")
    logger.info(f"Precision: {test_metrics['eval_precision']:.3f}")
    logger.info(f"Recall: {test_metrics['eval_recall']:.3f}")
    logger.info(f"Exact Match Accuracy: {test_metrics['eval_accuracy']:.3f}")
    logger.info(f"Hamming Accuracy: {test_metrics['eval_hamming_accuracy']:.3f}")
    
    
    logger.info("Calculating per-condition test metrics...")
    test_predictions = trainer.predict(test_dataset)
    predictions_dict = test_predictions.predictions
    if isinstance(predictions_dict, dict):
        logits = predictions_dict['logits']
    else:
        logits = predictions_dict
    labels = test_predictions.label_ids
    
    if isinstance(logits, torch.Tensor):
        logits = logits.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    probs = 1 / (1 + np.exp(-logits))
    y_pred = (probs > 0.5).astype(int)
    
    logger.info("\nTest Set Metrics per Condition:")
    logger.info("-" * 80)
    headers = ["Condition", "Precision", "Recall", "F1-score", "Accuracy", "Support", "AUC"]
    rows = []
    
    for i, condition in enumerate(CONDITIONS):
        precision = precision_score(labels[:, i], y_pred[:, i], zero_division=0)
        recall = recall_score(labels[:, i], y_pred[:, i], zero_division=0)
        f1 = f1_score(labels[:, i], y_pred[:, i], zero_division=0)
        accuracy = np.mean(labels[:, i] == y_pred[:, i])
        support = np.sum(labels[:, i])
        try:
            auc = roc_auc_score(labels[:, i], probs[:, i]) if len(np.unique(labels[:, i])) > 1 else np.nan
        except ValueError:
            auc = np.nan
            
        rows.append([
            condition,
            f"{precision:.3f}",
            f"{recall:.3f}",
            f"{f1:.3f}",
            f"{accuracy:.3f}",
            f"{int(support)}",
            f"{auc:.3f}" if not np.isnan(auc) else "N/A"
        ])
    
    rows.append([
        "Macro Avg",
        f"{test_metrics['eval_precision']:.3f}",
        f"{test_metrics['eval_recall']:.3f}",
        f"{test_metrics['eval_f1']:.3f}",
        f"{test_metrics['eval_hamming_accuracy']:.3f}",
        f"{len(test_dataset)} (Total Samples)",
        f"{test_metrics['eval_auc']:.3f}"
    ])
    
    table_str = tabulate(rows, headers=headers, tablefmt="grid")
    logger.info("\n" + table_str)
    
    metrics_table_path = os.path.join(OUTPUT_DIR, "training_mimic_on_chexpert_metrics_table.txt")
    with open(metrics_table_path, "w") as f:
        f.write("MIMIC Training with CheXpert Transfer Learning - Test Metrics per Condition:\n")
        f.write("-" * 80 + "\n")
        f.write(table_str)
    
    metrics_json_path = os.path.join(OUTPUT_DIR, "metrics_training_mimic_on_chexpert.json")
    with open(metrics_json_path, "w") as f:
        json.dump(test_metrics, f, indent=4)
    logger.info(f"Overall test metrics saved to {metrics_json_path}")
    
    logger.info("Creating CSV with real and predicted labels...")
    csv_data = []
    batch_size = logits.shape[0]
    for i in range(batch_size):
        row = {"unique_id": f"sample_{i}"}
        
        for j, condition in enumerate(CONDITIONS):
            row[f"real_{condition}"] = int(labels[i][j])
        
        for j, condition in enumerate(CONDITIONS):
            row[f"pred_{condition}"] = int(y_pred[i][j])
            
        for j, condition in enumerate(CONDITIONS):
            row[f"prob_{condition}"] = float(probs[i][j])
            
        csv_data.append(row)
    
    csv_df = pd.DataFrame(csv_data)
    csv_path = os.path.join(OUTPUT_DIR, "test_predictions_and_labels.csv")
    csv_df.to_csv(csv_path, index=False)
    logger.info(f"Saved predictions and labels CSV to: {csv_path}")
    
    logger.info("Training complete")

def aggressive_memory_cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        if hasattr(torch.cuda, 'reset_peak_memory_stats'):
            torch.cuda.reset_peak_memory_stats()

def monitor_memory():
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / (1024**3)
        memory_reserved = torch.cuda.memory_reserved() / (1024**3)
        
        if memory_allocated > 10.0:
            print("High CUDA memory usage detected - performing cleanup")
            aggressive_memory_cleanup()
    
    memory = psutil.virtual_memory()
    available_gb = memory.available / (1024**3)
    if available_gb < 20.0:
        print(f"Low system memory: {available_gb:.1f}GB available")
        gc.collect()

def find_latest_checkpoint(output_dir):
    import glob
    checkpoint_pattern = os.path.join(output_dir, "checkpoint-*")
    checkpoints = glob.glob(checkpoint_pattern)
    
    if not checkpoints:
        return None
    
    def get_checkpoint_num(path):
        try:
            return int(os.path.basename(path).split('-')[1])
        except (ValueError, IndexError):
            return 0
    
    latest_checkpoint = max(checkpoints, key=get_checkpoint_num)
    checkpoint_num = get_checkpoint_num(latest_checkpoint)
    
    logger.info(f"Found latest checkpoint: {latest_checkpoint} (step {checkpoint_num})")
    return latest_checkpoint

import os

if __name__=="__main__":
    main()
