#!/usr/bin/env python
"""
Implicit Ensemble Methods for DINO ViT on Vision Datasets

This script implements multiple ensemble approaches for Vision Transformers
using a shared backbone and unified training / evaluation pipeline.

================================================================================
SUPPORTED METHODS
================================================================================

1. SVF Implicit Ensemble (--method svf)
   - Singular Value Fine-tuning: share U, Vh; learn per-member singular values
   - W_m = U @ diag(s_m) @ Vh
   - Very parameter-efficient: only k singular values per layer per member
   - Multiplicative perturbation of pretrained weight directions

2. LoRA Implicit Ensemble (--method lora)
   - Low-Rank Adaptation: frozen base W; learn per-member low-rank adapters
   - W_m = W_base + B_m @ A_m * (alpha / r)
   - Additive perturbation of frozen weights
   - Only LoRA parameters and classifier head are trained

3. BatchEnsemble (--method batch_ensemble)
   - Reference: "BatchEnsemble: An Alternative Approach to Efficient Ensemble
     and Lifelong Learning" (Wen et al., 2020)
   - W_m = W ⊙ (r_m ⊗ s_m) where:
       * W is shared weight matrix (TRAINABLE, initialized from pretrained)
       * r_m is per-member output scaling vector [out_features]
       * s_m is per-member input scaling vector [in_features]
       * ⊗ is outer product, ⊙ is element-wise multiplication
   - Efficient computation: y_m = (x * s_m) @ W^T * r_m
   - Initialization: r_m, s_m ≈ 1 + noise, so W_m ≈ W initially
   - Key difference from LoRA: multiplicative perturbation and trains base W

4. Deep Ensemble (--method deep_ensemble)
   - Train multiple independent networks with different random seeds
   - Gold standard for uncertainty but expensive (N× parameters, N× FLOPs)

5. MC Dropout (--method mc_dropout)
   - Single network with dropout enabled at test time
   - Multiple stochastic forward passes for uncertainty estimation

6. Single Model (--method single)
   - Standard single-model baseline for comparison

================================================================================
COMPARISON OF IMPLICIT ENSEMBLE METHODS
================================================================================

| Method         | Base Weights      | Perturbation Type | Params / Member / Layer |
|----------------|-------------------|-------------------|-------------------------|
| SVF            | Frozen (U, Vh)    | Multiplicative    | k singular values       |
| LoRA           | Frozen            | Additive          | r × (in + out)          |
| BatchEnsemble  | Trainable         | Multiplicative    | in + out                |

================================================================================
USAGE EXAMPLES
================================================================================

# Single model baseline
python dino_svf_ensemble_flowers.py --method single --dataset flowers102

# SVF Ensemble (4 members, default settings)
python dino_svf_ensemble_flowers.py --method svf --n_members 4

# SVF with custom settings
python dino_svf_ensemble_flowers.py --method svf --n_members 8 --topk 64 --svf_scope attn

# LoRA Ensemble
python dino_svf_ensemble_flowers.py --method lora --n_members 4 --lora_r 16

# BatchEnsemble (4 members)
python dino_svf_ensemble_flowers.py --method batch_ensemble --n_members 4

# BatchEnsemble with custom settings
python dino_svf_ensemble_flowers.py --method batch_ensemble --n_members 8 \
    --be_scope attn --be_init_std 0.01

# Deep Ensemble (expensive but strong baseline)
python dino_svf_ensemble_flowers.py --method deep_ensemble --n_members 4

# MC Dropout
python dino_svf_ensemble_flowers.py --method mc_dropout --mc_samples 10

# Enable logging
python dino_svf_ensemble_flowers.py --method svf --save_log

# Different backbones
python dino_svf_ensemble_flowers.py --method svf --backbone vit_base_patch14_dinov2.lvd142m
python dino_svf_ensemble_flowers.py --method svf --backbone vit_large_patch14_dinov2.lvd142m

# Different datasets
python dino_svf_ensemble_flowers.py --method svf --dataset food101
python dino_svf_ensemble_flowers.py --method svf --dataset dtd
python dino_svf_ensemble_flowers.py --method svf --dataset cifar100

================================================================================
KEY ARGUMENTS
================================================================================

General:
  --method            : svf, lora, batch_ensemble, deep_ensemble, mc_dropout, single
  --dataset           : flowers102, food101, dtd, cifar100, stanford_cars, aircraft, pets
  --backbone          : Any timm ViT backbone
                        (default: vit_small_patch14_dinov2.lvd142m)
  --mode              : lp or ft (default: ft)
                        lp = linear probing (train head only)
                        ft = fine-tune adapters (SVF / LoRA / BE) + head
  --n_members         : Number of ensemble members (ignored for --method single)
  --epochs            : Training epochs (default: 20)
  --lr                : Learning rate
                        (default: 1e-3 for SVF, 1e-4 for others)
  --batch_size        : Batch size
  --weight_decay      : Weight decay for AdamW
  --warmup_epochs     : Learning-rate warmup epochs
  --grad_clip         : Gradient norm clipping (default: 1.0)
  --seed              : Random seed
  --amp               : Enable mixed-precision training
  --image_size        : Override backbone input resolution
  --save_log          : Save logs and result summary to logs/ directory

Data:
  --data_root         : Root directory for datasets (default: ./data)
                        Datasets are downloaded automatically if missing
  --label_fraction    : Fraction of training data to use
  --val_fraction      : Fraction of training data used for validation

SVF-specific:
  --topk              : Number of singular values to keep (default: all)
  --svf_scope         : attn, mlp, or attn_mlp (default: attn_mlp)
  --svf_init_mean     : Mean initialization for singular values (default: 0.0)
  --svf_init_std      : Initialization noise for singular values (default: 0.01)

LoRA-specific:
  --lora_r            : LoRA rank (default: 16)
  --lora_alpha        : LoRA scaling factor (default: 32.0)
  --lora_dropout      : Dropout applied to LoRA adapters (default: 0.0)
  --lora_scope        : attn, mlp, or attn_mlp (default: attn_mlp)

BatchEnsemble-specific:
  --be_scope          : attn, mlp, or attn_mlp (default: attn_mlp)
  --be_init_std       : Initialization noise for r/s vectors (default: 0.02)

MC Dropout-specific:
  --mc_samples        : Number of MC samples at test time (default: 10)
  --mc_dropout_rate   : Dropout rate (default: 0.2)

Evaluation:
  --cifar100c_eval    : Evaluate on CIFAR-100-C (severity levels 1–5)
  --ood_eval          : OOD detection using MSP (CIFAR-10 and SVHN)

Notes:
  - Method-specific arguments are ignored when not applicable.
  - Implicit ensembles (SVF, LoRA, BatchEnsemble) process all members in a
    single forward pass via batch expansion.
  - Deep ensembles train and evaluate independent models.
"""

import argparse
import random
import math
import copy
import time
import sys
import logging
from datetime import datetime
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.datasets import (
    Flowers102, StanfordCars, FGVCAircraft, Food101,
    CIFAR100, DTD, OxfordIIITPet
)
from torchvision.datasets import CIFAR10, SVHN
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from torchvision import transforms
from PIL import Image
import os

import timm

# Optional imports for CLIP models
try:
    import open_clip
    OPEN_CLIP_AVAILABLE = True
except ImportError:
    OPEN_CLIP_AVAILABLE = False

try:
    from transformers import CLIPModel, CLIPProcessor
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False


# Checkpoint directory
CHECKPOINT_DIR = "storage"

def get_checkpoint_path(filename):
    """Get full checkpoint path in /storage directory."""
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    return os.path.join(CHECKPOINT_DIR, filename)


# Logging Setup
LOG_DIR = "logs"


class TeeLogger:
    """
    Logger that writes to both stdout and a file.
    """
    def __init__(self, filepath):
        self.terminal = sys.stdout
        os.makedirs(os.path.dirname(filepath) if os.path.dirname(filepath) else ".", exist_ok=True)
        self.log = open(filepath, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()


def generate_experiment_name(args):
    """
    Generate experiment name based on arguments.

    Format: {dataset}_{method}_{backbone_short}_M{n_members}_seed{seed}_{method_specific}_{timestamp}

    Examples:
    - flowers102_svf_small_p14_d2_M4_s42_topkAll_sc-attmlp_lr1e-3_ft_ep20_bs32_20231215_143052
    - flowers102_lora_small_p14_d2_M4_s42_r16_a32_sc-attmlp_lr1e-4_ft_ep20_bs32_20231215_143052
    """
    parts = []

    # Dataset
    parts.append(args.dataset)

    # Method
    parts.append(args.method)

    # Backbone (shortened)
    backbone_short = args.backbone.replace("vit_", "").replace("_224", "").replace("_518", "")
    backbone_short = backbone_short.replace("patch", "p").replace(".dino", "_dino")
    backbone_short = backbone_short.replace(".lvd142m", "").replace(".lvd1689m", "")
    backbone_short = backbone_short.replace("dinov2", "d2").replace("dinov3", "d3")
    backbone_short = backbone_short[:25]  # Limit length
    parts.append(backbone_short)

    # Number of members
    if args.method in ["svf", "lora", "deep_ensemble", "batch_ensemble"]:
        parts.append(f"M{args.n_members}")
    else:
        parts.append("M1")

    # Seed
    parts.append(f"s{args.seed}")

    # Method-specific parameters
    if args.method == "svf":
        topk_str = f"topk{args.topk}" if args.topk is not None else "topkAll"
        parts.append(topk_str)
        parts.append(f"sc-{args.svf_scope.replace('_', '')}")
    elif args.method == "lora":
        parts.append(f"r{args.lora_r}")
        parts.append(f"a{int(args.lora_alpha)}")
        parts.append(f"sc-{args.lora_scope.replace('_', '')}")
    elif args.method == "mc_dropout":
        parts.append(f"drop{args.mc_dropout_rate}")
        parts.append(f"samp{args.mc_samples}")

    # Learning rate
    parts.append(f"lr{args.lr:.0e}".replace("e-0", "e-"))

    # Mode (lp or ft)
    parts.append(args.mode)

    # Epochs
    parts.append(f"ep{args.epochs}")

    # Batch size
    parts.append(f"bs{args.batch_size}")

    # Timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    parts.append(timestamp)

    return "_".join(parts)


def setup_logging(args):
    """
    Setup logging to both console and file.
    Same experiment config will overwrite previous log file.

    Returns:
        log_filepath: Path to the log file
        tee_logger: TeeLogger instance (call close() when done)
        exp_name: Experiment name
    """
    global LOG_DIR
    LOG_DIR = args.log_dir
    os.makedirs(LOG_DIR, exist_ok=True)

    # Use custom exp_name if provided, else generate
    if args.exp_name is not None:
        exp_name = args.exp_name
    else:
        exp_name = generate_experiment_name(args)

    log_filepath = os.path.join(LOG_DIR, f"{exp_name}.txt")

    # Create TeeLogger to write to both stdout and file (overwrites existing)
    tee_logger = TeeLogger(log_filepath)
    sys.stdout = tee_logger

    # Print experiment info header
    print("=" * 100)
    print(f"EXPERIMENT LOG: {exp_name}")
    print("=" * 100)
    print(f"Log file: {log_filepath}")
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("=" * 100)
    print("\n📋 CONFIGURATION:")
    print("-" * 50)
    for key, value in sorted(vars(args).items()):
        print(f"  {key}: {value}")
    print("-" * 50 + "\n")

    return log_filepath, tee_logger, exp_name


def save_results_summary(args, results_dict, log_filepath):
    """
    Save a summary of results to a separate JSON file for easy parsing.

    Args:
        args: Argument namespace
        results_dict: Dictionary containing results
        log_filepath: Path to the log file (used to derive summary path)
    """
    import json

    summary = {
        "experiment_name": os.path.basename(log_filepath).replace(".txt", ""),
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "config": {k: v for k, v in vars(args).items() if not k.startswith("_")},
        "results": results_dict,
    }

    # Convert any non-serializable types
    def make_serializable(obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, (np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.int32, np.int64)):
            return int(obj)
        elif isinstance(obj, torch.Tensor):
            return obj.cpu().numpy().tolist()
        return obj

    def convert_dict(d):
        if isinstance(d, dict):
            return {k: convert_dict(v) for k, v in d.items()}
        elif isinstance(d, list):
            return [convert_dict(v) for v in d]
        else:
            return make_serializable(d)

    summary = convert_dict(summary)

    summary_filepath = log_filepath.replace(".txt", "_summary.json")
    with open(summary_filepath, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\n📄 Results summary saved to: {summary_filepath}")
    return summary_filepath


# Custom Dataset for Stanford Cars from Kaggle (manual download)
class StanfordCarsLocal(Dataset):
    """
    Stanford Cars dataset loaded from local Kaggle download.

    Expected structure in data_root/stanford_cars/:
        train/
            class_folder_1/
                image1.jpg
                image2.jpg
            class_folder_2/
                ...
        test/
            class_folder_1/
                ...

    Download from: https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset
    """
    def __init__(self, root, split="train", transform=None):
        self.root = os.path.join(root, "stanford_cars", split)
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}

        if not os.path.exists(self.root):
            raise RuntimeError(
                f"Stanford Cars not found at {self.root}. "
                f"Please download from Kaggle and extract to {os.path.join(root, 'stanford_cars')}/\n"
                f"Expected structure: stanford_cars/train/<class_folders>/ and stanford_cars/test/<class_folders>/"
            )

        # Build class to index mapping
        classes = sorted(os.listdir(self.root))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        self.classes = classes

        # Collect all samples
        for class_name in classes:
            class_dir = os.path.join(self.root, class_name)
            if not os.path.isdir(class_dir):
                continue
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, self.class_to_idx[class_name]))

        print(f"Loaded Stanford Cars {split}: {len(self.samples)} images, {len(classes)} classes")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label


class CIFAR100C(Dataset):
    def __init__(self, root_dir, severity=1, transform=None):
        self.severity = severity
        self.transform = transform

        labels = np.load(os.path.join(root_dir, 'labels.npy'))

        self.images = []
        self.targets = []
        corruptions = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy') and f != 'labels.npy'])
        for fname in corruptions:
            arr = np.load(os.path.join(root_dir, fname))  # (50000, 32, 32, 3) uint8
            n = arr.shape[0] // 5
            start = (severity - 1) * n
            end = severity * n
            self.images.append(arr[start:end])
            self.targets.append(labels[start:end])

        self.images = np.concatenate(self.images, axis=0)
        self.targets = np.concatenate(self.targets, axis=0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]          # numpy HWC uint8
        target = int(self.targets[idx])

        # Convert to PIL so torchvision transforms work as-is
        img = Image.fromarray(img)

        if self.transform:
            img = self.transform(img)

        target = torch.tensor(target, dtype=torch.int64)
        return img, target


# Available Datasets
AVAILABLE_DATASETS = {
    "flowers102": {
        "description": "Oxford Flowers 102",
        "num_classes": 102,
        "train_split": "train",
        "val_split": "val", 
        "test_split": "test",
        "dataset_class": Flowers102,
        "has_val": True,
        "notes": "Fine-grained flower classification. Small dataset (1020 train).",
    },
    "cars": {
        "description": "Stanford Cars (Local/Kaggle)",
        "num_classes": 196,
        "train_split": "train",
        "val_split": None,
        "test_split": "test",
        "dataset_class": StanfordCarsLocal,
        "has_val": False,
        "notes": "Fine-grained car classification. Download from Kaggle to ./data/stanford_cars/",
    },
    "aircraft": {
        "description": "FGVC Aircraft",
        "num_classes": 100,
        "train_split": "train",
        "val_split": "val",
        "test_split": "test", 
        "dataset_class": FGVCAircraft,
        "has_val": True,
        "notes": "Fine-grained aircraft classification. 100 aircraft variants.",
    },
    "food101": {
        "description": "Food-101",
        "num_classes": 101,
        "train_split": "train",
        "val_split": None,
        "test_split": "test",
        "dataset_class": Food101,
        "has_val": False,
        "notes": "Food classification. 101 food categories, 1000 images per class.",
    },
    "cifar100": {
        "description": "CIFAR-100",
        "num_classes": 100,
        "train_split": "train",
        "val_split": None,
        "test_split": "test",
        "dataset_class": CIFAR100,
        "has_val": False,
        "notes": "100 classes, 32x32 images. Standard benchmark (will be resized).",
    },
    "dtd": {
        "description": "Describable Textures (DTD)",
        "num_classes": 47,
        "train_split": "train",
        "val_split": "val",
        "test_split": "test",
        "dataset_class": DTD,
        "has_val": True,
        "notes": "Texture classification. 47 texture categories.",
    },
    "pets": {
        "description": "Oxford-IIIT Pets",
        "num_classes": 37,
        "train_split": "trainval",
        "val_split": None,
        "test_split": "test",
        "dataset_class": OxfordIIITPet,
        "has_val": False,
        "notes": "Fine-grained pet classification. 37 cat/dog breeds.",
    },
}


