"""
Tests LOES layer selection for classification head input, with uniform LoRA backbone.

Experiments per model-dataset:
1. LoRA-All + Last Layer: LoRA on all layers, classify from last layer only
2. LoRA-All + Last-4 Concat: LoRA on all layers, concat last 4 layers with adapters
3. LoRA-All + LOES-4 Concat: LoRA on all layers, concat LOES top-4 layers with adapters

Models: DINOv2-Large (CLS pooling), CLIP-ViT-B/32 (mean pooling)
Datasets: Stanford Cars, DTD

difference from v1-LOES guides which layers to FUSE for classification,not where to place LoRA adapters.
"""

import os
import csv
import time
import random
from datetime import datetime
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import v2

import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    AutoImageProcessor, 
    Dinov2Model,
    CLIPVisionModel,
    CLIPProcessor
)
from peft import LoraConfig, get_peft_model


SEED = 42
DEVICE = "cuda"
BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
EPOCHS = 15
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
CALIBRATION_PCT = 0.2

# LoRA configuration
LORA_RANK = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.1

# Layer selection
K_LAYERS = 4  # Number of layers to concat
PROJ_DIM = 256  # Adapter projection dimension

# Output
OUTPUT_DIR = Path("lora_loes_results_v2")
RESULTS_CSV = OUTPUT_DIR / "lora_loes_results_v2.csv"

# Model configs
MODELS = {
     "clip-vit-b32": {
        "name": "openai/clip-vit-base-patch32",
        "num_layers": 12,
        "hidden_dim": 768,
        "pooling": "mean",
        "lora_targets": ["q_proj", "v_proj"],  # CLIP uses different naming
    },
    "dinov2-large": {
        "name": "facebook/dinov2-large",
        "num_layers": 24,
        "hidden_dim": 1024,
        "pooling": "cls",
        "lora_targets": ["query", "value"],  # In attention
    }
   
}

# Datasets
DATASETS = [
    {"name": "tanganke/stanford_cars", "train": "train", "val": None, "test": "test"},
    {"name": "randall-lab/dtd", "train": "train", "val": "validation", "test": "test"},
]


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def count_parameters(model, only_trainable=True):
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def format_params(num_params):
    if num_params >= 1e6:
        return f"{num_params/1e6:.2f}M"
    elif num_params >= 1e3:
        return f"{num_params/1e3:.2f}K"
    return str(num_params)


def get_timestamp():
    return datetime.now().strftime("%Y-%m-%d %H:%M:%S")