def print_available_datasets():
    """Print all available datasets with their specifications."""
    print("\n" + "=" * 80)
    print("📋 AVAILABLE DATASETS")
    print("=" * 80)

    for name, info in AVAILABLE_DATASETS.items():
        print(f"\n  --dataset {name}")
        print(f"      {info['description']} | {info['num_classes']} classes")
        print(f"      {info['notes']}")

    print("\n" + "=" * 80 + "\n")


# Available DINO/DINOv2/DINOv3 Models in timm
AVAILABLE_BACKBONES = {
    # DINO v1 Models
    "vit_small_patch16_224.dino": {
        "description": "DINO ViT-Small/16",
        "params": "21.7M",
        "embed_dim": 384,
        "num_heads": 6,
        "num_layers": 12,
        "patch_size": 16,
        "input_size": 224,
        "notes": "Good balance of speed and accuracy. Default choice.",
    },
    "vit_small_patch8_224.dino": {
        "description": "DINO ViT-Small/8",
        "params": "21.7M",
        "embed_dim": 384,
        "num_heads": 6,
        "num_layers": 12,
        "patch_size": 8,
        "input_size": 224,
        "notes": "Same params as /16 but 4x more patches. Better for dense tasks.",
    },
    "vit_base_patch16_224.dino": {
        "description": "DINO ViT-Base/16",
        "params": "85.8M",
        "embed_dim": 768,
        "num_heads": 12,
        "num_layers": 12,
        "patch_size": 16,
        "input_size": 224,
        "notes": "Larger model, better features, 4x params of Small.",
    },
    "vit_base_patch8_224.dino": {
        "description": "DINO ViT-Base/8",
        "params": "85.8M",
        "embed_dim": 768,
        "num_heads": 12,
        "num_layers": 12,
        "patch_size": 8,
        "input_size": 224,
        "notes": "Base model with fine-grained patches. Memory intensive.",
    },
    # DINOv2 Models
    "vit_small_patch14_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Small/14",
        "params": "22M",
        "embed_dim": 384,
        "num_heads": 6,
        "num_layers": 12,
        "patch_size": 14,
        "input_size": 518,
        "notes": "DINOv2 small. Distilled from ViT-g. Great features.",
    },
    "vit_base_patch14_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Base/14",
        "params": "86M",
        "embed_dim": 768,
        "num_heads": 12,
        "num_layers": 12,
        "patch_size": 14,
        "input_size": 518,
        "notes": "DINOv2 base. Distilled from ViT-g. Recommended.",
    },
    "vit_large_patch14_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Large/14",
        "params": "300M",
        "embed_dim": 1024,
        "num_heads": 16,
        "num_layers": 24,
        "patch_size": 14,
        "input_size": 518,
        "notes": "DINOv2 large. Distilled from ViT-g. High quality.",
    },
    "vit_giant_patch14_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Giant/14",
        "params": "1.1B",
        "embed_dim": 1536,
        "num_heads": 24,
        "num_layers": 40,
        "patch_size": 14,
        "input_size": 518,
        "notes": "DINOv2 giant. Teacher model. Best features, very large.",
    },
    # DINOv2 with Registers
    "vit_small_patch14_reg4_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Small/14 + 4 registers",
        "params": "22M",
        "embed_dim": 384,
        "num_heads": 6,
        "num_layers": 12,
        "patch_size": 14,
        "input_size": 518,
        "notes": "With register tokens. Cleaner attention maps.",
    },
    "vit_base_patch14_reg4_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Base/14 + 4 registers",
        "params": "86M",
        "embed_dim": 768,
        "num_heads": 12,
        "num_layers": 12,
        "patch_size": 14,
        "input_size": 518,
        "notes": "With register tokens. Recommended for dense tasks.",
    },
    "vit_large_patch14_reg4_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Large/14 + 4 registers",
        "params": "300M",
        "embed_dim": 1024,
        "num_heads": 16,
        "num_layers": 24,
        "patch_size": 14,
        "input_size": 518,
        "notes": "With register tokens. High quality dense features.",
    },
    "vit_giant_patch14_reg4_dinov2.lvd142m": {
        "description": "DINOv2 ViT-Giant/14 + 4 registers",
        "params": "1.1B",
        "embed_dim": 1536,
        "num_heads": 24,
        "num_layers": 40,
        "patch_size": 14,
        "input_size": 518,
        "notes": "With register tokens. Best quality, very large.",
    },
    # DINOv3 Models (LVD-1689M)
    "vit_small_patch16_dinov3.lvd1689m": {
        "description": "DINOv3 ViT-Small/16 (LVD-1689M)",
        "params": "21.6M",
        "embed_dim": 384,
        "num_heads": 6,
        "num_layers": 12,
        "patch_size": 16,
        "input_size": 256,  # model card default
        "notes": "DINOv3 small distilled from ViT-7B. Use this (non-qkvb) for timm baselines.",
    },
}


def print_available_backbones():
    """Print all available backbones with their specifications."""
    print("\n" + "=" * 100)
    print("AVAILABLE BACKBONES")
    print("=" * 100)

    print("\n DINO v1 (Original, ImageNet-1K):")
    print("-" * 100)
    for name, info in AVAILABLE_BACKBONES.items():
        if ".dino" in name and "dinov2" not in name:
            print(f"  {name}")
            print(f"      {info['description']} | {info['params']} params | "
                  f"patch={info['patch_size']} | embed={info['embed_dim']}")
            print(f"      {info['notes']}")

    print("\n DINOv2 (Improved, LVD-142M):")
    print("-" * 100)
    for name, info in AVAILABLE_BACKBONES.items():
        if "dinov2" in name and "reg4" not in name:
            print(f"  {name}")
            print(f"      {info['description']} | {info['params']} params | "
                  f"patch={info['patch_size']} | embed={info['embed_dim']}")
            print(f"      {info['notes']}")

    print("\n DINOv2 with Registers:")
    print("-" * 100)
    for name, info in AVAILABLE_BACKBONES.items():
        if "reg4" in name:
            print(f"  {name}")
            print(f"      {info['description']} | {info['params']} params | "
                  f"patch={info['patch_size']} | embed={info['embed_dim']}")
            print(f"      {info['notes']}")

    print("\n" + "=" * 100 + "\n")


# Utilities
def parse_args():
    parser = argparse.ArgumentParser(
        "DINO SVF/LoRA Implicit Ensemble",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Example usage:
  # List all available backbones and datasets
  python dino_svf_ensemble_flowers.py --list_backbones
  python dino_svf_ensemble_flowers.py --list_datasets

  # Single model baseline
  python dino_svf_ensemble_flowers.py --method single --dataset flowers102

  # SVF Implicit Ensemble - 4 members
  python dino_svf_ensemble_flowers.py --method svf --n_members 4 --dataset flowers102

  # LoRA Implicit Ensemble - 4 members
  python dino_svf_ensemble_flowers.py --method lora --n_members 4 --lora_r 16 --dataset flowers102

  # Deep Ensemble - 4 independent networks
  python dino_svf_ensemble_flowers.py --method deep_ensemble --n_members 4 --dataset cars

  # MC Dropout
  python dino_svf_ensemble_flowers.py --method mc_dropout --mc_samples 10 --dataset aircraft

Available datasets: flowers102, cars, aircraft, food101 (use --list_datasets for details)
Available backbones: Use --list_backbones for full list
"""
    )

    # Data
    parser.add_argument("--data_root", type=str, default="./data")
    parser.add_argument(
        "--dataset",
        type=str,
        default="flowers102",
        choices=["flowers102", "cars", "aircraft", "food101", "cifar100", "dtd", "pets"],
        help="Dataset to use (default: flowers102)",
    )
    parser.add_argument("--list_datasets", action="store_true")
    parser.add_argument("--val_fraction", type=float, default=0.1)
    parser.add_argument("--label_fraction", type=float, default=1.0)

    # Model
    parser.add_argument(
        "--backbone",
        type=str,
        default="vit_small_patch16_224.dino",
        help="timm model name (use --list_backbones to see all options)",
    )
    parser.add_argument("--list_backbones", action="store_true")
    parser.add_argument("--image_size", type=int, default=None)
    parser.add_argument(
        "--init_mode",
        type=str,
        default="pretrained",
        choices=["pretrained", "random"],
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="ft",
        choices=["lp", "ft"],
        help="lp: linear probe (head only), ft: fine-tune",
    )

    # Ensemble method
    parser.add_argument(
        "--method",
        type=str,
        default="svf",
        choices=["svf", "lora", "batch_ensemble", "deep_ensemble", "mc_dropout", "single"],
        help="Method: svf, lora, batch_ensemble, deep_ensemble, mc_dropout, or single",
    )

    # SVF ensemble config
    parser.add_argument("--n_members", type=int, default=4)
    parser.add_argument("--topk", type=int, default=None)
    parser.add_argument(
        "--svf_scope",
        type=str,
        choices=["attn", "mlp", "attn_mlp"],
        default="attn_mlp",
        help="Which layers to apply SVF: attn (attention only), mlp (MLP only), attn_mlp (both)",
    )
    parser.add_argument("--svf_init_mean", type=float, default=0.0)
    parser.add_argument("--svf_init_std", type=float, default=0.005)
    parser.add_argument("--head_init_mean", type=float, default=0.0)
    parser.add_argument("--head_init_std", type=float, default=0.01)

    # LoRA ensemble config
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
    parser.add_argument("--lora_alpha", type=float, default=8.0, help="LoRA alpha scaling")
    parser.add_argument("--lora_dropout", type=float, default=0.0, help="LoRA dropout")
    parser.add_argument("--lora_init_std", type=float, default=0.1, help="LoRA init std")
    parser.add_argument(
        "--lora_scope",
        type=str,
        choices=["attn", "mlp", "attn_mlp"],
        default="attn_mlp",
        help="Which layers to apply LoRA: attn (attention only), mlp (MLP only), attn_mlp (both)",
    )

    # BatchEnsemble config
    parser.add_argument("--be_init_std", type=float, default=0.01, help="BatchEnsemble r/s init std")
    parser.add_argument(
        "--be_scope",
        type=str,
        choices=["attn", "mlp", "attn_mlp"],
        default="attn_mlp",
        help="Which layers to apply BatchEnsemble: attn (attention only), mlp (MLP only), attn_mlp (both)",
    )

    # MC Dropout config
    parser.add_argument("--mc_samples", type=int, default=4)
    parser.add_argument("--mc_dropout_rate", type=float, default=0.05)

    # MLP head options
    parser.add_argument("--use_mlp_head", action="store_true")
    parser.add_argument("--mlp_hidden_dim", type=int, default=None)
    parser.add_argument("--mlp_dropout", type=float, default=0.1)

    # Training
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=10)
    parser.add_argument("--warmup_epochs", type=int, default=5)
    parser.add_argument(
        "--lr",
        type=float,
        default=None,
        help="Learning rate (default: 1e-3 for svf, 1e-4 for others -  unse also 1e-3 for lora which is optimal)",
    )
    parser.add_argument("--weight_decay", type=float, default=0.05)
    parser.add_argument("--grad_clip", type=float, default=1.0)

    # Misc
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--amp", action="store_true")
    parser.add_argument("--verbose", action="store_true")

    # Logging
    parser.add_argument(
        "--save_log",
        action="store_true",
        help="Save logs to file (default: False)",
    )
    parser.add_argument(
        "--log_dir",
        type=str,
        default="logs",
        help="Directory to save log files (default: logs)",
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        default=None,
        help="Custom experiment name (default: auto-generated)",
    )
    # OOD eval
    parser.add_argument(
        "--ood_eval",
        action="store_true",
        help="If set, run OOD detection on (CIFAR-10, SVHN) using the trained model. ID = in-distribution test set.",
    )
    parser.add_argument(
        "--ood_batch_size",
        type=int,
        default=None,
        help="Optional batch size for OOD loaders (defaults to --batch_size).",
    )
    parser.add_argument(
        "--ood_num_workers",
        type=int,
        default=None,
        help="Optional num_workers for OOD loaders (defaults to --num_workers).",
    )
    # CIFAR-100-C eval (optional, evaluation only)
    parser.add_argument(
        "--cifar100c_eval",
        action="store_true",
        help="If set, evaluate the trained model on CIFAR-100-C for severities 1..5 (no training).",
    )
    parser.add_argument(
        "--cifar100c_root",
        type=str,
        default="data/CIFAR100C/",
        help="Path to extracted CIFAR-100-C .npy files (must contain labels.npy and corruption .npy files).",
    )

    return parser.parse_args()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_gpu_memory_mb():
    """Get current GPU memory usage in MB."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024 / 1024
    return 0.0


def get_gpu_max_memory_mb():
    """Get peak GPU memory usage in MB."""
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated() / 1024 / 1024
    return 0.0


def reset_memory_stats():
    """Reset GPU memory statistics."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()


def measure_inference_metrics(model, input_shape, device, n_warmup=5, n_runs=20):
    """
    Measure inference metrics: FLOPs, memory, and time for a SINGLE forward pass.

    IMPORTANT: This function properly isolates inference measurements by:
    1. Clearing GPU cache and resetting memory stats
    2. Moving model to eval mode
    3. Using torch.no_grad() throughout
    4. Synchronizing CUDA before/after measurements

    Args:
        model: The model to measure
        input_shape: Input tensor shape (e.g., (1, 3, 224, 224) for single image)
        device: Device to run on
        n_warmup: Number of warmup runs before timing
        n_runs: Number of timed runs for averaging

    Returns:
        dict with keys:
            - flops: FLOPs for single forward pass
            - inference_memory_mb: Peak GPU memory during single inference
            - inference_time_ms: Average inference time in ms
            - inference_time_std_ms: Std of inference time
    """
    model.eval()

    # Clear GPU memory completely
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    # Create dummy input
    dummy_input = torch.randn(input_shape, device=device)

    # Step 2: Measure FLOPs
    flops = None
    try:
        from fvcore.nn import FlopCountAnalysis
        with torch.no_grad():
            flop_counter = FlopCountAnalysis(model, dummy_input)
            flops = flop_counter.total()
    except ImportError:
        # Manual estimation based on model parameters
        total_params = sum(p.numel() for p in model.parameters())
        seq_len = (input_shape[2] // 16) ** 2 + 1  # approximate for ViT
        flops = 2 * total_params * seq_len
    except Exception as e:
        print(f"Warning: Could not count FLOPs: {e}")
        total_params = sum(p.numel() for p in model.parameters())
        flops = 2 * total_params * 200
    
    # Measure inference memory (isolated measurement)
    inference_memory_mb = 0.0
    if torch.cuda.is_available():
        # Clear everything first
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

        # Get baseline memory (model weights already loaded)
        baseline_memory = torch.cuda.memory_allocated()

        # Run single inference
        with torch.no_grad():
            _ = model(dummy_input)

        torch.cuda.synchronize()

        # Get peak memory during inference
        peak_memory = torch.cuda.max_memory_allocated()

        # Inference memory = peak - baseline (activations + intermediate tensors)
        inference_memory_mb = (peak_memory - baseline_memory) / 1024 / 1024

        # Also report total memory including model
        total_inference_memory_mb = peak_memory / 1024 / 1024
    else:
        total_inference_memory_mb = 0.0

    # Measure inference time (after clearing cache again)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    with torch.no_grad():
        # Warmup runs (important for accurate timing)
        for _ in range(n_warmup):
            _ = model(dummy_input)

        if torch.cuda.is_available():
            torch.cuda.synchronize()

        # Timed runs
        times = []
        for _ in range(n_runs):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
  
            start = time.perf_counter()
            _ = model(dummy_input)

            if torch.cuda.is_available():
                torch.cuda.synchronize()

            end = time.perf_counter()
            times.append((end - start) * 1000)  # Convert to ms

    avg_time_ms = np.mean(times)
    std_time_ms = np.std(times)

    # Clean up
    del dummy_input
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return {
        "flops": flops,
        "inference_memory_mb": inference_memory_mb,  # Memory for activations only
        "total_inference_memory_mb": total_inference_memory_mb,  # Model + activations
        "inference_time_ms": avg_time_ms,
        "inference_time_std_ms": std_time_ms,
    }


def format_flops(flops):
    """Format FLOPs in human-readable form."""
    if flops >= 1e12:
        return f"{flops / 1e12:.2f} TFLOPs"
    elif flops >= 1e9:
        return f"{flops / 1e9:.2f} GFLOPs"
    elif flops >= 1e6:
        return f"{flops / 1e6:.2f} MFLOPs"
    else:
        return f"{flops:.0f} FLOPs"


def format_time(seconds):
    """Format time in human-readable form."""
    if seconds >= 3600:
        hours = seconds // 3600
        minutes = (seconds % 3600) // 60
        secs = seconds % 60
        return f"{int(hours)}h {int(minutes)}m {secs:.1f}s"
    elif seconds >= 60:
        minutes = seconds // 60
        secs = seconds % 60
        return f"{int(minutes)}m {secs:.1f}s"
    else:
        return f"{seconds:.2f}s"


def get_input_size_for_backbone(backbone_name: str, override_size: int = None) -> int:
    """Get the appropriate input size for a backbone."""
    if override_size is not None:
        return override_size
    if backbone_name in AVAILABLE_BACKBONES:
        return AVAILABLE_BACKBONES[backbone_name]["input_size"]
    if "dinov2" in backbone_name:
        return 518
    return 224


def build_transforms(image_size: int = 224):
    """Build train and eval transforms for the given image size."""
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.4, 0.4, 0.4),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    eval_tf = transforms.Compose([
        transforms.Resize(int(image_size * 256 / 224)),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225]),
    ])

    return train_tf, eval_tf


def create_backbone(backbone_name: str, pretrained: bool = True, img_size: int = None):
    """Create a backbone model from timm."""
    print(f"Loading {backbone_name} via timm")

    create_kwargs = {
        "pretrained": pretrained,
        "num_classes": 0,
    }
    if img_size is not None:
        create_kwargs["img_size"] = img_size

    backbone = timm.create_model(backbone_name, **create_kwargs)
    return backbone, backbone.num_features


def create_cifar100c_loaders(args):
    """
    Returns dict severity->DataLoader for CIFAR-100-C, severities 1..5.
    Uses same eval transforms as backbone.
    """
    if args.cifar100c_root is None:
        raise ValueError("--cifar100c_root must be provided when --cifar100c_eval is set.")

    image_size = get_input_size_for_backbone(args.backbone, args.image_size)
    _, eval_tf = build_transforms(image_size)

    bs = args.ood_batch_size if args.ood_batch_size is not None else args.batch_size
    nw = args.ood_num_workers if args.ood_num_workers is not None else args.num_workers

    loaders = {}
    for sev in [1, 2, 3, 4, 5]:
        ds = CIFAR100C(root_dir=args.cifar100c_root, severity=sev, transform=eval_tf)
        loaders[sev] = DataLoader(ds, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=True)

    return loaders

# SVF Layers (Single Model)
class SVFLinear(nn.Module):
    """Single-model SVF layer: W = U @ diag(s) @ Vh"""
    def __init__(self, base_linear: nn.Linear, topk: int = None):
        super().__init__()

        W0 = base_linear.weight.data.detach().cpu().float()
        U, S, Vh = torch.linalg.svd(W0, full_matrices=False)

        if topk is not None:
            k = min(topk, S.shape[0])
            U = U[:, :k]
            S = S[:k]
            Vh = Vh[:k, :]
        else:
            k = S.shape[0]

        self.register_buffer("U", U.contiguous())
        self.register_buffer("Vh", Vh.contiguous())
        self.s = nn.Parameter(S.clone().contiguous())

        if base_linear.bias is not None:
            self.bias = nn.Parameter(base_linear.bias.detach().clone().contiguous())
        else:
            self.bias = None

        self.out_features = U.shape[0]
        self.in_features = Vh.shape[1]

    def forward(self, x):
        W = (self.U * self.s.unsqueeze(0)) @ self.Vh
        return F.linear(x, W, self.bias)


# ----------------------------------------------------------------------
# SVF Ensemble Layers
# ----------------------------------------------------------------------

class EnsembleSVFLinear(nn.Module):
    """
    Implicit SVF ensemble layer.

    Each member m has its own singular values s_m.
    W_m = U @ diag(s_m) @ Vh

    Input: [B*M, ..., in_features] in SAMPLE-MAJOR order
    Output: [B*M, ..., out_features]
    """
    def __init__(
        self,
        base_linear: nn.Linear,
        topk: int = None,
        n_members: int = 4,
        init_mean: float = 0.0,
        init_std: float = 0.01,
    ):
        super().__init__()

        self.n_members = n_members

        W0 = base_linear.weight.data.detach().cpu().float()
        U, S, Vh = torch.linalg.svd(W0, full_matrices=False)

        if topk is not None:
            k = min(topk, S.shape[0])
            U = U[:, :k]
            S = S[:k]
            Vh = Vh[:k, :]
        else:
            k = S.shape[0]

        self.k = k

        self.register_buffer("U", U.contiguous())
        self.register_buffer("Vh", Vh.contiguous())
        self.register_buffer("s_base", S.contiguous())

        # Per-member singular values: [M, k]
        s_init = S.unsqueeze(0).repeat(n_members, 1)
        noise = torch.randn_like(s_init) * init_std + init_mean
        self.s = nn.Parameter((s_init + noise).contiguous())

        # Per-member bias: [M, out_dim]
        if base_linear.bias is not None:
            b = base_linear.bias.detach().clone()
            b_init = b.unsqueeze(0).repeat(n_members, 1)
            bias_noise = torch.randn_like(b_init) * init_std + init_mean
            self.bias = nn.Parameter((b_init + bias_noise).contiguous())
        else:
            self.bias = None

        self.out_features = U.shape[0]
        self.in_features = Vh.shape[1]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        M = self.n_members
        orig_shape = x.shape
        in_features = orig_shape[-1]
        assert in_features == self.in_features

        B_total = orig_shape[0]
        rest_dims = orig_shape[1:-1]

        if len(rest_dims) == 0:
            x_bt_in = x.view(B_total, 1, in_features)
        else:
            x_bt_in = x.view(B_total, -1, in_features)
        T = x_bt_in.shape[1]

        assert B_total % M == 0
        B = B_total // M

        # Reshape: [B*M, T, in] -> [B, M, T, in] -> [M, B, T, in]
        x_bmti = x_bt_in.view(B, M, T, in_features)
        x_mbti = x_bmti.permute(1, 0, 2, 3).contiguous()
        x_mNi = x_mbti.view(M, B * T, in_features)

        U = self.U
        Vh = self.Vh
        s = self.s
        bias = self.bias

        outs = []
        for m in range(M):
            Wm = (U * s[m].unsqueeze(0)) @ Vh
            bm = bias[m] if bias is not None else None
            ym = F.linear(x_mNi[m], Wm, bm)
            outs.append(ym)

        y_mNo = torch.stack(outs, dim=0)
        y_mbto = y_mNo.view(M, B, T, self.out_features)
        y_bmto = y_mbto.permute(1, 0, 2, 3).contiguous()

        if len(rest_dims) == 0:
            y = y_bmto.view(B * M, self.out_features)
        else:
            y = y_bmto.view(B * M, *rest_dims, self.out_features)

        return y


# LoRA Ensemble Layers
class EnsembleLoRALinear(nn.Module):
    """
    LoRA-based ensemble layer.

    Each member m has its own low-rank adapter:
    W_m = W_base + B_m @ A_m * (alpha / r)

    Where:
    - W_base: Frozen base weights [out_features, in_features]
    - A_m: Per-member down-projection [n_members, r, in_features]
    - B_m: Per-member up-projection [n_members, out_features, r]

    Parameters per member: r * (in_features + out_features)

    Input: [B*M, seq_len, in_features] in SAMPLE-MAJOR order
    Output: [B*M, seq_len, out_features]
    """

    def __init__(
        self,
        base_linear: nn.Linear,
        n_members: int,
        lora_r: int = 16,
        lora_alpha: float = 32.0,
        lora_dropout: float = 0.0,
        lora_init_std: float = 0.02,
    ):
        super().__init__()
        self.n_members = n_members
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.scaling = lora_alpha / lora_r
        self.lora_dropout = lora_dropout
  
        orig_dtype = base_linear.weight.dtype
 
        # Store base weights as buffer (frozen)
        self.register_buffer('base_weight', base_linear.weight.data.detach().clone())
        if base_linear.bias is not None:
            self.register_buffer('base_bias', base_linear.bias.data.detach().clone())
        else:
            self.register_buffer('base_bias', None)

        # LoRA parameters per member
        # A: down-projection [n_members, r, in_features]
        # B: up-projection [n_members, out_features, r]
        # Standard LoRA init: A ~ N(0, std), B = 0 (so residual starts at 0)
        lora_A_init = torch.randn(n_members, lora_r, self.in_features, dtype=orig_dtype) * lora_init_std
        lora_B_init = torch.zeros(n_members, self.out_features, lora_r, dtype=orig_dtype)

        self.lora_A = nn.Parameter(lora_A_init.contiguous())
        self.lora_B = nn.Parameter(lora_B_init.contiguous())

        # Dropout layer
        if lora_dropout > 0:
            self.dropout = nn.Dropout(p=lora_dropout)
        else:
            self.dropout = nn.Identity()

        # Parameter count info
        params_per_member = lora_r * (self.in_features + self.out_features)
        total_lora_params = params_per_member * n_members
        print(f"[EnsembleLoRALinear] {self.in_features}x{self.out_features}, r={lora_r}, "
              f"params/member={params_per_member:,}, total={total_lora_params:,}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x shape: [B*M, seq_len, in_features] or [B*M, in_features]
        """
        M = self.n_members
        orig_shape = x.shape
        in_features = orig_shape[-1]
        assert in_features == self.in_features, "Input dim mismatch"

        # Flatten all non-feature dims
        B_total = orig_shape[0]
        rest_dims = orig_shape[1:-1]
        if len(rest_dims) == 0:
            x_bt_in = x.view(B_total, 1, in_features)
        else:
            x_bt_in = x.view(B_total, -1, in_features)
        T = x_bt_in.shape[1]

        assert B_total % M == 0, f"Batch {B_total} not divisible by n_members {M}"
        B = B_total // M

        # Reshape to [B, M, T, in], then permute to [M, B, T, in]
        x_bmti = x_bt_in.view(B, M, T, in_features)
        x_mbti = x_bmti.permute(1, 0, 2, 3).contiguous()
        x_mNi = x_mbti.view(M, B * T, in_features)

        # Apply dropout to input before LoRA
        x_mNi = self.dropout(x_mNi)

        # Per-member forward
        outs = []
        for m in range(M):
            # Compute LoRA delta: B @ A * scaling
            # A[m]: [r, in], B[m]: [out, r]
            # delta_W = B[m] @ A[m]: [out, in]
            delta_W = self.lora_B[m] @ self.lora_A[m]
            W_m = self.base_weight + delta_W * self.scaling

            o = F.linear(x_mNi[m], W_m, self.base_bias)
            outs.append(o)

        out_mNo = torch.stack(outs, dim=0)

        # Reshape back to [B*M, T, out]
        out_features = out_mNo.shape[-1]
        out_mbto = out_mNo.view(M, B, T, out_features).permute(1, 0, 2, 3)
        out_bt_o = out_mbto.reshape(B_total, T, out_features)

        # Restore original shape
        if len(rest_dims) == 0:
            return out_bt_o.squeeze(1)
        else:
            return out_bt_o.view(*orig_shape[:-1], out_features)


# BatchEnsemble Layers
class EnsembleBatchEnsembleLinear(nn.Module):
    """
    BatchEnsemble layer for implicit ensemble.

    We keep W frozen (buffer), and only train per-member r/s (and optionally bias).
    Efficient computation: y_m = (x * s_m) @ W^T * r_m
    """

    def __init__(
        self,
        base_linear: nn.Linear,
        n_members: int = 4,
        init_std: float = 0.02,
        train_bias: bool = True,  # <- NEW: allow disabling bias training easily
    ):
        super().__init__()

        self.n_members = n_members
        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        self.train_bias = train_bias
        orig_dtype = base_linear.weight.dtype
        self.weight = nn.Parameter(base_linear.weight.detach().clone())

        # Per-member rank-1 scaling factors (trainable)
        r_init = torch.ones(n_members, self.out_features, dtype=orig_dtype) + \
                 torch.randn(n_members, self.out_features, dtype=orig_dtype) * init_std
        s_init = torch.ones(n_members, self.in_features, dtype=orig_dtype) + \
                 torch.randn(n_members, self.in_features, dtype=orig_dtype) * init_std

        self.r = nn.Parameter(r_init.contiguous())
        self.s = nn.Parameter(s_init.contiguous())

        # Per-member bias (optional trainable)
        if base_linear.bias is not None:
            bias = base_linear.bias.detach().clone()
            bias_init = bias.unsqueeze(0).repeat(n_members, 1) + \
                        torch.randn(n_members, self.out_features, dtype=orig_dtype) * init_std
            self.bias = nn.Parameter(bias_init.contiguous())
            if not train_bias:
                self.bias.requires_grad = False
        else:
            self.bias = None

        params_per_member = self.in_features + self.out_features
        total_be_params = params_per_member * n_members
        shared_params = self.in_features * self.out_features
        print(
            f"[EnsembleBatchEnsemble FROZEN-W] {self.in_features}x{self.out_features}, "
            f"params/member={params_per_member:,}, total_member_params={total_be_params:,}, "
            f"shared_params(frozen)={shared_params:,}"
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        M = self.n_members
        orig_shape = x.shape
        in_features = orig_shape[-1]
        assert in_features == self.in_features, f"Input dim mismatch: {in_features} vs {self.in_features}"

        B_total = orig_shape[0]
        rest_dims = orig_shape[1:-1]

        if len(rest_dims) == 0:
            x_bt_in = x.view(B_total, 1, in_features)
        else:
            x_bt_in = x.view(B_total, -1, in_features)
        T = x_bt_in.shape[1]

        assert B_total % M == 0, f"Batch size {B_total} not divisible by n_members {M}"
        B = B_total // M

        # [B*M, T, in] -> [B, M, T, in]
        x_bmti = x_bt_in.view(B, M, T, in_features)

        # input scaling
        x_scaled = x_bmti * self.s.view(1, M, 1, in_features)

        # shared frozen weight
        y = torch.matmul(x_scaled, self.weight.T)

        # output scaling
        y = y * self.r.view(1, M, 1, self.out_features)

        if self.bias is not None:
            y = y + self.bias.view(1, M, 1, self.out_features)

        # back to [B*M, ...]
        y = y.view(B * M, T, self.out_features)

        if len(rest_dims) == 0:
            return y.squeeze(1)
        else:
            return y.view(*orig_shape[:-1], self.out_features)


# Ensemble Classifier Heads
class EnsembleClassifierHead(nn.Module):
    """Member-specific linear classifier heads."""
    def __init__(
        self,
        in_features: int,
        out_features: int,
        n_members: int,
        head_init_mean: float = 0.0,
        head_init_std: float = 0.01,
    ):
        super().__init__()
        self.n_members = n_members
        self.in_features = in_features
        self.out_features = out_features

        W_init = torch.randn(n_members, out_features, in_features) * head_init_std + head_init_mean
        self.weight = nn.Parameter(W_init.contiguous())

        b_init = torch.zeros(n_members, out_features)
        self.bias = nn.Parameter(b_init.contiguous())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B_total, H = x.shape
        M = self.n_members
        assert B_total % M == 0
        B = B_total // M

        x = x.view(B, M, H).permute(1, 0, 2).contiguous()

        out = torch.einsum("mbh,mch->mbc", x, self.weight)
        out = out + self.bias[:, None, :]

        out = out.permute(1, 0, 2).reshape(B_total, self.out_features)
        return out


class EnsembleMLPClassifierHead(nn.Module):
    """Member-specific 3-layer MLP classifier heads."""
    def __init__(
        self,
        in_features: int,
        out_features: int,
        n_members: int,
        hidden_dim: int = None,
        head_init_mean: float = 0.0,
        head_init_std: float = 0.02,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.n_members = n_members
        self.in_features = in_features
        self.out_features = out_features

        if hidden_dim is None:
            hidden_dim = in_features
        self.hidden_dim = hidden_dim

        self.weight1 = nn.Parameter(
            (torch.randn(n_members, hidden_dim, in_features) * head_init_std + head_init_mean).contiguous()
        )
        self.bias1 = nn.Parameter(torch.zeros(n_members, hidden_dim).contiguous())

        self.weight2 = nn.Parameter(
            (torch.randn(n_members, hidden_dim, hidden_dim) * head_init_std + head_init_mean).contiguous()
        )
        self.bias2 = nn.Parameter(torch.zeros(n_members, hidden_dim).contiguous())

        self.weight3 = nn.Parameter(
            (torch.randn(n_members, out_features, hidden_dim) * head_init_std + head_init_mean).contiguous()
        )
        self.bias3 = nn.Parameter(torch.zeros(n_members, out_features).contiguous())

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B_total, H = x.shape
        M = self.n_members
        assert B_total % M == 0
        B = B_total // M

        x = x.view(B, M, H).permute(1, 0, 2).contiguous()

        out = torch.einsum("mbh,mdh->mbd", x, self.weight1)
        out = out + self.bias1[:, None, :]
        out = F.gelu(out)
        out = self.dropout(out)

        out = torch.einsum("mbd,med->mbe", out, self.weight2)
        out = out + self.bias2[:, None, :]
        out = F.gelu(out)
        out = self.dropout(out)

        out = torch.einsum("mbe,mce->mbc", out, self.weight3)
        out = out + self.bias3[:, None, :]

        out = out.permute(1, 0, 2).reshape(B_total, self.out_features)
        return out


# MC Dropout Model
class ViTMCDropout(nn.Module):
    """ViT with MC Dropout for uncertainty estimation."""
    def __init__(
        self,
        backbone_name: str,
        num_classes: int,
        pretrained: bool = True,
        dropout_rate: float = 0.2,
        mode: str = "ft",
        img_size: int = None,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.mode = mode
        self.dropout_rate = dropout_rate

        print(f"Building MC Dropout model {backbone_name}, pretrained={pretrained}")

        self.backbone, in_features = create_backbone(backbone_name, pretrained, img_size)

        self.dropout = nn.Dropout(dropout_rate)
        self.head = nn.Linear(in_features, num_classes)

        self._inject_dropout(dropout_rate)
        self._setup_trainable_params()

    def _inject_dropout(self, dropout_rate):
        """Inject/replace dropout in attention and MLP blocks."""
        if not hasattr(self.backbone, 'blocks'):
            print("Warning: backbone has no 'blocks' attribute")
            return

        for block in self.backbone.blocks:
            if hasattr(block, 'attn'):
                if hasattr(block.attn, 'attn_drop'):
                    block.attn.attn_drop = nn.Dropout(dropout_rate)
                if hasattr(block.attn, 'proj_drop'):
                    block.attn.proj_drop = nn.Dropout(dropout_rate)

            if hasattr(block, 'mlp'):
                if hasattr(block.mlp, 'drop'):
                    block.mlp.drop = nn.Dropout(dropout_rate)
                if hasattr(block.mlp, 'drop1'):
                    block.mlp.drop1 = nn.Dropout(dropout_rate)
                if hasattr(block.mlp, 'drop2'):
                    block.mlp.drop2 = nn.Dropout(dropout_rate)

            if hasattr(block, 'drop_path'):
                block.drop_path = nn.Dropout(dropout_rate)

    def _setup_trainable_params(self):
        if self.mode == "lp":
            for p in self.backbone.parameters():
                p.requires_grad = False
            for p in self.head.parameters():
                p.requires_grad = True
        else:
            for p in self.parameters():
                p.requires_grad = True

    def forward(self, x, labels=None):
        features = self.backbone(x)
        features = self.dropout(features)
        logits = self.head(features)

        output = {"logits": logits}

        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            output["loss"] = loss

        return output

    def forward_mc(self, x, n_samples=10, labels=None):
        self.train()

        B = x.size(0)
        all_logits = []

        with torch.no_grad():
            for _ in range(n_samples):
                features = self.backbone(x)
                features = self.dropout(features)
                logits = self.head(features)
                all_logits.append(logits)

        all_logits = torch.stack(all_logits, dim=0)
        logits_mean = all_logits.mean(dim=0)

        output = {
            "logits": logits_mean,
            "logits_members": all_logits.permute(1, 0, 2),
        }

        if labels is not None:
            loss = F.cross_entropy(logits_mean, labels)
            output["loss"] = loss

        return output


class ViTBatchEnsemble(nn.Module):
    """
    ViT with BatchEnsemble implicit ensemble.

    BatchEnsemble (Wen et al., 2020) uses rank-1 multiplicative perturbations:
    W_m = W ⊙ (r_m ⊗ s_m)

    Key properties:
    - Shared base weights W are trained (unlike LoRA where base is frozen)
    - Per-member rank-1 factors r_m, s_m provide member diversity
    - Initialized so W_m ≈ W initially (r_m, s_m ≈ 1)
    - Very parameter-efficient: only in_features + out_features per member per layer
    """
    def __init__(
        self,
        backbone_name: str,
        num_classes: int,
        pretrained: bool = True,
        n_members: int = 4,
        be_scope: str = "attn_mlp",
        be_init_std: float = 0.02,
        head_init_mean: float = 0.0,
        head_init_std: float = 0.01,
        use_mlp_head: bool = False,
        mlp_hidden_dim: int = None,
        mlp_dropout: float = 0.1,
        mode: str = "ft",
        img_size: int = None,
    ):
        super().__init__()

        self.n_members = n_members
        self.num_classes = num_classes
        self.mode = mode

        print(f"Building BatchEnsemble model {backbone_name}, pretrained={pretrained}")
        self.backbone, in_features = create_backbone(backbone_name, pretrained, img_size)

        if mode == "ft" and n_members > 1:
            print(f"Wrapping with BatchEnsemble (n_members={n_members})")
            wrap_vit_with_batchensemble(
                self.backbone,
                be_scope=be_scope,
                n_members=n_members,
                be_init_std=be_init_std,
            )

        if n_members > 1:
            if use_mlp_head:
                print(f"Using 3-layer MLP ensemble head")
                self.head = EnsembleMLPClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    hidden_dim=mlp_hidden_dim,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                    dropout=mlp_dropout,
                )
            else:
                print("Using linear ensemble head")
                self.head = EnsembleClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                )
        else:
            print("Using single linear head")
            self.head = nn.Linear(in_features, num_classes)

        self._setup_trainable_params()

    def _setup_trainable_params(self):
        if self.mode == "lp":
            # linear probe: only head
            for p in self.backbone.parameters():
                p.requires_grad = False
            for p in self.head.parameters():
                p.requires_grad = True
            return

        # fine-tune: train backbone + BE factors + head
        for p in self.backbone.parameters():
            p.requires_grad = True
        for p in self.head.parameters():
            p.requires_grad = True

    def forward(self, x, labels=None):
        M = self.n_members
        B = x.size(0)

        if M > 1:
            x_rep = x.repeat_interleave(M, dim=0)
        else:
            x_rep = x

        features = self.backbone(x_rep)
        logits_all = self.head(features)

        if M > 1:
            logits_members = logits_all.view(B, M, self.num_classes)
            logits_mean = logits_members.mean(dim=1)
        else:
            logits_members = logits_all.unsqueeze(1)
            logits_mean = logits_all

        output = {"logits": logits_mean, "logits_members": logits_members}

        if labels is not None:
            if M > 1:
                logits_flat = logits_members.view(B * M, self.num_classes)
                labels_rep = labels.unsqueeze(1).repeat(1, M).view(-1)
            else:
                logits_flat = logits_all
                labels_rep = labels

            loss = F.cross_entropy(logits_flat, labels_rep)
            output["loss"] = loss

        return output


# Standard ViT for Deep Ensemble
class ViTStandard(nn.Module):
    """Standard ViT model for deep ensemble training."""
    def __init__(
        self,
        backbone_name: str,
        num_classes: int,
        pretrained: bool = True,
        mode: str = "ft",
        img_size: int = None,
    ):
        super().__init__()

        self.num_classes = num_classes
        self.mode = mode

        self.backbone, in_features = create_backbone(backbone_name, pretrained, img_size)
        self.head = nn.Linear(in_features, num_classes)

        self._setup_trainable_params()

    def _setup_trainable_params(self):
        if self.mode == "lp":
            for p in self.backbone.parameters():
                p.requires_grad = False
            for p in self.head.parameters():
                p.requires_grad = True
        else:
            for p in self.parameters():
                p.requires_grad = True

    def forward(self, x, labels=None):
        features = self.backbone(x)
        logits = self.head(features)

        output = {"logits": logits}

        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            output["loss"] = loss

        return output


# ViT Wrapping Functions
def get_transformer_blocks(model):
    """Get transformer blocks from different model architectures."""
    if hasattr(model, 'blocks'):
        return list(model.blocks)

    if hasattr(model, 'visual'):
        if hasattr(model.visual, 'transformer'):
            return list(model.visual.transformer.resblocks)
        if hasattr(model.visual, 'trunk') and hasattr(model.visual.trunk, 'blocks'):
            return list(model.visual.trunk.blocks)

    if hasattr(model, 'vision_model'):
        if hasattr(model.vision_model, 'encoder'):
            return list(model.vision_model.encoder.layers)

    print("Warning: Could not find transformer blocks in model")
    return []


def wrap_vit_with_svf(model, topk=None, svf_scope="attn_mlp"):
    """Wrap a ViT model with single-model SVF layers."""
    blocks = get_transformer_blocks(model)

    if not blocks:
        return model

    for i, block in enumerate(blocks):
        msg = f"[INFO] Block {i}:"
        wrapped_any = False

        # Wrap attention layers (if scope includes attn)
        if svf_scope in ["attn", "attn_mlp"]:
            attn = None
            if hasattr(block, 'attn'):
                attn = block.attn
            elif hasattr(block, 'self_attn'):
                attn = block.self_attn

            if attn is not None:
                if hasattr(attn, 'qkv'):
                    attn.qkv = SVFLinear(attn.qkv, topk=topk)
                    msg += " qkv"
                    wrapped_any = True
                else:
                    for proj_name in ['q_proj', 'k_proj', 'v_proj', 'q', 'k', 'v']:
                        if hasattr(attn, proj_name):
                            setattr(attn, proj_name, SVFLinear(getattr(attn, proj_name), topk=topk))
                            wrapped_any = True
                    if wrapped_any:
                        msg += " Q/K/V"

                for proj_name in ['proj', 'out_proj', 'o_proj']:
                    if hasattr(attn, proj_name):
                        setattr(attn, proj_name, SVFLinear(getattr(attn, proj_name), topk=topk))
                        msg += f" {proj_name}"
                        wrapped_any = True
                        break

        # Wrap MLP layers (if scope includes mlp)
        if svf_scope in ["mlp", "attn_mlp"]:
            mlp = None
            if hasattr(block, 'mlp'):
                mlp = block.mlp
            elif hasattr(block, 'ffn'):
                mlp = block.ffn

            if mlp is not None:
                for fc_name in ['fc1', 'fc2', 'c_fc', 'c_proj', 'wi', 'wo']:
                    if hasattr(mlp, fc_name):
                        setattr(mlp, fc_name, SVFLinear(getattr(mlp, fc_name), topk=topk))
                        msg += f" {fc_name}"
                        wrapped_any = True

        if wrapped_any:
            print(msg)

    return model


def wrap_vit_with_svf_ensemble(
    model,
    topk=None,
    svf_scope="attn_mlp",
    n_members=4,
    svf_init_mean=0.0,
    svf_init_std=0.01,
):
    """Wrap a ViT model with ensemble SVF layers."""
    def make_ensemble_linear(base_linear):
        return EnsembleSVFLinear(
            base_linear,
            topk=topk,
            n_members=n_members,
            init_mean=svf_init_mean,
            init_std=svf_init_std,
        )

    blocks = get_transformer_blocks(model)

    if not blocks:
        return model

    for i, block in enumerate(blocks):
        msg = f"[INFO] Block {i}:"
        wrapped_any = False

        # Wrap attention layers (if scope includes attn)
        if svf_scope in ["attn", "attn_mlp"]:
            attn = None
            if hasattr(block, 'attn'):
                attn = block.attn
            elif hasattr(block, 'self_attn'):
                attn = block.self_attn

            if attn is not None:
                if hasattr(attn, 'qkv'):
                    attn.qkv = make_ensemble_linear(attn.qkv)
                    msg += " qkv"
                    wrapped_any = True
                else:
                    for proj_name in ['q_proj', 'k_proj', 'v_proj', 'q', 'k', 'v']:
                        if hasattr(attn, proj_name):
                            setattr(attn, proj_name, make_ensemble_linear(getattr(attn, proj_name)))
                            wrapped_any = True
                    if wrapped_any:
                        msg += " Q/K/V"

                for proj_name in ['proj', 'out_proj', 'o_proj']:
                    if hasattr(attn, proj_name):
                        setattr(attn, proj_name, make_ensemble_linear(getattr(attn, proj_name)))
                        msg += f" {proj_name}"
                        wrapped_any = True
                        break

        # Wrap MLP layers (if scope includes mlp)
        if svf_scope in ["mlp", "attn_mlp"]:
            mlp = None
            if hasattr(block, 'mlp'):
                mlp = block.mlp
            elif hasattr(block, 'ffn'):
                mlp = block.ffn

            if mlp is not None:
                for fc_name in ['fc1', 'fc2', 'c_fc', 'c_proj', 'wi', 'wo']:
                    if hasattr(mlp, fc_name):
                        setattr(mlp, fc_name, make_ensemble_linear(getattr(mlp, fc_name)))
                        msg += f" {fc_name}"
                        wrapped_any = True

        if wrapped_any:
            print(msg)

    return model


def wrap_vit_with_lora_ensemble(
    model,
    lora_scope="attn_mlp",
    n_members=4,
    lora_r=16,
    lora_alpha=32.0,
    lora_dropout=0.0,
    lora_init_std=0.02,
):
    """Wrap a ViT model with ensemble LoRA layers."""
    def make_lora_linear(base_linear):
        return EnsembleLoRALinear(
            base_linear,
            n_members=n_members,
            lora_r=lora_r,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            lora_init_std=lora_init_std,
        )

    blocks = get_transformer_blocks(model)

    if not blocks:
        print("Warning: No blocks found for LoRA ensemble wrapping")
        return model

    for i, block in enumerate(blocks):
        msg = f"[INFO] Block {i}:"
        wrapped_any = False

        # Wrap attention layers (if scope includes attn)
        if lora_scope in ["attn", "attn_mlp"]:
            attn = None
            if hasattr(block, 'attn'):
                attn = block.attn
            elif hasattr(block, 'self_attn'):
                attn = block.self_attn

            if attn is not None:
                if hasattr(attn, 'qkv'):
                    attn.qkv = make_lora_linear(attn.qkv)
                    msg += " qkv"
                    wrapped_any = True
                else:
                    for proj_name in ['q_proj', 'k_proj', 'v_proj', 'q', 'k', 'v']:
                        if hasattr(attn, proj_name):
                            setattr(attn, proj_name, make_lora_linear(getattr(attn, proj_name)))
                            wrapped_any = True
                    if wrapped_any:
                        msg += " Q/K/V"

                for proj_name in ['proj', 'out_proj', 'o_proj']:
                    if hasattr(attn, proj_name):
                        setattr(attn, proj_name, make_lora_linear(getattr(attn, proj_name)))
                        msg += f" {proj_name}"
                        wrapped_any = True
                        break

        # Wrap MLP layers (if scope includes mlp)
        if lora_scope in ["mlp", "attn_mlp"]:
            mlp = None
            if hasattr(block, 'mlp'):
                mlp = block.mlp
            elif hasattr(block, 'ffn'):
                mlp = block.ffn

            if mlp is not None:
                for fc_name in ['fc1', 'fc2', 'c_fc', 'c_proj', 'wi', 'wo']:
                    if hasattr(mlp, fc_name):
                        setattr(mlp, fc_name, make_lora_linear(getattr(mlp, fc_name)))
                        msg += f" {fc_name}"
                        wrapped_any = True

        if wrapped_any:
            print(msg)

    return model


def wrap_vit_with_batchensemble(
    model,
    be_scope="attn_mlp",
    n_members=4,
    be_init_std=0.02,
):
    """
    Wrap a ViT model with BatchEnsemble layers.

    BatchEnsemble uses rank-1 multiplicative perturbations:
    W_m = W ⊙ (r_m ⊗ s_m)

    Unlike LoRA (additive), BatchEnsemble modifies weights multiplicatively
    and trains the shared base weights W as well.
    """
    def make_be_linear(base_linear):
        return EnsembleBatchEnsembleLinear(
            base_linear,
            n_members=n_members,
            init_std=be_init_std,
        )

    blocks = get_transformer_blocks(model)

    if not blocks:
        print("Warning: No blocks found for BatchEnsemble wrapping")
        return model

    for i, block in enumerate(blocks):
        msg = f"[INFO] Block {i}:"
        wrapped_any = False

        # Wrap attention layers (if scope includes attn)
        if be_scope in ["attn", "attn_mlp"]:
            attn = None
            if hasattr(block, 'attn'):
                attn = block.attn
            elif hasattr(block, 'self_attn'):
                attn = block.self_attn

            if attn is not None:
                if hasattr(attn, 'qkv'):
                    attn.qkv = make_be_linear(attn.qkv)
                    msg += " qkv"
                    wrapped_any = True
                else:
                    for proj_name in ['q_proj', 'k_proj', 'v_proj', 'q', 'k', 'v']:
                        if hasattr(attn, proj_name):
                            setattr(attn, proj_name, make_be_linear(getattr(attn, proj_name)))
                            wrapped_any = True
                    if wrapped_any:
                        msg += " Q/K/V"

                for proj_name in ['proj', 'out_proj', 'o_proj']:
                    if hasattr(attn, proj_name):
                        setattr(attn, proj_name, make_be_linear(getattr(attn, proj_name)))
                        msg += f" {proj_name}"
                        wrapped_any = True
                        break

        # Wrap MLP layers (if scope includes mlp)
        if be_scope in ["mlp", "attn_mlp"]:
            mlp = None
            if hasattr(block, 'mlp'):
                mlp = block.mlp
            elif hasattr(block, 'ffn'):
                mlp = block.ffn

            if mlp is not None:
                for fc_name in ['fc1', 'fc2', 'c_fc', 'c_proj', 'wi', 'wo']:
                    if hasattr(mlp, fc_name):
                        setattr(mlp, fc_name, make_be_linear(getattr(mlp, fc_name)))
                        msg += f" {fc_name}"
                        wrapped_any = True

        if wrapped_any:
            print(msg)

    return model


# Full Ensemble Model Wrappers
class ViTSVFImplicitEnsemble(nn.Module):
    """ViT with SVF implicit ensemble."""
    def __init__(
        self,
        backbone_name: str,
        num_classes: int,
        pretrained: bool = True,
        n_members: int = 4,
        topk: int = None,
        svf_scope: str = "attn_mlp",
        svf_init_mean: float = 0.0,
        svf_init_std: float = 0.01,
        head_init_mean: float = 0.0,
        head_init_std: float = 0.01,
        use_mlp_head: bool = False,
        mlp_hidden_dim: int = None,
        mlp_dropout: float = 0.1,
        mode: str = "ft",
        img_size: int = None,
    ):
        super().__init__()

        self.n_members = n_members
        self.num_classes = num_classes
        self.mode = mode

        print(f"Building SVF ensemble model {backbone_name}, pretrained={pretrained}")
        self.backbone, in_features = create_backbone(backbone_name, pretrained, img_size)

        if mode == "ft" and n_members > 1:
            print(f"Wrapping with SVF ensemble (n_members={n_members})")
            wrap_vit_with_svf_ensemble(
                self.backbone,
                topk=topk,
                svf_scope=svf_scope,
                n_members=n_members,
                svf_init_mean=svf_init_mean,
                svf_init_std=svf_init_std,
            )
        elif mode == "ft" and n_members == 1:
            print("Wrapping with single SVF")
            wrap_vit_with_svf(self.backbone, topk=topk, svf_scope=svf_scope)

        if n_members > 1:
            if use_mlp_head:
                print(f"Using 3-layer MLP ensemble head")
                self.head = EnsembleMLPClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    hidden_dim=mlp_hidden_dim,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                    dropout=mlp_dropout,
                )
            else:
                print("Using linear ensemble head")
                self.head = EnsembleClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                )
        else:
            print("Using single linear head")
            self.head = nn.Linear(in_features, num_classes)

        self._setup_trainable_params()

    def _setup_trainable_params(self):
        if self.mode == "lp":
            for p in self.backbone.parameters():
                p.requires_grad = False
            for p in self.head.parameters():
                p.requires_grad = True
        else:
            for p in self.backbone.parameters():
                p.requires_grad = False

            for module in self.backbone.modules():
                if isinstance(module, (SVFLinear, EnsembleSVFLinear)):
                    for p in module.parameters():
                        p.requires_grad = True

            for p in self.head.parameters():
                p.requires_grad = True

    def forward(self, x, labels=None):
        M = self.n_members
        B = x.size(0)

        if M > 1:
            x_rep = x.repeat_interleave(M, dim=0)
        else:
            x_rep = x

        features = self.backbone(x_rep)
        logits_all = self.head(features)

        if M > 1:
            logits_members = logits_all.view(B, M, self.num_classes)
            logits_mean = logits_members.mean(dim=1)
        else:
            logits_members = logits_all.unsqueeze(1)
            logits_mean = logits_all

        output = {"logits": logits_mean, "logits_members": logits_members}

        if labels is not None:
            if M > 1:
                logits_flat = logits_members.view(B * M, self.num_classes)
                labels_rep = labels.unsqueeze(1).repeat(1, M).view(-1)
            else:
                logits_flat = logits_all
                labels_rep = labels

            loss = F.cross_entropy(logits_flat, labels_rep)
            output["loss"] = loss

        return output