class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, split, img_size=224):
        self.ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
        f = self.ds.features

        if "image" in f:
            self.img_key = "image"
        elif "img" in f:
            self.img_key = "img"
        else:
            raise KeyError(f"No image column found. Available: {list(f.keys())}")

        if "label" in f:
            self.label_key = "label"
        elif "fine_label" in f:
            self.label_key = "fine_label"
        else:
            raise KeyError(f"No label column found. Available: {list(f.keys())}")

        self.tf = v2.Compose([
            v2.Resize(img_size, interpolation=v2.InterpolationMode.BICUBIC),
            v2.CenterCrop(img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        self._num_classes = len(set(x[self.label_key] for x in self.ds))

    def __getitem__(self, i):
        x = self.ds[i]
        y = int(x[self.label_key])
        return self.tf(x[self.img_key].convert("RGB")), y

    def __len__(self):
        return len(self.ds)
    
    @property
    def num_classes(self):
        return self._num_classes


def compute_isotropy(X, eps=1e-9):
    """Compute isotropy score - higher means more isotropic (better)."""
    Xc = X - X.mean(0, keepdim=True)
    cov = (Xc.t() @ Xc) / Xc.shape[0]
    try:
        eigs = torch.linalg.eigvalsh(cov).real.clamp(min=0.0)
        return (eigs.mean() / (eigs.std(unbiased=False) + eps)).item()
    except:
        return 1.0


def closed_form_ridge(X, Y, reg=1e-3):
    """Closed-form ridge regression."""
    Xc = X - X.mean(0, keepdim=True)
    Yc = Y - Y.mean(0, keepdim=True)
    W = torch.linalg.solve(
        Xc.t() @ Xc + reg * torch.eye(X.shape[1], device=X.device), 
        Xc.t() @ Yc
    )
    b = (Y.mean(0, keepdim=True) - X.mean(0, keepdim=True) @ W).squeeze(0)
    return W, b


def pool_hidden_state(hidden_state, pooling="cls"):
    """Pool hidden state based on strategy."""
    if pooling == "cls":
        return hidden_state[:, 0, :]  # CLS token
    else:  # mean
        return hidden_state.mean(dim=1)


def collect_embeddings_for_loes(model, dataset, n_samples, batch_size, num_layers, pooling, device):
    """Collect embeddings from all layers for LOES scoring."""
    model.eval()
    
    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
    loader = DataLoader(
        Subset(dataset, indices), batch_size, 
        shuffle=False, num_workers=4, pin_memory=True
    )
    
    all_embeddings = [[] for _ in range(num_layers)]
    all_labels = []
    
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Collecting embeddings for LOES"):
            x = x.to(device)
            
            outputs = model(pixel_values=x, output_hidden_states=True)
            hidden_states = outputs.hidden_states
            
            # Extract pooled features from each transformer layer
            for layer_idx in range(num_layers):
                hs = hidden_states[layer_idx + 1]  # +1 to skip embedding layer
                pooled = pool_hidden_state(hs, pooling).cpu()
                all_embeddings[layer_idx].append(pooled)
            
            all_labels.append(y)
    
    embeddings = [torch.cat(e, dim=0) for e in all_embeddings]
    labels = torch.cat(all_labels)
    
    return embeddings, labels


def compute_loes_scores(embeddings, labels, alpha=1.0):
    """Compute LOES score for each layer. Lower = better."""
    num_classes = int(labels.max()) + 1
    Y = F.one_hot(labels, num_classes).float()
    
    layer_scores = []
    for layer_idx, X in enumerate(embeddings):
        X = X.float()
        W, b = closed_form_ridge(X, Y)
        pred = X @ W + b
        loss = ((pred - Y) ** 2).mean().item()
        iso = compute_isotropy(X)
        score = loss + alpha * (1 - iso)
        
        layer_scores.append({
            'layer_idx': layer_idx,
            'loes_score': score,
            'ridge_loss': loss,
            'isotropy': iso
        })
    
    return layer_scores


def select_top_k_layers(layer_scores, k):
    """Select top-k layers by LOES score (lower is better)."""
    sorted_scores = sorted(layer_scores, key=lambda x: x['loes_score'])
    return [s['layer_idx'] for s in sorted_scores[:k]]


class GeometricLoss(nn.Module):
    """GeoReg: Enforces simplicial structure on class manifolds."""
    def __init__(self, weight=0.1):
        super().__init__()
        self.weight = weight
    
    def forward(self, feats, labels):
        if self.weight <= 0:
            return torch.tensor(0.0, device=feats.device)
        
        classes = torch.unique(labels)
        if len(classes) < 3:
            return torch.tensor(0.0, device=feats.device)
        
        # Compute class centroids
        centroids = torch.stack([feats[labels == c].mean(0) for c in classes])
        if centroids.shape[0] < 3:
            return torch.tensor(0.0, device=feats.device)
        
        # Triangle area from random triplet
        idx = torch.randperm(len(centroids))[:3]
        a, b, c = centroids[idx[0]], centroids[idx[1]], centroids[idx[2]]
        ab, ac = a - b, a - c
        area = 0.5 * torch.sqrt((ab.pow(2).sum() * ac.pow(2).sum() - (ab * ac).sum().pow(2)).clamp(min=1e-6))
        
        # Isotropy loss from covariance
        cov = torch.cov(feats.T) + 1e-4 * torch.eye(feats.shape[1], device=feats.device)
        try:
            iso_loss = torch.linalg.eigvalsh(cov).real.clamp(min=1e-6).var()
        except:
            return torch.tensor(0.0, device=feats.device)
        
        return self.weight * (iso_loss - torch.log(area + 1e-6))


def create_model_with_lora(model_key, device="cuda"):
    """Create model with LoRA on all layers."""
    config = MODELS[model_key]
    model_name = config["name"]
    num_layers = config["num_layers"]
    
    print(f"  Loading {model_name}...")
    
    if "dinov2" in model_key:
        base_model = Dinov2Model.from_pretrained(model_name)
        # DINOv2 LoRA targets
        target_modules = []
        for i in range(num_layers):
            target_modules.append(f"encoder.layer.{i}.attention.attention.query")
            target_modules.append(f"encoder.layer.{i}.attention.attention.value")
    
    elif "clip" in model_key:
        base_model = CLIPVisionModel.from_pretrained(model_name)
        # CLIP LoRA targets
        target_modules = []
        for i in range(num_layers):
            target_modules.append(f"encoder.layers.{i}.self_attn.q_proj")
            target_modules.append(f"encoder.layers.{i}.self_attn.v_proj")
    
    else:
        raise ValueError(f"Unknown model key: {model_key}")
    
    lora_config = LoraConfig(
        r=LORA_RANK,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        target_modules=target_modules,
        bias="none",
    )
    
    model = get_peft_model(base_model, lora_config)
    model = model.to(device)
    
    return model, config


class LastLayerClassifier(nn.Module):
    """Simple classifier using only the last layer."""
    def __init__(self, input_dim, num_classes, dropout=0.2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Dropout(dropout),
            nn.Linear(input_dim, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(x)


class MultiLayerConcatClassifier(nn.Module):
    """
    Classifier that concatenates features from multiple layers with per-layer adapters.
    Matches your paper's approach: LayerNorm -> Linear -> GELU per layer, then concat.
    """
    def __init__(self, input_dim, num_classes, num_layers, proj_dim=256, dropout=0.2):
        super().__init__()
        
        # Per-layer adapters
        self.adapters = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(input_dim),
                nn.Linear(input_dim, proj_dim),
                nn.GELU()
            )
            for _ in range(num_layers)
        ])
        
        # Final classifier on concatenated features
        self.classifier = nn.Sequential(
            nn.LayerNorm(proj_dim * num_layers),
            nn.Dropout(dropout),
            nn.Linear(proj_dim * num_layers, num_classes)
        )
    
    def forward(self, layer_features):
        """
        Args:
            layer_features: List of [batch, hidden_dim] tensors, one per selected layer
        """
        adapted = [self.adapters[i](f) for i, f in enumerate(layer_features)]
        concat = torch.cat(adapted, dim=-1)
        return self.classifier(concat), concat  # Return concat for GeoReg


def extract_layer_features(model, x, layer_indices, pooling, num_layers):
    """Extract pooled features from specified layers."""
    outputs = model(pixel_values=x, output_hidden_states=True)
    hidden_states = outputs.hidden_states  # Tuple: embedding + num_layers
    
    features = []
    for idx in layer_indices:
        hs = hidden_states[idx + 1]  # +1 to skip embedding
        pooled = pool_hidden_state(hs, pooling)
        features.append(pooled)
    
    return features


def train_epoch_last_layer(model, classifier, train_loader, optimizer, scheduler, 
                           pooling, device):
    """Train using only last layer features."""
    model.train()
    classifier.train()
    
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        
        outputs = model(pixel_values=x, output_hidden_states=True)
        last_hidden = outputs.hidden_states[-1]
        features = pool_hidden_state(last_hidden, pooling)
        
        logits = classifier(features)
        loss = F.cross_entropy(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*total_correct/total_samples:.2f}%'})
    
    return total_loss / total_samples, total_correct / total_samples


def train_epoch_multi_layer(model, classifier, train_loader, optimizer, scheduler,
                            layer_indices, pooling, num_layers, geo_loss_fn, device):
    """Train using multi-layer concatenated features with GeoReg."""
    model.train()
    classifier.train()
    
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        
        # Extract features from selected layers
        layer_features = extract_layer_features(model, x, layer_indices, pooling, num_layers)
        
        # Forward through classifier
        logits, concat_features = classifier(layer_features)
        
        # Loss with GeoReg
        ce_loss = F.cross_entropy(logits, y)
        geo_loss = geo_loss_fn(concat_features, y)
        loss = ce_loss + geo_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*total_correct/total_samples:.2f}%'})
    
    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def evaluate_last_layer(model, classifier, loader, pooling, device):
    model.eval()
    classifier.eval()
    
    total_correct = 0
    total_samples = 0
    
    for x, y in tqdm(loader, desc="Evaluating", leave=False):
        x, y = x.to(device), y.to(device)
        
        outputs = model(pixel_values=x, output_hidden_states=True)
        last_hidden = outputs.hidden_states[-1]
        features = pool_hidden_state(last_hidden, pooling)
        
        logits = classifier(features)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
    
    return total_correct / total_samples


@torch.no_grad()
def evaluate_multi_layer(model, classifier, loader, layer_indices, pooling, num_layers, device):
    model.eval()
    classifier.eval()
    
    total_correct = 0
    total_samples = 0
    
    for x, y in tqdm(loader, desc="Evaluating", leave=False):
        x, y = x.to(device), y.to(device)
        
        layer_features = extract_layer_features(model, x, layer_indices, pooling, num_layers)
        logits, _ = classifier(layer_features)
        
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
    
    return total_correct / total_samples

def run_lora_last_layer(model_key, train_loader, eval_loader, test_loader, num_classes, device):
    """Config 1: LoRA-All + Last Layer classifier."""
    print("\n" + "="*70)
    print(f"CONFIG 1: LoRA-All + Last Layer")
    print("="*70)
    
    model, config = create_model_with_lora(model_key, device)
    model.print_trainable_parameters()
    
    classifier = LastLayerClassifier(config["hidden_dim"], num_classes).to(device)
    
    lora_params = count_parameters(model, only_trainable=True)
    classifier_params = count_parameters(classifier, only_trainable=True)
    total_trainable = lora_params + classifier_params
    
    print(f"  LoRA params: {format_params(lora_params)}")
    print(f"  Classifier params: {format_params(classifier_params)}")
    print(f"  Total trainable: {format_params(total_trainable)}")
    
    optimizer = torch.optim.AdamW([
        {'params': model.parameters(), 'lr': LEARNING_RATE},
        {'params': classifier.parameters(), 'lr': LEARNING_RATE}
    ], weight_decay=WEIGHT_DECAY)
    
    scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS, eta_min=1e-6)
    
    best_val_acc = 0.0
    best_state = None
    start_time = time.time()
    
    for epoch in range(EPOCHS):
        print(f"\n  Epoch {epoch+1}/{EPOCHS}")
        
        train_loss, train_acc = train_epoch_last_layer(
            model, classifier, train_loader, optimizer, scheduler, config["pooling"], device
        )
        
        val_acc = evaluate_last_layer(model, classifier, eval_loader, config["pooling"], device)
        
        print(f"    Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.2f}%")
        print(f"    Val Acc: {100*val_acc:.2f}%")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {
                'model': {k: v.cpu().clone() for k, v in model.state_dict().items()},
                'classifier': {k: v.cpu().clone() for k, v in classifier.state_dict().items()}
            }
    
    training_time = time.time() - start_time
    
    # Load best and evaluate on test
    if best_state:
        model.load_state_dict({k: v.to(device) for k, v in best_state['model'].items()})
        classifier.load_state_dict({k: v.to(device) for k, v in best_state['classifier'].items()})
    
    test_acc = evaluate_last_layer(model, classifier, test_loader, config["pooling"], device)
    
    print(f"\n  Best Val Acc: {100*best_val_acc:.2f}%")
    print(f"  Test Acc: {100*test_acc:.2f}%")
    print(f"  Training Time: {training_time/60:.1f} min")
    
    results = {
        'config': 'LoRA-All + Last Layer',
        'layers_used': f"[{config['num_layers']-1}]",
        'num_layers_fused': 1,
        'lora_params': lora_params,
        'classifier_params': classifier_params,
        'total_trainable_params': total_trainable,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'training_time_sec': training_time
    }
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_lora_multi_layer(model_key, train_loader, eval_loader, test_loader, num_classes,
                         layer_indices, config_name, device):
    """Config 2/3: LoRA-All + Multi-layer concat classifier."""
    print("\n" + "="*70)
    print(f"CONFIG: {config_name}")
    print(f"  Layers: {layer_indices}")
    print("="*70)
    
    model, config = create_model_with_lora(model_key, device)
    model.print_trainable_parameters()
    
    classifier = MultiLayerConcatClassifier(
        config["hidden_dim"], num_classes, len(layer_indices), PROJ_DIM
    ).to(device)
    
    geo_loss_fn = GeometricLoss(weight=0.1)
    
    lora_params = count_parameters(model, only_trainable=True)
    classifier_params = count_parameters(classifier, only_trainable=True)
    total_trainable = lora_params + classifier_params
    
    print(f"  LoRA params: {format_params(lora_params)}")
    print(f"  Classifier params: {format_params(classifier_params)}")
    print(f"  Total trainable: {format_params(total_trainable)}")
    
    optimizer = torch.optim.AdamW([
        {'params': model.parameters(), 'lr': LEARNING_RATE},
        {'params': classifier.parameters(), 'lr': LEARNING_RATE}
    ], weight_decay=WEIGHT_DECAY)
    
    scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * EPOCHS, eta_min=1e-6)
    
    best_val_acc = 0.0
    best_state = None
    start_time = time.time()
    
    for epoch in range(EPOCHS):
        print(f"\n  Epoch {epoch+1}/{EPOCHS}")
        
        train_loss, train_acc = train_epoch_multi_layer(
            model, classifier, train_loader, optimizer, scheduler,
            layer_indices, config["pooling"], config["num_layers"], geo_loss_fn, device
        )
        
        val_acc = evaluate_multi_layer(
            model, classifier, eval_loader, layer_indices, config["pooling"], config["num_layers"], device
        )
        
        print(f"    Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.2f}%")
        print(f"    Val Acc: {100*val_acc:.2f}%")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = {
                'model': {k: v.cpu().clone() for k, v in model.state_dict().items()},
                'classifier': {k: v.cpu().clone() for k, v in classifier.state_dict().items()}
            }
    
    training_time = time.time() - start_time
    
    # Load best and evaluate on test
    if best_state:
        model.load_state_dict({k: v.to(device) for k, v in best_state['model'].items()})
        classifier.load_state_dict({k: v.to(device) for k, v in best_state['classifier'].items()})
    
    test_acc = evaluate_multi_layer(
        model, classifier, test_loader, layer_indices, config["pooling"], config["num_layers"], device
    )
    
    print(f"\n  Best Val Acc: {100*best_val_acc:.2f}%")
    print(f"  Test Acc: {100*test_acc:.2f}%")
    print(f"  Training Time: {training_time/60:.1f} min")
    
    results = {
        'config': config_name,
        'layers_used': str(layer_indices),
        'num_layers_fused': len(layer_indices),
        'lora_params': lora_params,
        'classifier_params': classifier_params,
        'total_trainable_params': total_trainable,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'training_time_sec': training_time
    }
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_model_dataset_experiments(model_key, dataset_config, device):
    """Run all 3 configs for one model-dataset pair."""
    model_config = MODELS[model_key]
    dataset_name = dataset_config['name']
    dataset_safe = dataset_name.split('/')[-1]
    num_layers = model_config["num_layers"]
    
    print("\n" + "#"*70)
    print(f"# MODEL: {model_key} | DATASET: {dataset_safe}")
    print("#"*70)
    
    # Load dataset
    print(f"\nLoading dataset {dataset_name}...")
    train_ds = ImageDataset(dataset_name, dataset_config['train'])
    num_classes = train_ds.num_classes
    print(f"  Train samples: {len(train_ds)}, Classes: {num_classes}")
    
    train_loader = DataLoader(
        train_ds, BATCH_SIZE, shuffle=True, 
        drop_last=True, num_workers=4, pin_memory=True
    )
    
    # Eval loader
    if dataset_config['val']:
        eval_ds = ImageDataset(dataset_name, dataset_config['val'])
        eval_loader = DataLoader(eval_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    else:
        eval_ds = ImageDataset(dataset_name, dataset_config['test'])
        eval_loader = DataLoader(eval_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    
    # Test loader
    test_ds = ImageDataset(dataset_name, dataset_config['test'])
    test_loader = DataLoader(test_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    
    print("\n" + "="*70)
    print("PHASE 1: LOES LAYER SCORING")
    print("="*70)
    
    # Load frozen base model for LOES
    if "dinov2" in model_key:
        base_model = Dinov2Model.from_pretrained(model_config["name"]).to(device)
    else:
        base_model = CLIPVisionModel.from_pretrained(model_config["name"]).to(device)
    
    base_model.eval()
    for p in base_model.parameters():
        p.requires_grad = False
    
    n_cal_samples = int(len(train_ds) * CALIBRATION_PCT)
    print(f"  Collecting embeddings from {n_cal_samples} samples...")
    
    embeddings, labels = collect_embeddings_for_loes(
        base_model, train_ds, n_cal_samples, BATCH_SIZE, 
        num_layers, model_config["pooling"], device
    )
    
    print("  Computing LOES scores...")
    layer_scores = compute_loes_scores(embeddings, labels)
    
    # Print scores
    print(f"\n  Layer LOES Scores (lower = better):")
    sorted_scores = sorted(layer_scores, key=lambda x: x['loes_score'])
    for s in sorted_scores:
        print(f"    Layer {s['layer_idx']:2d}: score={s['loes_score']:.4f}, "
              f"loss={s['ridge_loss']:.4f}, iso={s['isotropy']:.4f}")
    
    # Get layer selections
    loes_layers = select_top_k_layers(layer_scores, K_LAYERS)
    last_k_layers = list(range(num_layers - K_LAYERS, num_layers))
    
    print(f"\n  LOES-{K_LAYERS} selected: {loes_layers}")
    print(f"  Last-{K_LAYERS} baseline: {last_k_layers}")
    
    # Save LOES scores
    loes_csv = OUTPUT_DIR / f"loes_scores_{model_key}_{dataset_safe}.csv"
    with open(loes_csv, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=['layer_idx', 'loes_score', 'ridge_loss', 'isotropy'])
        writer.writeheader()
        writer.writerows(sorted(layer_scores, key=lambda x: x['layer_idx']))
    print(f"  Saved LOES scores to {loes_csv}")
    
    del base_model, embeddings, labels
    torch.cuda.empty_cache()
    
    all_results = []
    
    # Config 1: LoRA-All + Last Layer
    set_seed(SEED)
    result1 = run_lora_last_layer(model_key, train_loader, eval_loader, test_loader, num_classes, device)
    result1['model'] = model_key
    result1['dataset'] = dataset_safe
    all_results.append(result1)
    
    set_seed(SEED)
    result2 = run_lora_multi_layer(
        model_key, train_loader, eval_loader, test_loader, num_classes,
        last_k_layers, f"LoRA-All + Last-{K_LAYERS} Concat", device
    )
    result2['model'] = model_key
    result2['dataset'] = dataset_safe
    all_results.append(result2)
    
    # Config 3: LoRA-All + LOES-4 Concat
    set_seed(SEED)
    result3 = run_lora_multi_layer(
        model_key, train_loader, eval_loader, test_loader, num_classes,
        loes_layers, f"LoRA-All + LOES-{K_LAYERS} Concat", device
    )
    result3['model'] = model_key
    result3['dataset'] = dataset_safe
    result3['loes_selected_layers'] = loes_layers
    all_results.append(result3)
    
    print("\n" + "="*70)
    print(f"SUMMARY: {model_key} on {dataset_safe}")
    print("="*70)
    print(f"{'Config':<30} {'Params':<12} {'Val Acc':<10} {'Test Acc':<10} {'Time':<10}")
    print("-"*70)
    for r in all_results:
        print(f"{r['config']:<30} {format_params(r['total_trainable_params']):<12} "
              f"{100*r['best_val_acc']:.2f}%     {100*r['test_acc']:.2f}%     "
              f"{r['training_time_sec']/60:.1f} min")
    
    return all_results


def save_results_to_csv(results):
    """Append results to CSV."""
    fieldnames = [
        'timestamp', 'model', 'dataset', 'config', 'layers_used', 'num_layers_fused',
        'lora_params', 'classifier_params', 'total_trainable_params',
        'best_val_acc', 'test_acc', 'training_time_sec', 'training_time_min',
        'seed', 'batch_size', 'epochs', 'learning_rate', 'lora_rank', 'proj_dim'
    ]
    
    write_header = not RESULTS_CSV.exists()
    
    with open(RESULTS_CSV, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if write_header:
            writer.writeheader()
        
        for r in results:
            row = {
                'timestamp': get_timestamp(),
                'model': r['model'],
                'dataset': r['dataset'],
                'config': r['config'],
                'layers_used': r['layers_used'],
                'num_layers_fused': r['num_layers_fused'],
                'lora_params': r['lora_params'],
                'classifier_params': r['classifier_params'],
                'total_trainable_params': r['total_trainable_params'],
                'best_val_acc': round(r['best_val_acc'], 6),
                'test_acc': round(r['test_acc'], 6),
                'training_time_sec': round(r['training_time_sec'], 2),
                'training_time_min': round(r['training_time_sec'] / 60, 2),
                'seed': SEED,
                'batch_size': BATCH_SIZE,
                'epochs': EPOCHS,
                'learning_rate': LEARNING_RATE,
                'lora_rank': LORA_RANK,
                'proj_dim': PROJ_DIM
            }
            writer.writerow(row)
    
    print(f"\nResults appended to {RESULTS_CSV}")


def main():
    """Main entry point."""
    set_seed(SEED)
    OUTPUT_DIR.mkdir(exist_ok=True)
    
    print("="*70)
    print("LoRA + LOES v2 EXPERIMENT")
    print("="*70)
    print(f"Testing: LOES layer selection for classification head")
    print(f"Device: {DEVICE}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Epochs: {EPOCHS}")
    print(f"K layers: {K_LAYERS}")
    print(f"LoRA rank: {LORA_RANK}")
    print(f"Projection dim: {PROJ_DIM}")
    print("="*70)
    
    all_results = []
    
    # Order: DINOv2 Stanford Cars -> CLIP Stanford Cars -> DINOv2 DTD -> CLIP DTD
    experiment_order = [
        ("dinov2-large", DATASETS[0]),  # DINOv2 + Stanford Cars
        ("clip-vit-b32", DATASETS[0]),   # CLIP + Stanford Cars
        ("dinov2-large", DATASETS[1]),  # DINOv2 + DTD
        ("clip-vit-b32", DATASETS[1]),   # CLIP + DTD
    ]
    
    for model_key, dataset_config in experiment_order:
        results = run_model_dataset_experiments(model_key, dataset_config, DEVICE)
        all_results.extend(results)
        
        # Save after each model-dataset pair
        save_results_to_csv(results)
    
    # Final summary
    print("\n" + "#"*70)
    print("# FINAL SUMMARY")
    print("#"*70)
    
    for model_key in ["dinov2-large", "clip-vit-b32"]:
        for ds in DATASETS:
            ds_name = ds['name'].split('/')[-1]
            relevant = [r for r in all_results if r['model'] == model_key and r['dataset'] == ds_name]
            if relevant:
                print(f"\n{model_key} + {ds_name}:")
                print(f"  {'Config':<30} {'Test Acc':<10}")
                print(f"  {'-'*45}")
                for r in relevant:
                    print(f"  {r['config']:<30} {100*r['test_acc']:.2f}%")
    
    print("\n" + "="*70)
    print("EXPERIMENT COMPLETE!")
    print(f"Results saved to: {RESULTS_CSV}")
    print("="*70)


if __name__ == "__main__":
    main()