class ViTLoRAImplicitEnsemble(nn.Module):
    """ViT with LoRA implicit ensemble."""
    def __init__(
        self,
        backbone_name: str,
        num_classes: int,
        pretrained: bool = True,
        n_members: int = 4,
        lora_r: int = 16,
        lora_alpha: float = 32.0,
        lora_dropout: float = 0.0,
        lora_init_std: float = 0.02,
        lora_scope: str = "attn_mlp",
        head_init_mean: float = 0.0,
        head_init_std: float = 0.01,
        use_mlp_head: bool = False,
        mlp_hidden_dim: int = None,
        mlp_dropout: float = 0.1,
        mode: str = "ft",
        img_size: int = None,
    ):
        super().__init__()

        self.n_members = n_members
        self.num_classes = num_classes
        self.mode = mode
        self.lora_r = lora_r

        print(f"Building LoRA ensemble model {backbone_name}, pretrained={pretrained}")
        self.backbone, in_features = create_backbone(backbone_name, pretrained, img_size)

        if mode == "ft" and n_members > 1:
            print(f"Wrapping with LoRA ensemble (n_members={n_members}, r={lora_r})")
            wrap_vit_with_lora_ensemble(
                self.backbone,
                lora_scope=lora_scope,
                n_members=n_members,
                lora_r=lora_r,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                lora_init_std=lora_init_std,
            )

        if n_members > 1:
            if use_mlp_head:
                print(f"Using 3-layer MLP ensemble head")
                self.head = EnsembleMLPClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    hidden_dim=mlp_hidden_dim,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                    dropout=mlp_dropout,
                )
            else:
                print("Using linear ensemble head")
                self.head = EnsembleClassifierHead(
                    in_features=in_features,
                    out_features=num_classes,
                    n_members=n_members,
                    head_init_mean=head_init_mean,
                    head_init_std=head_init_std,
                )
        else:
            print("Using single linear head")
            self.head = nn.Linear(in_features, num_classes)

        self._setup_trainable_params()

    def _setup_trainable_params(self):
        if self.mode == "lp":
            for p in self.backbone.parameters():
                p.requires_grad = False
            for p in self.head.parameters():
                p.requires_grad = True
        else:
            # Freeze all backbone params first
            for p in self.backbone.parameters():
                p.requires_grad = False

            # Unfreeze LoRA parameters
            for module in self.backbone.modules():
                if isinstance(module, EnsembleLoRALinear):
                    module.lora_A.requires_grad = True
                    module.lora_B.requires_grad = True

            # Unfreeze head
            for p in self.head.parameters():
                p.requires_grad = True

    def forward(self, x, labels=None):
        M = self.n_members
        B = x.size(0)

        if M > 1:
            x_rep = x.repeat_interleave(M, dim=0)
        else:
            x_rep = x

        features = self.backbone(x_rep)
        logits_all = self.head(features)

        if M > 1:
            logits_members = logits_all.view(B, M, self.num_classes)
            logits_mean = logits_members.mean(dim=1)
        else:
            logits_members = logits_all.unsqueeze(1)
            logits_mean = logits_all

        output = {"logits": logits_mean, "logits_members": logits_members}

        if labels is not None:
            if M > 1:
                logits_flat = logits_members.view(B * M, self.num_classes)
                labels_rep = labels.unsqueeze(1).repeat(1, M).view(-1)
            else:
                logits_flat = logits_all
                labels_rep = labels

            loss = F.cross_entropy(logits_flat, labels_rep)
            output["loss"] = loss

        return output


# Parameter Logging
def log_trainable_parameters(model, verbose=True):
    """Pretty logging of trainable parameters."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
    trainable_count = sum(p.numel() for _, p in trainable)

    print("\n" + "=" * 80)
    print("📊 MODEL PARAMETER SUMMARY")
    print("=" * 80)
    print(f"Total parameters:      {total_params:,}")
    print(f"Trainable parameters:  {trainable_count:,}")
    print(f"Frozen parameters:     {total_params - trainable_count:,}")
    print(f"Trainable ratio:       {100 * trainable_count / total_params:.2f}%")
    print("=" * 80)

    param_to_group = {}
    for module_name, module in model.named_modules():
        group = None
        if isinstance(module, EnsembleSVFLinear):
            group = "EnsembleSVFLinear"
        elif isinstance(module, SVFLinear):
            group = "SVFLinear"
        elif isinstance(module, EnsembleLoRALinear):
            group = "EnsembleLoRALinear"
        elif isinstance(module, EnsembleMLPClassifierHead):
            group = "EnsembleMLPClassifierHead"
        elif isinstance(module, EnsembleClassifierHead):
            group = "EnsembleClassifierHead"

        if group is not None:
            for pname, p in module.named_parameters(recurse=False):
                full_name = f"{module_name}.{pname}" if module_name else pname
                param_to_group[full_name] = group

    groups = defaultdict(list)
    for name, p in trainable:
        if name in param_to_group:
            g = param_to_group[name]
        elif "head" in name:
            g = "ClassifierHead"
        else:
            g = "Other"
        groups[g].append((name, p))

    print("\n📦 TRAINABLE PARAMETER GROUPS")
    print("-" * 80)
    for gname, params in groups.items():
        n = sum(p.numel() for _, p in params)
        print(f"{gname:30s}: {n:,} parameters in {len(params)} tensors")
    print("-" * 80)

    if verbose and len(trainable) <= 50:
        print("\n🔍 DETAILED TRAINABLE PARAMETERS")
        print("-" * 80)
        for name, p in trainable:
            shape_str = str(list(p.shape))
            print(f"{name:60s}  shape={shape_str:20s}  count={p.numel():,}")
        print("-" * 80)


# Data
def maybe_subsample(dataset, fraction, seed):
    if fraction >= 1.0:
        return dataset

    n = len(dataset)
    k = int(n * fraction)
    print(f"Using only {k}/{n} training samples ({fraction:.2f} fraction).")

    rng = random.Random(seed)
    indices = list(range(n))
    rng.shuffle(indices)
    indices = indices[:k]

    return Subset(dataset, indices)


def split_train_val(dataset, val_fraction, seed):
    n = len(dataset)
    val_size = int(n * val_fraction)
    train_size = n - val_size

    rng = random.Random(seed)
    indices = list(range(n))
    rng.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:]

    return Subset(dataset, train_indices), Subset(dataset, val_indices)


def create_dataloaders(args):
    """Create data loaders for the specified dataset."""

    image_size = get_input_size_for_backbone(args.backbone, args.image_size)
    print(f"Using image size: {image_size}x{image_size}")

    train_tf, eval_tf = build_transforms(image_size)

    dataset_config = AVAILABLE_DATASETS[args.dataset]
    dataset_class = dataset_config["dataset_class"]
    num_classes = dataset_config["num_classes"]

    print(f"Loading dataset: {dataset_config['description']} ({num_classes} classes)")

    if args.dataset == "flowers102":
        train_ds = Flowers102(root=args.data_root, split="train", download=True, transform=train_tf)
        val_ds = Flowers102(root=args.data_root, split="val", download=True, transform=eval_tf)
        test_ds = Flowers102(root=args.data_root, split="test", download=True, transform=eval_tf)

    elif args.dataset == "cars":
        full_train_ds = StanfordCarsLocal(root=args.data_root, split="train", transform=train_tf)
        test_ds = StanfordCarsLocal(root=args.data_root, split="test", transform=eval_tf)

        train_ds, val_ds_tmp = split_train_val(full_train_ds, args.val_fraction, args.seed)

        full_train_ds_eval = StanfordCarsLocal(root=args.data_root, split="train", transform=eval_tf)
        val_indices = val_ds_tmp.indices
        val_ds = Subset(full_train_ds_eval, val_indices)

    elif args.dataset == "aircraft":
        train_ds = FGVCAircraft(root=args.data_root, split="train", annotation_level="variant", download=True, transform=train_tf)
        val_ds = FGVCAircraft(root=args.data_root, split="val", annotation_level="variant", download=True, transform=eval_tf)
        test_ds = FGVCAircraft(root=args.data_root, split="test", annotation_level="variant", download=True, transform=eval_tf)

    elif args.dataset == "food101":
        full_train_ds = Food101(root=args.data_root, split="train", download=True, transform=train_tf)
        test_ds = Food101(root=args.data_root, split="test", download=True, transform=eval_tf)

        train_ds, val_ds_tmp = split_train_val(full_train_ds, args.val_fraction, args.seed)

        full_train_ds_eval = Food101(root=args.data_root, split="train", download=False, transform=eval_tf)
        val_indices = val_ds_tmp.indices
        val_ds = Subset(full_train_ds_eval, val_indices)

    elif args.dataset == "cifar100":
        full_train_ds = CIFAR100(root=args.data_root, train=True, download=True, transform=train_tf)
        test_ds = CIFAR100(root=args.data_root, train=False, download=True, transform=eval_tf)

        train_ds, val_ds_tmp = split_train_val(full_train_ds, args.val_fraction, args.seed)

        full_train_ds_eval = CIFAR100(root=args.data_root, train=True, download=False, transform=eval_tf)
        val_indices = val_ds_tmp.indices
        val_ds = Subset(full_train_ds_eval, val_indices)

    elif args.dataset == "dtd":
        train_ds = DTD(root=args.data_root, split="train", download=True, transform=train_tf)
        val_ds = DTD(root=args.data_root, split="val", download=True, transform=eval_tf)
        test_ds = DTD(root=args.data_root, split="test", download=True, transform=eval_tf)

    elif args.dataset == "pets":
        full_train_ds = OxfordIIITPet(root=args.data_root, split="trainval", download=True, transform=train_tf)
        test_ds = OxfordIIITPet(root=args.data_root, split="test", download=True, transform=eval_tf)

        train_ds, val_ds_tmp = split_train_val(full_train_ds, args.val_fraction, args.seed)

        full_train_ds_eval = OxfordIIITPet(root=args.data_root, split="trainval", download=False, transform=eval_tf)
        val_indices = val_ds_tmp.indices
        val_ds = Subset(full_train_ds_eval, val_indices)

    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")

    train_ds = maybe_subsample(train_ds, args.label_fraction, args.seed)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

    return train_loader, val_loader, test_loader, num_classes


def create_external_ood_loaders(args):
    """
    External OOD datasets (always):
      - CIFAR-10 test
      - SVHN test

    Uses the SAME eval transform (resize + normalize) as the current backbone/image_size.
    """
    image_size = get_input_size_for_backbone(args.backbone, args.image_size)
    _, eval_tf = build_transforms(image_size)

    bs = args.ood_batch_size if args.ood_batch_size is not None else args.batch_size
    nw = args.ood_num_workers if args.ood_num_workers is not None else args.num_workers

    cifar10_ds = CIFAR10(root=args.data_root, train=False, download=True, transform=eval_tf)
    svhn_ds = SVHN(root=args.data_root, split="test", download=True, transform=eval_tf)

    cifar10_loader = DataLoader(
        cifar10_ds, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=True
    )
    svhn_loader = DataLoader(
        svhn_ds, batch_size=bs, shuffle=False, num_workers=nw, pin_memory=True
    )

    return {
        "cifar10": cifar10_loader,
        "svhn": svhn_loader,
    }

# Training + Evaluation
def train_one_epoch(model, loader, optimizer, device, scaler=None, grad_clip=1.0):
    model.train()
    total, correct, loss_sum = 0, 0, 0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        if scaler:
            with torch.cuda.amp.autocast():
                output = model(imgs, labels=labels)
                loss = output["loss"]
            scaler.scale(loss).backward()
            if grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(imgs, labels=labels)
            loss = output["loss"]
            loss.backward()
            if grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        logits = output["logits"]
        loss_sum += loss.item() * imgs.size(0)
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

    return loss_sum / total, correct / total


@torch.no_grad()
def evaluate(model, loader, device, return_calibration=False, return_logits=False):
    model.eval()
    total, correct, loss_sum = 0, 0, 0

    all_probs = []
    all_logits = []
    all_labels = []

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)

        output = model(imgs, labels=labels)
        loss = output["loss"]
        logits = output["logits"]

        loss_sum += loss.item() * imgs.size(0)
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)

        if return_calibration or return_logits:
            all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())
            if return_calibration:
                probs = F.softmax(logits, dim=1)
                all_probs.append(probs.cpu())

    avg_loss = loss_sum / total
    accuracy = correct / total

    if return_logits:
        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        return avg_loss, accuracy, all_logits, all_labels

    if not return_calibration:
        return avg_loss, accuracy

    all_probs = torch.cat(all_probs, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    log_probs = torch.log(all_probs + 1e-10)
    nll = F.nll_loss(log_probs, all_labels, reduction="mean").item()

    n_classes = all_probs.shape[1]
    one_hot = torch.eye(n_classes)[all_labels]
    brier = torch.mean(torch.sum((all_probs - one_hot) ** 2, dim=1)).item()

    ece = compute_ece(all_probs, all_labels)

    return avg_loss, accuracy, nll, brier, ece


def compute_ece(probs, labels, n_bins=15):
    """Expected Calibration Error."""
    confidences, predictions = torch.max(probs, dim=1)
    accuracies = predictions.eq(labels)

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    ece = 0.0
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()

        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece.item()


def calibrate_temperature(logits, labels, lr=0.01, max_iter=50):
    temperature = nn.Parameter(torch.ones(1, device=logits.device))
    optimizer = torch.optim.LBFGS([temperature], lr=lr, max_iter=max_iter)

    def eval_loss():
        optimizer.zero_grad()
        scaled_logits = logits / temperature
        loss = F.cross_entropy(scaled_logits, labels)
        loss.backward()
        return loss

    optimizer.step(eval_loss)

    return temperature.item()


def compute_calibrated_metrics(logits, labels, temperature):
    scaled_logits = logits / temperature
    probs = F.softmax(scaled_logits, dim=1)

    _, preds = probs.max(1)
    accuracy = (preds == labels).float().mean().item()

    log_probs = torch.log(probs + 1e-10)
    nll = F.nll_loss(log_probs, labels, reduction="mean").item()

    n_classes = probs.shape[1]
    one_hot = torch.eye(n_classes, device=probs.device)[labels]
    brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()

    ece = compute_ece(probs.cpu(), labels.cpu())

    return {
        "accuracy": accuracy,
        "nll": nll, 
        "brier": brier,
        "ece": ece,
        "temperature": temperature,
    }


def build_warmup_cosine_scheduler(optimizer, epochs, warmup_epochs):
    def lr_lambda(epoch):
        if epochs <= 1:
            return 1.0
        if epoch < warmup_epochs and warmup_epochs > 0:
            return float(epoch + 1) / float(warmup_epochs)
        progress = float(epoch - warmup_epochs) / float(max(1, epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)


@torch.no_grad()
def get_logits_for_scoring(model_or_models, imgs, args):
    """
    Returns logits [B, C] for scoring across methods.
    - deep_ensemble: average logits over members
    - mc_dropout: forward_mc mean logits
    - single/svf/lora: model(x)["logits"]
    """
    if args.method == "deep_ensemble":
        logits_list = []
        for m in model_or_models:
            out = m(imgs)
            logits_list.append(out["logits"])
        return torch.stack(logits_list, dim=0).mean(dim=0)

    if args.method == "mc_dropout":
        out = model_or_models.forward_mc(imgs, n_samples=args.mc_samples)
        return out["logits"]

    out = model_or_models(imgs)
    return out["logits"]


@torch.no_grad()
def evaluate_ood_msp(model_or_models, id_loader, ood_loader, args):
    """
    Option A convention:
      - labels: ID=1 (positive), OOD=0 (negative)
      - score: MSP = max softmax prob (higher => more ID-like)

    FPR@95TPR (ID-positive):
      - pick threshold tau such that TPR_ID = 95% (i.e., 95% of ID have score >= tau)
      - report FPR_OOD at that tau (fraction of OOD with score >= tau)
    """
    device = args.device
    scores = []
    labels = []

    def collect(loader, is_id):
        for imgs, _ in loader:
            imgs = imgs.to(device)
            logits = get_logits_for_scoring(model_or_models, imgs, args)
            probs = F.softmax(logits, dim=1)
            msp = probs.max(dim=1).values  # ID-likeness
            scores.append(msp.detach().cpu().numpy())
            labels.append(np.full((imgs.size(0),), 1 if is_id else 0, dtype=np.int64))

    collect(id_loader, is_id=True)
    collect(ood_loader, is_id=False)

    scores_np = np.concatenate(scores)
    labels_np = np.concatenate(labels)

    # Standard metrics for this convention (ID is positive)
    auroc = roc_auc_score(labels_np, scores_np)
    prec, rec, _ = precision_recall_curve(labels_np, scores_np)
    auprc = auc(rec, prec)

    # FPR@95TPR (ID-positive, MSP higher => ID)
    id_scores = scores_np[labels_np == 1]
    ood_scores = scores_np[labels_np == 0]

    target_tpr = 0.95
    # Need tau such that P(score >= tau | ID) = 0.95
    # => tau is the (1 - 0.95) = 5th percentile of ID scores
    tau = float(np.quantile(id_scores, 1.0 - target_tpr)) if id_scores.size > 0 else float("nan")

    # FPR: fraction of OOD incorrectly accepted as ID (score >= tau)
    fpr95 = float(np.mean(ood_scores >= tau)) if ood_scores.size > 0 else float("nan")

    return {
        "auroc": float(auroc),
        "auprc": float(auprc),
        "fpr95": fpr95,              # Option A meaning: OOD accepted-as-ID rate at 95% ID recall
    }


@torch.no_grad()
def compute_metrics_from_logits(logits_cpu: torch.Tensor, labels_cpu: torch.Tensor):
    """
    logits_cpu: [N, C] on CPU
    labels_cpu: [N] on CPU
    Returns accuracy, nll, brier, ece (uncalibrated).
    """
    probs = F.softmax(logits_cpu, dim=1)
    preds = probs.argmax(dim=1)
    accuracy = (preds == labels_cpu).float().mean().item()

    log_probs = torch.log(probs + 1e-10)
    nll = F.nll_loss(log_probs, labels_cpu, reduction="mean").item()

    n_classes = probs.shape[1]
    one_hot = torch.eye(n_classes)[labels_cpu]
    brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()

    ece = compute_ece(probs, labels_cpu)

    ce_loss = F.cross_entropy(logits_cpu, labels_cpu, reduction="mean").item()

    return {
        "loss": ce_loss,
        "accuracy": accuracy,
        "nll": nll,
        "brier": brier,
        "ece": ece,
    }

# Training Functions
def train_single_model(model, train_loader, val_loader, optimizer, scheduler, args, scaler, model_idx=None):
    best_val_acc = 0.0

    prefix = f"[Model {model_idx}] " if model_idx is not None else ""
    grad_clip = args.grad_clip

    for epoch in range(args.epochs):
        model.train()
        total, correct, loss_sum = 0, 0, 0

        for imgs, labels in train_loader:
            imgs, labels = imgs.to(args.device), labels.to(args.device)
            optimizer.zero_grad()

            if scaler:
                with torch.cuda.amp.autocast():
                    output = model(imgs, labels=labels)
                    loss = output["loss"]
                scaler.scale(loss).backward()
                if grad_clip > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                output = model(imgs, labels=labels)
                loss = output["loss"]
                loss.backward()
                if grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                optimizer.step()

            logits = output["logits"]
            loss_sum += loss.item() * imgs.size(0)
            _, pred = logits.max(1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)

        train_loss = loss_sum / total
        train_acc = correct / total

        model.eval()
        val_total, val_correct, val_loss_sum = 0, 0, 0

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(args.device), labels.to(args.device)
                output = model(imgs, labels=labels)

                val_loss_sum += output["loss"].item() * imgs.size(0)
                _, pred = output["logits"].max(1)
                val_correct += (pred == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss_sum / val_total
        val_acc = val_correct / val_total

        scheduler.step()

        current_lrs = [f"{group['lr']:.2e}" for group in optimizer.param_groups]
        print(
            f"{prefix}Epoch {epoch + 1:3d}/{args.epochs} | "
            f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
            f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
            f"lrs={current_lrs}"
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc

    return best_val_acc


def evaluate_deep_ensemble(models, loader, device, val_loader=None):
    for model in models:
        model.eval()

    all_logits_ensemble = []
    all_logits_individual = [[] for _ in models]
    all_labels = []

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            batch_logits = []
            for i, model in enumerate(models):
                output = model(imgs)
                logits = output["logits"]
                batch_logits.append(logits)
                all_logits_individual[i].append(logits.cpu())

            stacked = torch.stack(batch_logits, dim=0)
            ensemble_logits = stacked.mean(dim=0)

            all_logits_ensemble.append(ensemble_logits.cpu())
            all_labels.append(labels.cpu())

    all_logits_ensemble = torch.cat(all_logits_ensemble, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    probs = F.softmax(all_logits_ensemble, dim=1)
    _, preds = probs.max(1)
    accuracy = (preds == all_labels).float().mean().item()

    log_probs = torch.log(probs + 1e-10)
    nll = F.nll_loss(log_probs, all_labels, reduction="mean").item()

    n_classes = probs.shape[1]
    one_hot = torch.eye(n_classes)[all_labels]
    brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()

    ece = compute_ece(probs, all_labels)

    individual_results = []
    for i, logits_list in enumerate(all_logits_individual):
        logits_i = torch.cat(logits_list, dim=0)
        probs_i = F.softmax(logits_i, dim=1)
        _, preds_i = probs_i.max(1)
        acc_i = (preds_i == all_labels).float().mean().item()

        log_probs_i = torch.log(probs_i + 1e-10)
        nll_i = F.nll_loss(log_probs_i, all_labels, reduction="mean").item()

        one_hot_i = torch.eye(n_classes)[all_labels]
        brier_i = torch.mean(torch.sum((probs_i - one_hot_i) ** 2, dim=1)).item()

        ece_i = compute_ece(probs_i, all_labels)

        individual_results.append({
            "accuracy": acc_i,
            "nll": nll_i,
            "brier": brier_i,
            "ece": ece_i,
        })

    result = {
        "ensemble": {
            "accuracy": accuracy,
            "nll": nll,
            "brier": brier,
            "ece": ece,
        },
        "individual": individual_results,
        "ensemble_logits": all_logits_ensemble,
        "labels": all_labels,
    }

    if val_loader is not None:
        val_logits_ensemble = []
        val_labels = []

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                batch_logits = []
                for model in models:
                    output = model(imgs)
                    batch_logits.append(output["logits"])
                stacked = torch.stack(batch_logits, dim=0)
                ensemble_logits = stacked.mean(dim=0)
                val_logits_ensemble.append(ensemble_logits.cpu())
                val_labels.append(labels.cpu())

        val_logits_ensemble = torch.cat(val_logits_ensemble, dim=0)
        val_labels = torch.cat(val_labels, dim=0)

        optimal_temp = calibrate_temperature(val_logits_ensemble, val_labels)
        calibrated = compute_calibrated_metrics(all_logits_ensemble, all_labels, optimal_temp)
        result["calibrated"] = calibrated

    return result


def evaluate_mc_dropout(model, loader, device, n_samples, val_loader=None):
    all_logits_mean = []
    all_logits_single = []
    all_labels = []

    model.train()
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            output = model.forward_mc(imgs, n_samples=n_samples)
            all_logits_mean.append(output["logits"].cpu())
            all_labels.append(labels.cpu())

    model.eval()
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            output = model(imgs)
            all_logits_single.append(output["logits"].cpu())

    all_logits_mean = torch.cat(all_logits_mean, dim=0)
    all_logits_single = torch.cat(all_logits_single, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    def compute_metrics(logits, labels):
        probs = F.softmax(logits, dim=1)
        _, preds = probs.max(1)
        accuracy = (preds == labels).float().mean().item()

        log_probs = torch.log(probs + 1e-10)
        nll = F.nll_loss(log_probs, labels, reduction="mean").item()

        n_classes = probs.shape[1]
        one_hot = torch.eye(n_classes)[labels]
        brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()

        ece = compute_ece(probs, labels)

        return {"accuracy": accuracy, "nll": nll, "brier": brier, "ece": ece}

    result = {
        "mc_dropout": compute_metrics(all_logits_mean, all_labels),
        "single_pass": compute_metrics(all_logits_single, all_labels),
        "mc_logits": all_logits_mean,
        "labels": all_labels,
    }

    if val_loader is not None:
        val_logits = []
        val_labels = []

        model.train()
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                output = model.forward_mc(imgs, n_samples=n_samples)
                val_logits.append(output["logits"].cpu())
                val_labels.append(labels.cpu())

        val_logits = torch.cat(val_logits, dim=0)
        val_labels = torch.cat(val_labels, dim=0)

        optimal_temp = calibrate_temperature(val_logits, val_labels)
        calibrated = compute_calibrated_metrics(all_logits_mean, all_labels, optimal_temp)
        result["calibrated"] = calibrated

    return result


@torch.no_grad()
def evaluate_cifar100c_by_severity(model_or_models, cifar100c_loaders, args):
    """
    Evaluates on CIFAR-100-C severities 1..5 separately.
    Returns dict severity->metrics.
    """
    device = args.device
    results = {}

    # Set model modes appropriately
    if args.method == "mc_dropout":
        # forward_mc will enable dropout sampling; that's intended for MC-D.
        pass
    elif args.method == "deep_ensemble":
        for m in model_or_models:
            m.eval()
    else:
        model_or_models.eval()

    for sev, loader in cifar100c_loaders.items():
        all_logits = []
        all_labels = []

        for imgs, labels in loader:
            imgs = imgs.to(device)
            logits = get_logits_for_scoring(model_or_models, imgs, args)  # [B, C]
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())

        logits_cpu = torch.cat(all_logits, dim=0)
        labels_cpu = torch.cat(all_labels, dim=0)

        results[sev] = compute_metrics_from_logits(logits_cpu, labels_cpu)

    return results


# Main
def main():
    args = parse_args()

    if args.list_backbones:
        print_available_backbones()
        return

    if args.list_datasets:
        print_available_datasets()
        return

    # Set default learning rate based on method
    if args.lr is None:
        if args.method == "svf":
            args.lr = 1e-3
        else:  # single, lora, deep_ensemble, mc_dropout
            args.lr = 1e-4

    # Setup logging if requested
    tee_logger = None
    log_filepath = None
    exp_name = None
    if args.save_log:
        log_filepath, tee_logger, exp_name = setup_logging(args)
        args.exp_name = exp_name

    set_seed(args.seed)

    if args.device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, switching to cpu")
        args.device = "cpu"

    print(f"\n{'='*80}")
    print(f"🚀 DINO Ensemble Methods Comparison")
    print(f"{'='*80}")
    print(f"Dataset: {args.dataset}")
    print(f"Method: {args.method}")
    print(f"Backbone: {args.backbone}")
    print(f"Mode: {args.mode}")

    if args.method == "svf":
        print(f"N members: {args.n_members}")
        print(f"SVF scope: {args.svf_scope}")
        print(f"Top-k: {args.topk}")
    elif args.method == "lora":
        print(f"N members: {args.n_members}")
        print(f"LoRA rank: {args.lora_r}")
        print(f"LoRA alpha: {args.lora_alpha}")
        print(f"LoRA scope: {args.lora_scope}")
    elif args.method == "deep_ensemble":
        print(f"N members: {args.n_members}")
    elif args.method == "mc_dropout":
        print(f"MC samples: {args.mc_samples}")
        print(f"Dropout rate: {args.mc_dropout_rate}")

    print(f"Learning rate: {args.lr}")
    print(f"Epochs: {args.epochs}")
    print(f"Batch size: {args.batch_size}")
    print(f"Weight decay: {args.weight_decay}")
    print(f"Device: {args.device}")
    print(f"Seed: {args.seed}")
    if args.save_log:
        print(f"Log file: {log_filepath}")
    print(f"{'='*80}\n")

    train_loader, val_loader, test_loader, num_classes = create_dataloaders(args)

    image_size = get_input_size_for_backbone(args.backbone, args.image_size)

    scaler = torch.cuda.amp.GradScaler(enabled=args.amp and args.device == "cuda")

    # Initialize results dict for logging
    results_dict = {}

    # Single Model Baseline
    if args.method == "single":
        print("\n" + "=" * 80)
        print("📈 TRAINING SINGLE MODEL BASELINE")
        print("=" * 80)

        reset_memory_stats()
        training_start_time = time.time()

        model = ViTStandard(
            backbone_name=args.backbone,
            num_classes=num_classes,
            pretrained=(args.init_mode == "pretrained"),
            mode=args.mode,
            img_size=image_size,
        ).to(args.device)

        log_trainable_parameters(model, verbose=args.verbose)

        if args.mode == "lp":
            optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
        else:
            optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

        best_val_acc = train_single_model(
            model, train_loader, val_loader, optimizer, scheduler, args, scaler
        )

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        ckpt_name = get_checkpoint_path(
            f"{args.exp_name}.pth"
        )
        torch.save(model.state_dict(), ckpt_name)
        print(f"\n✅ Best val accuracy: {best_val_acc:.4f}")
        print(f"Saved checkpoint to: {ckpt_name}")

        # Test evaluation with calibration
        print("\n" + "=" * 80)
        print("🧪 TEST EVALUATION")
        print("=" * 80)

        model.load_state_dict(torch.load(ckpt_name, map_location=args.device))

        model.eval()
        val_logits_list, val_labels_list = [], []
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(args.device), labels.to(args.device)
                output = model(imgs)
                val_logits_list.append(output["logits"].cpu())
                val_labels_list.append(labels.cpu())
        val_logits = torch.cat(val_logits_list, dim=0)
        val_labels_tensor = torch.cat(val_labels_list, dim=0)

        all_logits, all_labels = [], []
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(args.device), labels.to(args.device)
                output = model(imgs)
                all_logits.append(output["logits"].cpu())
                all_labels.append(labels.cpu())

        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        probs = F.softmax(all_logits, dim=1)
        _, preds = probs.max(1)
        accuracy = (preds == all_labels).float().mean().item()

        log_probs = torch.log(probs + 1e-10)
        nll = F.nll_loss(log_probs, all_labels, reduction="mean").item()

        n_classes = probs.shape[1]
        one_hot = torch.eye(n_classes)[all_labels]
        brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()
        ece = compute_ece(probs, all_labels)

        print(f"\n📊 RESULTS (uncalibrated):")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  NLL:      {nll:.4f}")
        print(f"  Brier:    {brier:.4f}")
        print(f"  ECE:      {ece:.4f}")

        optimal_temp = calibrate_temperature(val_logits, val_labels_tensor)
        calibrated = compute_calibrated_metrics(all_logits, all_labels, optimal_temp)

        print(f"\n📊 RESULTS (temperature scaled, T={optimal_temp:.3f}):")
        print(f"  Accuracy: {calibrated['accuracy']:.4f}")
        print(f"  NLL:      {calibrated['nll']:.4f}")
        print(f"  Brier:    {calibrated['brier']:.4f}")
        print(f"  ECE:      {calibrated['ece']:.4f}")

        # Compute inference metrics (FLOPs, memory, time) - isolated from training
        inference_metrics = measure_inference_metrics(
            model, (1, 3, image_size, image_size), args.device
        )

        print(f"\n⏱️  TIMING & COMPUTE:")
        print(f"  Training time:        {format_time(training_time)}")
        print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        print(f"  ─────────────────────────────────────")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])}")
        print(f"  Inference memory:     {inference_metrics['inference_memory_mb']:.1f} MB (activations)")
        print(f"  Inference total mem:  {inference_metrics['total_inference_memory_mb']:.1f} MB (model+act)")
        print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (batch=1)")

        # Save results for logging
        results_dict = {
            "method": "single",
            "best_val_accuracy": best_val_acc,
            "uncalibrated": {
                "accuracy": accuracy,
                "nll": nll,
                "brier": brier,
                "ece": ece,
            },
            "calibrated": calibrated,
            "temperature": optimal_temp,
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops": inference_metrics['flops'],
                "inference_memory_mb": inference_metrics['inference_memory_mb'],
                "inference_total_memory_mb": inference_metrics['total_inference_memory_mb'],
                "inference_time_ms": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }

        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print("🧪 CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(model, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res


        # Optional OOD evaluation: test against CIFAR-10 and SVHN
        # ID = in-distribution test set (test_loader)
        if args.ood_eval:
            print("\n" + "=" * 80)
            print("OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            # For deterministic methods, eval mode is appropriate.
            # For MC dropout, forward_mc intentionally uses dropout sampling.
            if args.method != "mc_dropout":
                model.eval()

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(model, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(model, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results


    # SVF Implicit Ensemble
    elif args.method == "svf":
        reset_memory_stats()
        training_start_time = time.time()

        model = ViTSVFImplicitEnsemble(
            backbone_name=args.backbone,
            num_classes=num_classes,
            pretrained=(args.init_mode == "pretrained"),
            n_members=args.n_members,
            topk=args.topk,
            svf_scope=args.svf_scope,
            svf_init_mean=args.svf_init_mean,
            svf_init_std=args.svf_init_std,
            head_init_mean=args.head_init_mean,
            head_init_std=args.head_init_std,
            use_mlp_head=args.use_mlp_head,
            mlp_hidden_dim=args.mlp_hidden_dim,
            mlp_dropout=args.mlp_dropout,
            mode=args.mode,
            img_size=image_size,
        ).to(args.device)

        log_trainable_parameters(model, verbose=args.verbose)

        if args.mode == "lp":
            optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
        else:
            # Collect all trainable parameters (head + SVF params)
            trainable_params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW(
                trainable_params,
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

        best_val_acc = 0.0
        ckpt_name = get_checkpoint_path(f"{args.exp_name}_svf_M{args.n_members}.pth")

        print("\n" + "=" * 80)
        print("📈 TRAINING SVF IMPLICIT ENSEMBLE")
        print("=" * 80)

        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, args.device, scaler, args.grad_clip)
            val_loss, val_acc = evaluate(model, val_loader, args.device)
            scheduler.step()

            current_lrs = [f"{group['lr']:.2e}" for group in optimizer.param_groups]
            print(
                f"Epoch {epoch + 1:3d}/{args.epochs} | "
                f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
                f"lrs={current_lrs}"
            )

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), ckpt_name)

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        print(f"\n Best val accuracy: {best_val_acc:.4f}")
        print(f"Saved checkpoint to: {ckpt_name}")

        # Test evaluation
        print("\n" + "=" * 80)
        print("TEST EVALUATION")
        print("=" * 80)

        model.load_state_dict(torch.load(ckpt_name, map_location=args.device))

        _, _, val_logits, val_labels = evaluate(model, val_loader, args.device, return_logits=True)
        _, _, test_logits, test_labels = evaluate(model, test_loader, args.device, return_logits=True)

        probs = F.softmax(test_logits, dim=1)
        _, preds = probs.max(1)
        accuracy = (preds == test_labels).float().mean().item()
        log_probs = torch.log(probs + 1e-10)
        nll = F.nll_loss(log_probs, test_labels, reduction="mean").item()
        n_classes = probs.shape[1]
        one_hot = torch.eye(n_classes)[test_labels]
        brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()
        ece = compute_ece(probs, test_labels)

        print(f"\n📊 SVF ENSEMBLE RESULTS (uncalibrated):")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  NLL:      {nll:.4f}")
        print(f"  Brier:    {brier:.4f}")
        print(f"  ECE:      {ece:.4f}")

        optimal_temp = calibrate_temperature(val_logits, val_labels)
        calibrated = compute_calibrated_metrics(test_logits, test_labels, optimal_temp)

        print(f"\n📊 SVF ENSEMBLE RESULTS (temperature scaled, T={optimal_temp:.3f}):")
        print(f"  Accuracy: {calibrated['accuracy']:.4f}")
        print(f"  NLL:      {calibrated['nll']:.4f}")
        print(f"  Brier:    {calibrated['brier']:.4f}")
        print(f"  ECE:      {calibrated['ece']:.4f}")

        # Compute inference metrics (FLOPs, memory, time) - isolated from training
        # Note: SVF processes all M members in single forward pass
        inference_metrics = measure_inference_metrics(
            model, (1, 3, image_size, image_size), args.device
        )

        print(f"\n⏱️  TIMING & COMPUTE:")
        print(f"  Training time:        {format_time(training_time)}")
        print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        print(f"  ─────────────────────────────────────")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])} (all {args.n_members} members)")
        print(f"  Inference memory:     {inference_metrics['inference_memory_mb']:.1f} MB (activations)")
        print(f"  Inference total mem:  {inference_metrics['total_inference_memory_mb']:.1f} MB (model+act)")
        print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (batch=1, all members)")

        # Save results for logging
        results_dict = {
            "method": "svf",
            "n_members": args.n_members,
            "best_val_accuracy": best_val_acc,
            "uncalibrated": {
                "accuracy": accuracy,
                "nll": nll,
                "brier": brier,
                "ece": ece,
            },
            "calibrated": calibrated,
            "temperature": optimal_temp,
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops": inference_metrics['flops'],
                "inference_memory_mb": inference_metrics['inference_memory_mb'],
                "inference_total_memory_mb": inference_metrics['total_inference_memory_mb'],
                "inference_time_ms": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }
        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print("🧪 CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(model, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res

        # Optional OOD evaluation: test against CIFAR-10 and SVHN
        # ID = in-distribution test set (test_loader)
        if args.ood_eval:
            print("\n" + "=" * 80)
            print("OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            # For deterministic methods, eval mode is appropriate.
            # For MC dropout, forward_mc intentionally uses dropout sampling.
            if args.method != "mc_dropout":
                model.eval()

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(model, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(model, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results

    # LoRA Implicit Ensemble
    elif args.method == "lora":
        reset_memory_stats()
        training_start_time = time.time()

        model = ViTLoRAImplicitEnsemble(
            backbone_name=args.backbone,
            num_classes=num_classes,
            pretrained=(args.init_mode == "pretrained"),
            n_members=args.n_members,
            lora_r=args.lora_r,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            lora_init_std=args.lora_init_std,
            lora_scope=args.lora_scope,
            head_init_mean=args.head_init_mean,
            head_init_std=args.head_init_std,
            use_mlp_head=args.use_mlp_head,
            mlp_hidden_dim=args.mlp_hidden_dim,
            mlp_dropout=args.mlp_dropout,
            mode=args.mode,
            img_size=image_size,
        ).to(args.device)

        log_trainable_parameters(model, verbose=args.verbose)

        if args.mode == "lp":
            optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
        else:
            # Collect all trainable parameters (head + LoRA params)
            trainable_params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW(
                trainable_params,
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

        best_val_acc = 0.0
        ckpt_name = get_checkpoint_path(f"{args.exp_name}_lora_M{args.n_members}_r{args.lora_r}.pth")

        print("\n" + "=" * 80)
        print("TRAINING LoRA IMPLICIT ENSEMBLE")
        print("=" * 80)

        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, args.device, scaler, args.grad_clip)
            val_loss, val_acc = evaluate(model, val_loader, args.device)
            scheduler.step()

            current_lrs = [f"{group['lr']:.2e}" for group in optimizer.param_groups]
            print(
                f"Epoch {epoch + 1:3d}/{args.epochs} | "
                f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
                f"lrs={current_lrs}"
            )

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), ckpt_name)

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        print(f"\n Best val accuracy: {best_val_acc:.4f}")
        print(f"Saved checkpoint to: {ckpt_name}")

        # Test evaluation
        print("\n" + "=" * 80)
        print("TEST EVALUATION")
        print("=" * 80)

        model.load_state_dict(torch.load(ckpt_name, map_location=args.device))

        _, _, val_logits, val_labels = evaluate(model, val_loader, args.device, return_logits=True)
        _, _, test_logits, test_labels = evaluate(model, test_loader, args.device, return_logits=True)

        probs = F.softmax(test_logits, dim=1)
        _, preds = probs.max(1)
        accuracy = (preds == test_labels).float().mean().item()
        log_probs = torch.log(probs + 1e-10)
        nll = F.nll_loss(log_probs, test_labels, reduction="mean").item()
        n_classes = probs.shape[1]
        one_hot = torch.eye(n_classes)[test_labels]
        brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()
        ece = compute_ece(probs, test_labels)

        print(f"\n LoRA ENSEMBLE RESULTS (uncalibrated):")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  NLL:      {nll:.4f}")
        print(f"  Brier:    {brier:.4f}")
        print(f"  ECE:      {ece:.4f}")

        optimal_temp = calibrate_temperature(val_logits, val_labels)
        calibrated = compute_calibrated_metrics(test_logits, test_labels, optimal_temp)

        print(f"\n LoRA ENSEMBLE RESULTS (temperature scaled, T={optimal_temp:.3f}):")
        print(f"  Accuracy: {calibrated['accuracy']:.4f}")
        print(f"  NLL:      {calibrated['nll']:.4f}")
        print(f"  Brier:    {calibrated['brier']:.4f}")
        print(f"  ECE:      {calibrated['ece']:.4f}")

        # Compute inference metrics (FLOPs, memory, time) - isolated from training
        # Note: LoRA processes all M members in single forward pass
        inference_metrics = measure_inference_metrics(
            model, (1, 3, image_size, image_size), args.device
        )

        print(f"\n TIMING & COMPUTE:")
        print(f"  Training time:        {format_time(training_time)}")
        print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        print(f"  ─────────────────────────────────────")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])} (all {args.n_members} members)")
        print(f"  Inference memory:     {inference_metrics['inference_memory_mb']:.1f} MB (activations)")
        print(f"  Inference total mem:  {inference_metrics['total_inference_memory_mb']:.1f} MB (model+act)")
        print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (batch=1, all members)")

        # Save results for logging
        results_dict = {
            "method": "lora",
            "n_members": args.n_members,
            "lora_r": args.lora_r,
            "lora_alpha": args.lora_alpha,
            "best_val_accuracy": best_val_acc,
            "uncalibrated": {
                "accuracy": accuracy,
                "nll": nll,
                "brier": brier,
                "ece": ece,
            },
            "calibrated": calibrated,
            "temperature": optimal_temp,
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops": inference_metrics['flops'],
                "inference_memory_mb": inference_metrics['inference_memory_mb'],
                "inference_total_memory_mb": inference_metrics['total_inference_memory_mb'],
                "inference_time_ms": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }
        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print(" CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(model, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res

        # Optional OOD evaluation: test against CIFAR-10 and SVHN
        # ID = in-distribution test set (test_loader)
        if args.ood_eval:
            print("\n" + "=" * 80)
            print(" OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            # For deterministic methods, eval mode is appropriate.
            # For MC dropout, forward_mc intentionally uses dropout sampling.
            if args.method != "mc_dropout":
                model.eval()

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(model, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(model, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results

    # Deep Ensemble
    elif args.method == "deep_ensemble":
        print("\n" + "=" * 80)
        print(" TRAINING DEEP ENSEMBLE")
        print("=" * 80)

        reset_memory_stats()
        training_start_time = time.time()

        models = []
        ckpt_names = []

        for m_idx in range(args.n_members):
            member_seed = args.seed + m_idx * 1000
            set_seed(member_seed)

            print(f"\n{'─'*60}")
            print(f"Training member {m_idx + 1}/{args.n_members} (seed={member_seed})")
            print(f"{'─'*60}")

            model = ViTStandard(
                backbone_name=args.backbone,
                num_classes=num_classes,
                pretrained=(args.init_mode == "pretrained"),
                mode=args.mode,
                img_size=image_size,
            ).to(args.device)

            if m_idx == 0:
                log_trainable_parameters(model, verbose=args.verbose)

            if args.mode == "lp":
                optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
            else:
                optimizer = torch.optim.AdamW(
                    model.parameters(),
                    lr=args.lr,
                    weight_decay=args.weight_decay,
                )

            scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

            best_val_acc = train_single_model(
                model, train_loader, val_loader, optimizer, scheduler, args, scaler, model_idx=m_idx + 1
            )

            ckpt_name = get_checkpoint_path(f"{args.exp_name}_deep_ensemble_member{m_idx}.pth")
            torch.save(model.state_dict(), ckpt_name)
            ckpt_names.append(ckpt_name)
            models.append(model)

            print(f"Member {m_idx + 1} best val accuracy: {best_val_acc:.4f}")

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        # Test evaluation
        print("\n" + "=" * 80)
        print("🧪 TEST EVALUATION")
        print("=" * 80)

        for i, (model, ckpt_name) in enumerate(zip(models, ckpt_names)):
            model.load_state_dict(torch.load(ckpt_name, map_location=args.device))

        results = evaluate_deep_ensemble(models, test_loader, args.device, val_loader=val_loader)

        print("\n ENSEMBLE RESULTS (uncalibrated):")
        print(f"  Accuracy: {results['ensemble']['accuracy']:.4f}")
        print(f"  NLL:      {results['ensemble']['nll']:.4f}")
        print(f"  Brier:    {results['ensemble']['brier']:.4f}")
        print(f"  ECE:      {results['ensemble']['ece']:.4f}")

        if "calibrated" in results:
            print(f"\n ENSEMBLE RESULTS (temperature scaled, T={results['calibrated']['temperature']:.3f}):")
            print(f"  Accuracy: {results['calibrated']['accuracy']:.4f}")
            print(f"  NLL:      {results['calibrated']['nll']:.4f}")
            print(f"  Brier:    {results['calibrated']['brier']:.4f}")
            print(f"  ECE:      {results['calibrated']['ece']:.4f}")

        print("\n INDIVIDUAL MODEL RESULTS:")
        for i, ind_res in enumerate(results['individual']):
            print(f"  Model {i + 1}: Acc={ind_res['accuracy']:.4f}, NLL={ind_res['nll']:.4f}, "
                  f"Brier={ind_res['brier']:.4f}, ECE={ind_res['ece']:.4f}")

        avg_acc = np.mean([r['accuracy'] for r in results['individual']])
        print(f"\n AVERAGE SINGLE MODEL: Acc={avg_acc:.4f}")

        # Compute inference metrics for one member (isolated from training)
        inference_metrics = measure_inference_metrics(
            models[0], (1, 3, image_size, image_size), args.device
        )

        print(f"\n  TIMING & COMPUTE:")
        print(f"  Training time:        {format_time(training_time)} (all {args.n_members} members)")
        print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        print(f"  ─────────────────────────────────────")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])} (per member)")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'] * args.n_members)} (all {args.n_members} members)")
        print(f"  Inference memory:     {inference_metrics['inference_memory_mb']:.1f} MB (activations, per member)")
        print(f"  Inference total mem:  {inference_metrics['total_inference_memory_mb']:.1f} MB (model+act, per member)")
        print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (per member, batch=1)")

        # Save results for logging
        results_dict = {
            "method": "deep_ensemble",
            "n_members": args.n_members,
            "ensemble": results["ensemble"],
            "calibrated": results.get("calibrated", {}),
            "individual": results["individual"],
            "avg_individual_accuracy": avg_acc,
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops_per_member": inference_metrics['flops'],
                "inference_flops_ensemble": inference_metrics['flops'] * args.n_members,
                "inference_memory_mb_per_member": inference_metrics['inference_memory_mb'],
                "inference_total_memory_mb_per_member": inference_metrics['total_inference_memory_mb'],
                "inference_time_ms_per_member": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }
        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print(" CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(models, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res

        if args.ood_eval:
            print("\n" + "=" * 80)
            print(" OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            for m in models:
                m.eval()

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(models, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(models, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results

    # BatchEnsemble
    elif args.method == "batch_ensemble":
        reset_memory_stats()
        training_start_time = time.time()

        model = ViTBatchEnsemble(
            backbone_name=args.backbone,
            num_classes=num_classes,
            pretrained=(args.init_mode == "pretrained"),
            n_members=args.n_members,
            be_scope=args.be_scope,
            be_init_std=args.be_init_std,
            head_init_mean=args.head_init_mean,
            head_init_std=args.head_init_std,
            use_mlp_head=args.use_mlp_head,
            mlp_hidden_dim=args.mlp_hidden_dim,
            mlp_dropout=args.mlp_dropout,
            mode=args.mode,
            img_size=image_size,
        ).to(args.device)

        log_trainable_parameters(model, verbose=args.verbose)

        if args.mode == "lp":
            optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
        else:
            # Collect all trainable parameters (shared weights + r/s + head)
            trainable_params = [p for p in model.parameters() if p.requires_grad]
            optimizer = torch.optim.AdamW(
                trainable_params,
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

        best_val_acc = 0.0
        ckpt_name = get_checkpoint_path(f"{args.exp_name}_batch_ensemble_M{args.n_members}.pth")

        print("\n" + "=" * 80)
        print(" TRAINING BATCHENSEMBLE")
        print("=" * 80)

        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, args.device, scaler, args.grad_clip)
            val_loss, val_acc = evaluate(model, val_loader, args.device)
            scheduler.step()

            current_lrs = [f"{group['lr']:.2e}" for group in optimizer.param_groups]
            print(
                f"Epoch {epoch + 1:3d}/{args.epochs} | "
                f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} | "
                f"val_loss={val_loss:.4f} val_acc={val_acc:.4f} | "
                f"lrs={current_lrs}"
            )

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save(model.state_dict(), ckpt_name)

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        print(f"\n Best val accuracy: {best_val_acc:.4f}")

        # Test evaluation
        print("\n" + "=" * 80)
        print("TEST EVALUATION")
        print("=" * 80)

        model.load_state_dict(torch.load(ckpt_name, map_location=args.device))
        model.eval()

        # Collect validation logits for temperature scaling
        val_logits_list, val_labels_list = [], []
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(args.device), labels.to(args.device)
                output = model(imgs)
                val_logits_list.append(output["logits"].cpu())
                val_labels_list.append(labels.cpu())
        val_logits = torch.cat(val_logits_list, dim=0)
        val_labels = torch.cat(val_labels_list, dim=0)

        # Collect test logits
        test_logits_list, test_labels_list = [], []
        with torch.no_grad():
            for imgs, labels in test_loader:
                imgs, labels = imgs.to(args.device), labels.to(args.device)
                output = model(imgs)
                test_logits_list.append(output["logits"].cpu())
                test_labels_list.append(labels.cpu())
        test_logits = torch.cat(test_logits_list, dim=0)
        test_labels = torch.cat(test_labels_list, dim=0)

        # Compute metrics
        probs = F.softmax(test_logits, dim=1)
        _, preds = probs.max(1)
        accuracy = (preds == test_labels).float().mean().item()
        log_probs = torch.log(probs + 1e-10)
        nll = F.nll_loss(log_probs, test_labels, reduction="mean").item()
        n_classes = probs.shape[1]
        one_hot = torch.eye(n_classes)[test_labels]
        brier = torch.mean(torch.sum((probs - one_hot) ** 2, dim=1)).item()
        ece = compute_ece(probs, test_labels)

        print(f"\n BATCHENSEMBLE RESULTS (uncalibrated):")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  NLL:      {nll:.4f}")
        print(f"  Brier:    {brier:.4f}")
        print(f"  ECE:      {ece:.4f}")

        optimal_temp = calibrate_temperature(val_logits, val_labels)
        calibrated = compute_calibrated_metrics(test_logits, test_labels, optimal_temp)

        print(f"\n BATCHENSEMBLE RESULTS (temperature scaled, T={optimal_temp:.3f}):")
        print(f"  Accuracy: {calibrated['accuracy']:.4f}")
        print(f"  NLL:      {calibrated['nll']:.4f}")
        print(f"  Brier:    {calibrated['brier']:.4f}")
        print(f"  ECE:      {calibrated['ece']:.4f}")

        # Compute inference metrics (FLOPs, memory, time) - isolated from training
        # Note: BatchEnsemble processes all M members in single forward pass
        #inference_metrics = measure_inference_metrics(
        #    model, (1, 3, image_size, image_size), args.device
        #)
        #
        #print(f"\n TIMING & COMPUTE:")
        #print(f"  Training time:        {format_time(training_time)}")
        #print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        #print(f"  ─────────────────────────────────────")
        #print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])} (all {args.n_members} members)")
        #print(f"  Model memory:         {inference_metrics['model_memory_mb']:.1f} MB")
        #print(f"  Activation memory:    {inference_metrics['activation_memory_mb']:.1f} MB")
        #print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (batch=1, all members)")

        # dummy inference metrics for logging
        inference_metrics = {
            'flops': 0,
            'model_memory_mb': 0,
            'activation_memory_mb': 0,
            'inference_time_ms': 0,
            'inference_time_std_ms': 0,
        }
        # Save results for logging
        results_dict = {
            "method": "batch_ensemble",
            "n_members": args.n_members,
            "be_scope": args.be_scope,
            "best_val_accuracy": best_val_acc,
            "uncalibrated": {
                "accuracy": accuracy,
                "nll": nll,
                "brier": brier,
                "ece": ece,
            },
            "calibrated": calibrated,
            "temperature": optimal_temp,
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops": inference_metrics['flops'],
                "model_memory_mb": inference_metrics['model_memory_mb'],
                "activation_memory_mb": inference_metrics['activation_memory_mb'],
                "inference_time_ms": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }
        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print("CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(model, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res

        # Optional OOD evaluation: test against CIFAR-10 and SVHN
        # ID = in-distribution test set (test_loader)
        if args.ood_eval:
            print("\n" + "=" * 80)
            print(" OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            # For deterministic methods, eval mode is appropriate.
            # For MC dropout, forward_mc intentionally uses dropout sampling.
            if args.method != "mc_dropout":
                model.eval()

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(model, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(model, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results


    # MC Dropout
    elif args.method == "mc_dropout":
        print("\n" + "=" * 80)
        print(" TRAINING MC DROPOUT MODEL")
        print("=" * 80)

        reset_memory_stats()
        training_start_time = time.time()

        model = ViTMCDropout(
            backbone_name=args.backbone,
            num_classes=num_classes,
            pretrained=(args.init_mode == "pretrained"),
            dropout_rate=args.mc_dropout_rate,
            mode=args.mode,
            img_size=image_size,
        ).to(args.device)

        log_trainable_parameters(model, verbose=args.verbose)

        if args.mode == "lp":
            optimizer = torch.optim.Adam(model.head.parameters(), lr=args.lr)
        else:
            optimizer = torch.optim.AdamW(
                model.parameters(),
                lr=args.lr,
                weight_decay=args.weight_decay,
            )

        scheduler = build_warmup_cosine_scheduler(optimizer, args.epochs, args.warmup_epochs)

        best_val_acc = train_single_model(
            model, train_loader, val_loader, optimizer, scheduler, args, scaler
        )

        training_time = time.time() - training_start_time
        peak_memory_mb = get_gpu_max_memory_mb()

        ckpt_name = get_checkpoint_path(f"{args.exp_name}_mc_dropout.pth")
        torch.save(model.state_dict(), ckpt_name)
        print(f"\n✅ Best val accuracy: {best_val_acc:.4f}")

        # Test evaluation
        print("\n" + "=" * 80)
        print(" TEST EVALUATION")
        print("=" * 80)

        model.load_state_dict(torch.load(ckpt_name, map_location=args.device))
        results = evaluate_mc_dropout(model, test_loader, args.device, args.mc_samples, val_loader=val_loader)

        print(f"\n MC DROPOUT RESULTS ({args.mc_samples} samples, uncalibrated):")
        print(f"  Accuracy: {results['mc_dropout']['accuracy']:.4f}")
        print(f"  NLL:      {results['mc_dropout']['nll']:.4f}")
        print(f"  Brier:    {results['mc_dropout']['brier']:.4f}")
        print(f"  ECE:      {results['mc_dropout']['ece']:.4f}")

        if "calibrated" in results:
            print(f"\n MC DROPOUT RESULTS (temperature scaled, T={results['calibrated']['temperature']:.3f}):")
            print(f"  Accuracy: {results['calibrated']['accuracy']:.4f}")
            print(f"  NLL:      {results['calibrated']['nll']:.4f}")
            print(f"  Brier:    {results['calibrated']['brier']:.4f}")
            print(f"  ECE:      {results['calibrated']['ece']:.4f}")

        print(f"\n SINGLE PASS (no MC) RESULTS:")
        print(f"  Accuracy: {results['single_pass']['accuracy']:.4f}")
        print(f"  NLL:      {results['single_pass']['nll']:.4f}")
        print(f"  Brier:    {results['single_pass']['brier']:.4f}")
        print(f"  ECE:      {results['single_pass']['ece']:.4f}")

        # Compute inference metrics (FLOPs, memory, time) - isolated from training
        inference_metrics = measure_inference_metrics(
            model, (1, 3, image_size, image_size), args.device
        )

        print(f"\n TIMING & COMPUTE:")
        print(f"  Training time:        {format_time(training_time)}")
        print(f"  Training peak memory: {peak_memory_mb:.1f} MB")
        print(f"  ─────────────────────────────────────")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'])} (per MC sample)")
        print(f"  Inference FLOPs:      {format_flops(inference_metrics['flops'] * args.mc_samples)} ({args.mc_samples} samples)")
        print(f"  Inference memory:     {inference_metrics['inference_memory_mb']:.1f} MB (activations, per sample)")
        print(f"  Inference total mem:  {inference_metrics['total_inference_memory_mb']:.1f} MB (model+act, per sample)")
        print(f"  Inference time:       {inference_metrics['inference_time_ms']:.2f} ± {inference_metrics['inference_time_std_ms']:.2f} ms (per sample, batch=1)")

        # Save results for logging
        results_dict = {
            "method": "mc_dropout",
            "mc_samples": args.mc_samples,
            "dropout_rate": args.mc_dropout_rate,
            "best_val_accuracy": best_val_acc,
            "mc_dropout": results["mc_dropout"],
            "calibrated": results.get("calibrated", {}),
            "single_pass": results["single_pass"],
            "timing": {
                "training_time_seconds": training_time,
                "training_peak_memory_mb": peak_memory_mb,
                "inference_flops_per_sample": inference_metrics['flops'],
                "inference_flops_total": inference_metrics['flops'] * args.mc_samples,
                "inference_memory_mb_per_sample": inference_metrics['inference_memory_mb'],
                "inference_total_memory_mb_per_sample": inference_metrics['total_inference_memory_mb'],
                "inference_time_ms_per_sample": inference_metrics['inference_time_ms'],
                "inference_time_std_ms": inference_metrics['inference_time_std_ms'],
            },
        }
        if args.cifar100c_eval:
            print("\n" + "=" * 80)
            print(" CIFAR-100-C EVALUATION (per severity, evaluation only)")
            print("=" * 80)

            cifar100c_loaders = create_cifar100c_loaders(args)
            c100c_res = evaluate_cifar100c_by_severity(model, cifar100c_loaders, args)

            for sev in [1, 2, 3, 4, 5]:
                m = c100c_res[sev]
                print(
                    f"Severity {sev} | "
                    f"loss={m['loss']:.4f} acc={m['accuracy']:.4f} "
                    f"nll={m['nll']:.4f} brier={m['brier']:.4f} ece={m['ece']:.4f}"
                )

            results_dict["cifar100c"] = c100c_res

        if args.ood_eval:
            print("\n" + "=" * 80)
            print(" OOD DETECTION (MSP): ID = current test set, OOD = CIFAR-10 and SVHN")
            print("=" * 80)

            ext_ood = create_external_ood_loaders(args)

            ood_results = {}

            # OOD vs CIFAR-10
            res_c10 = evaluate_ood_msp(model, test_loader, ext_ood["cifar10"], args)
            print(f"OOD=CIFAR-10 | AUROC={res_c10['auroc']:.4f} | AUPRC={res_c10['auprc']:.4f} | FPR95={res_c10['fpr95']:.4f}")
            ood_results["cifar10"] = res_c10

            # OOD vs SVHN
            res_svhn = evaluate_ood_msp(model, test_loader, ext_ood["svhn"], args)
            print(f"OOD=SVHN     | AUROC={res_svhn['auroc']:.4f} | AUPRC={res_svhn['auprc']:.4f} | FPR95={res_svhn['fpr95']:.4f}")
            ood_results["svhn"] = res_svhn

            results_dict["ood_msp"] = ood_results

            print("\n" + "=" * 80)
            print(" DONE")
            print("=" * 80)

    # Save results summary if logging is enabled
    if args.save_log and log_filepath and results_dict:
        save_results_summary(args, results_dict, log_filepath)
        print(f"\n📄 Full log saved to: {log_filepath}")

        # Close the tee logger and restore stdout
        if tee_logger:
            sys.stdout = tee_logger.terminal
            tee_logger.close()


if __name__ == "__main__":
    main()