"""
total->
1. LoRA-All: LoRA on all 24 layers (rank 8)
2. LoRA-Last-4: LoRA on last 4 layers only (rank 8)  
3. LoRA-LOES-4: LoRA on top-4 LOES-scoring layers (rank 8)
4. LoRA-LOES-Weighted: LoRA on all layers with rank proportional to LOES score

All trained with frozen backbone + LoRA adapters + linear classifier on last layer.
"""

import os
import csv
import time
import random
from datetime import datetime
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.transforms import v2

import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoImageProcessor, Dinov2Model, Dinov2Config
from peft import LoraConfig, get_peft_model


SEED = 42
DEVICE = "cuda"
BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
EPOCHS = 15
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
CALIBRATION_PCT = 0.2

# DINOv2-Large specs
MODEL_NAME = "facebook/dinov2-large"
NUM_LAYERS = 24
HIDDEN_DIM = 1024

# LoRA configuration
LORA_RANK_DEFAULT = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.1

# Output
OUTPUT_DIR = Path("lora_loes_results")
RESULTS_CSV = OUTPUT_DIR / "lora_loes_results.csv"

# Datasets to run
DATASETS = [
    {"name": "tanganke/stanford_cars", "train": "train", "val": None, "test": "test"},
    {"name": "tanganke/sun397", "train": "train", "val": None, "test": "test"},
]

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def count_parameters(model, only_trainable=True):
    """Count model parameters."""
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())


def format_params(num_params):
    """Format parameter count for display."""
    if num_params >= 1e6:
        return f"{num_params/1e6:.2f}M"
    elif num_params >= 1e3:
        return f"{num_params/1e3:.2f}K"
    return str(num_params)


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_name, split, img_size=224):
        self.ds = load_dataset(dataset_name, split=split, trust_remote_code=True)
        f = self.ds.features

        # Find image and label columns
        if "image" in f:
            self.img_key = "image"
        elif "img" in f:
            self.img_key = "img"
        else:
            raise KeyError(f"No image column found. Available: {list(f.keys())}")

        if "label" in f:
            self.label_key = "label"
        elif "fine_label" in f:
            self.label_key = "fine_label"
        else:
            raise KeyError(f"No label column found. Available: {list(f.keys())}")

        self.tf = v2.Compose([
            v2.Resize(img_size),
            v2.CenterCrop(img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        
        # Cache num_classes
        self._num_classes = None

    def __getitem__(self, i):
        x = self.ds[i]
        y = int(x[self.label_key])
        return self.tf(x[self.img_key].convert("RGB")), y

    def __len__(self):
        return len(self.ds)
    
    @property
    def num_classes(self):
        if self._num_classes is None:
            # Try to get from features first (faster)
            feat = self.ds.features[self.label_key]
            if hasattr(feat, 'num_classes'):
                self._num_classes = feat.num_classes
            else:
                # Fallback to iterating
                self._num_classes = len(set(x[self.label_key] for x in self.ds))
        return self._num_classes


def compute_isotropy(X, eps=1e-9):
    """Compute isotropy score of embedding matrix."""
    Xc = X - X.mean(0, keepdim=True)
    cov = (Xc.t() @ Xc) / Xc.shape[0]
    try:
        eigs = torch.linalg.eigvalsh(cov).real.clamp(min=0.0)
        return (eigs.mean() / (eigs.std(unbiased=False) + eps)).item()
    except:
        return 1.0


def closed_form_ridge(X, Y, reg=1e-3):
    """Closed-form ridge regression solution."""
    Xc = X - X.mean(0, keepdim=True)
    Yc = Y - Y.mean(0, keepdim=True)
    W = torch.linalg.solve(
        Xc.t() @ Xc + reg * torch.eye(X.shape[1], device=X.device), 
        Xc.t() @ Yc
    )
    b = (Y.mean(0, keepdim=True) - X.mean(0, keepdim=True) @ W).squeeze(0)
    return W, b


def collect_embeddings_for_loes(model, dataset, n_samples, batch_size, device):
    """
    Collect embeddings from all layers for LOES scoring.
    Returns list of [n_samples, hidden_dim] tensors, one per layer.
    """
    model.eval()
    
    # Sample indices
    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))
    loader = DataLoader(Subset(dataset, indices), batch_size, shuffle=False, num_workers=4)
    
    all_embeddings = [[] for _ in range(NUM_LAYERS)]
    all_labels = []
    
    with torch.no_grad():
        for x, y in tqdm(loader, desc="Collecting embeddings for LOES"):
            x = x.to(device)
            
            # Forward pass with hidden states
            outputs = model(pixel_values=x, output_hidden_states=True)
            hidden_states = outputs.hidden_states  # Tuple of (batch, seq_len, hidden_dim)
            
            # Pool CLS token from each layer (skip embedding layer at index 0)
            for layer_idx in range(NUM_LAYERS):
                hs = hidden_states[layer_idx + 1]  # +1 to skip embedding layer
                cls_token = hs[:, 0, :].cpu()  # CLS token
                all_embeddings[layer_idx].append(cls_token)
            
            all_labels.append(y)
    
    # Concatenate
    embeddings = [torch.cat(e, dim=0) for e in all_embeddings]
    labels = torch.cat(all_labels)
    
    return embeddings, labels


def compute_loes_scores(embeddings, labels, alpha=1.0):
    """
    Compute LOES score for each layer.
    Lower score = better layer for the task.
    
    Returns: List of dicts with layer_idx, loes_score, loss, isotropy
    """
    num_classes = int(labels.max()) + 1
    Y = F.one_hot(labels, num_classes).float()
    
    layer_scores = []
    for layer_idx, X in enumerate(embeddings):
        X = X.float()
        W, b = closed_form_ridge(X, Y)
        pred = X @ W + b
        loss = ((pred - Y) ** 2).mean().item()
        iso = compute_isotropy(X)
        score = loss + alpha * (1 - iso)
        
        layer_scores.append({
            'layer_idx': layer_idx,
            'loes_score': score,
            'ridge_loss': loss,
            'isotropy': iso
        })
    
    return layer_scores


def select_top_k_layers(layer_scores, k):
    """Select top-k layers by LOES score (lower is better)."""
    sorted_scores = sorted(layer_scores, key=lambda x: x['loes_score'])
    return [s['layer_idx'] for s in sorted_scores[:k]]


def compute_weighted_ranks(layer_scores):
    """
    Compute LoRA ranks proportional to layer quality.
    Better layers (lower LOES) get higher rank.
    
    Distribution:
    - Top 6 layers: rank 16
    - Middle 12 layers: rank 8  
    - Bottom 6 layers: rank 4
    
    Total params roughly equal to uniform rank 8 across all layers.
    """
    sorted_scores = sorted(layer_scores, key=lambda x: x['loes_score'])
    
    rank_assignment = {}
    for i, s in enumerate(sorted_scores):
        layer_idx = s['layer_idx']
        if i < 6:
            rank_assignment[layer_idx] = 16  # Top 6
        elif i < 18:
            rank_assignment[layer_idx] = 8   # Middle 12
        else:
            rank_assignment[layer_idx] = 4   # Bottom 6
    
    return rank_assignment


def get_target_modules_for_layers(layer_indices):
    """
    Generate target module names for specific layers.
    DINOv2 uses: encoder.layer.{idx}.attention.attention.{query,value}
    """
    modules = []
    for idx in layer_indices:
        modules.append(f"encoder.layer.{idx}.attention.attention.query")
        modules.append(f"encoder.layer.{idx}.attention.attention.value")
    return modules


def create_dinov2_with_lora(
    layer_indices=None,  # If None, apply to all layers
    rank=8,
    rank_pattern=None,  # Dict mapping layer_idx -> rank (for weighted)
    device="cuda"
):
    """
    Create DINOv2-Large with LoRA adapters.
    
    Args:
        layer_indices: List of layer indices to apply LoRA (0-23). None = all layers.
        rank: Default LoRA rank
        rank_pattern: Dict mapping layer_idx -> rank for variable ranks
        device: Device to load model on
    
    Returns:
        model: PEFT model with LoRA
    """
    # Load base model
    print(f"Loading {MODEL_NAME}...")
    base_model = Dinov2Model.from_pretrained(MODEL_NAME)
    
    # Determine which layers to apply LoRA
    if layer_indices is None:
        layer_indices = list(range(NUM_LAYERS))
    
    # Get target modules
    target_modules = get_target_modules_for_layers(layer_indices)
    
    # Create LoRA config
    if rank_pattern is not None:
        # Variable rank per layer - need to use rank_pattern in config
        # PEFT supports rank_pattern as dict: {module_name: rank}
        pattern = {}
        for layer_idx, r in rank_pattern.items():
            pattern[f"encoder.layer.{layer_idx}.attention.attention.query"] = r
            pattern[f"encoder.layer.{layer_idx}.attention.attention.value"] = r
        
        lora_config = LoraConfig(
            r=rank,  # Default rank
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            target_modules=target_modules,
            rank_pattern=pattern,
            alpha_pattern={k: LORA_ALPHA for k in pattern.keys()},
            bias="none",
        )
    else:
        lora_config = LoraConfig(
            r=rank,
            lora_alpha=LORA_ALPHA,
            lora_dropout=LORA_DROPOUT,
            target_modules=target_modules,
            bias="none",
        )
    
    # Apply LoRA
    model = get_peft_model(base_model, lora_config)
    model = model.to(device)
    
    return model

class ClassifierHead(nn.Module):
    def __init__(self, input_dim, num_classes, dropout=0.2):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Dropout(dropout),
            nn.Linear(input_dim, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(x)



def train_one_epoch(model, classifier, train_loader, optimizer, scheduler, device):
    """Train for one epoch."""
    model.train()
    classifier.train()
    
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        
        # Forward through DINOv2 + LoRA
        outputs = model(pixel_values=x)
        
        # Get CLS token from last hidden state
        cls_token = outputs.last_hidden_state[:, 0, :]
        
        # Classify
        logits = classifier(cls_token)
        loss = F.cross_entropy(logits, y)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_loss += loss.item() * x.size(0)
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100*total_correct/total_samples:.2f}%'
        })
    
    return total_loss / total_samples, total_correct / total_samples


@torch.no_grad()
def evaluate(model, classifier, loader, device):
    """Evaluate model."""
    model.eval()
    classifier.eval()
    
    total_correct = 0
    total_samples = 0
    
    for x, y in tqdm(loader, desc="Evaluating"):
        x, y = x.to(device), y.to(device)
        
        outputs = model(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0, :]
        logits = classifier(cls_token)
        
        total_correct += (logits.argmax(1) == y).sum().item()
        total_samples += x.size(0)
    
    return total_correct / total_samples


def run_training(
    model, 
    classifier, 
    train_loader, 
    eval_loader, 
    test_loader,
    epochs,
    device
):
    """Full training loop."""
    # Count trainable parameters
    lora_params = count_parameters(model, only_trainable=True)
    classifier_params = count_parameters(classifier, only_trainable=True)
    total_trainable = lora_params + classifier_params
    
    print(f"  LoRA params: {format_params(lora_params)}")
    print(f"  Classifier params: {format_params(classifier_params)}")
    print(f"  Total trainable: {format_params(total_trainable)}")
    
    # Optimizer
    optimizer = torch.optim.AdamW([
        {'params': model.parameters(), 'lr': LEARNING_RATE},
        {'params': classifier.parameters(), 'lr': LEARNING_RATE}
    ], weight_decay=WEIGHT_DECAY)
    
    scheduler = CosineAnnealingLR(
        optimizer, 
        T_max=len(train_loader) * epochs,
        eta_min=1e-6
    )
    
    # Training
    best_acc = 0.0
    best_state = None
    
    start_time = time.time()
    
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        train_loss, train_acc = train_one_epoch(
            model, classifier, train_loader, optimizer, scheduler, device
        )
        
        val_acc = evaluate(model, classifier, eval_loader, device)
        
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.2f}%")
        print(f"  Val Acc: {100*val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_acc = val_acc
            best_state = {
                'model': {k: v.cpu().clone() for k, v in model.state_dict().items()},
                'classifier': {k: v.cpu().clone() for k, v in classifier.state_dict().items()}
            }
    
    training_time = time.time() - start_time
    
    # Final test evaluation
    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state['model'].items()})
        classifier.load_state_dict({k: v.to(device) for k, v in best_state['classifier'].items()})
    
    final_loader = test_loader if test_loader else eval_loader
    test_acc = evaluate(model, classifier, final_loader, device)
    
    print(f"\n  Best Val Acc: {100*best_acc:.2f}%")
    print(f"  Test Acc: {100*test_acc:.2f}%")
    print(f"  Training Time: {training_time/60:.1f} min")
    
    return {
        'best_val_acc': best_acc,
        'test_acc': test_acc,
        'lora_params': lora_params,
        'classifier_params': classifier_params,
        'total_trainable_params': total_trainable,
        'training_time_sec': training_time
    }


def run_lora_all(train_ds, train_loader, eval_loader, test_loader, num_classes, device):
    """Config A: LoRA on all 24 layers with rank 8."""
    print("\n" + "="*80)
    print("CONFIG A: LoRA-All (all 24 layers, rank 8)")
    print("="*80)
    
    model = create_dinov2_with_lora(
        layer_indices=None,  # All layers
        rank=LORA_RANK_DEFAULT,
        device=device
    )
    model.print_trainable_parameters()
    
    classifier = ClassifierHead(HIDDEN_DIM, num_classes).to(device)
    
    results = run_training(
        model, classifier, train_loader, eval_loader, test_loader, EPOCHS, device
    )
    
    results['config'] = 'LoRA-All'
    results['layers'] = 'all (0-23)'
    results['rank'] = LORA_RANK_DEFAULT
    results['num_lora_layers'] = NUM_LAYERS
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_lora_last_k(train_ds, train_loader, eval_loader, test_loader, num_classes, k, device):
    """Config B: LoRA on last k layers only."""
    print("\n" + "="*80)
    print(f"CONFIG B: LoRA-Last-{k} (last {k} layers, rank 8)")
    print("="*80)
    
    layer_indices = list(range(NUM_LAYERS - k, NUM_LAYERS))
    print(f"  Target layers: {layer_indices}")
    
    model = create_dinov2_with_lora(
        layer_indices=layer_indices,
        rank=LORA_RANK_DEFAULT,
        device=device
    )
    model.print_trainable_parameters()
    
    classifier = ClassifierHead(HIDDEN_DIM, num_classes).to(device)
    
    results = run_training(
        model, classifier, train_loader, eval_loader, test_loader, EPOCHS, device
    )
    
    results['config'] = f'LoRA-Last-{k}'
    results['layers'] = str(layer_indices)
    results['rank'] = LORA_RANK_DEFAULT
    results['num_lora_layers'] = k
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_lora_loes_k(train_ds, train_loader, eval_loader, test_loader, num_classes, k, layer_scores, device):
    """Config C: LoRA on top-k LOES-scoring layers."""
    print("\n" + "="*80)
    print(f"CONFIG C: LoRA-LOES-{k} (top {k} LOES layers, rank 8)")
    print("="*80)
    
    layer_indices = select_top_k_layers(layer_scores, k)
    print(f"  LOES-selected layers: {layer_indices}")
    
    # Print layer scores for selected layers
    for idx in layer_indices:
        score = next(s for s in layer_scores if s['layer_idx'] == idx)
        print(f"    Layer {idx}: LOES={score['loes_score']:.4f}, iso={score['isotropy']:.4f}")
    
    model = create_dinov2_with_lora(
        layer_indices=layer_indices,
        rank=LORA_RANK_DEFAULT,
        device=device
    )
    model.print_trainable_parameters()
    
    classifier = ClassifierHead(HIDDEN_DIM, num_classes).to(device)
    
    results = run_training(
        model, classifier, train_loader, eval_loader, test_loader, EPOCHS, device
    )
    
    results['config'] = f'LoRA-LOES-{k}'
    results['layers'] = str(layer_indices)
    results['rank'] = LORA_RANK_DEFAULT
    results['num_lora_layers'] = k
    results['loes_selected_layers'] = layer_indices
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_lora_loes_weighted(train_ds, train_loader, eval_loader, test_loader, num_classes, layer_scores, device):
    """Config D: LoRA on all layers with rank proportional to LOES score."""
    print("\n" + "="*80)
    print("CONFIG D: LoRA-LOES-Weighted (all layers, variable rank)")
    print("="*80)
    
    rank_pattern = compute_weighted_ranks(layer_scores)
    
    # Print rank assignments
    print("  Rank assignments:")
    for idx in range(NUM_LAYERS):
        score = next(s for s in layer_scores if s['layer_idx'] == idx)
        print(f"    Layer {idx}: rank={rank_pattern[idx]}, LOES={score['loes_score']:.4f}")
    
    model = create_dinov2_with_lora(
        layer_indices=list(range(NUM_LAYERS)),
        rank=LORA_RANK_DEFAULT,
        rank_pattern=rank_pattern,
        device=device
    )
    model.print_trainable_parameters()
    
    classifier = ClassifierHead(HIDDEN_DIM, num_classes).to(device)
    
    results = run_training(
        model, classifier, train_loader, eval_loader, test_loader, EPOCHS, device
    )
    
    results['config'] = 'LoRA-LOES-Weighted'
    results['layers'] = 'all (variable rank)'
    results['rank'] = 'variable (4/8/16)'
    results['num_lora_layers'] = NUM_LAYERS
    results['rank_pattern'] = rank_pattern
    
    del model, classifier
    torch.cuda.empty_cache()
    
    return results


def run_dataset_experiments(dataset_config, device):
    """Run all 4 configurations on a single dataset."""
    dataset_name = dataset_config['name']
    dataset_safe = dataset_name.split('/')[-1]
    
    print("\n" + "#"*80)
    print(f"# DATASET: {dataset_safe}")
    print("#"*80)
    
    # Load dataset
    print(f"\nLoading dataset {dataset_name}...")
    train_ds = ImageDataset(dataset_name, dataset_config['train'])
    num_classes = train_ds.num_classes
    print(f"  Train samples: {len(train_ds)}, Classes: {num_classes}")
    
    train_loader = DataLoader(
        train_ds, BATCH_SIZE, shuffle=True, 
        drop_last=True, num_workers=4, pin_memory=True
    )
    
    # Eval loader
    if dataset_config['val']:
        eval_ds = ImageDataset(dataset_name, dataset_config['val'])
        eval_loader = DataLoader(eval_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    else:
        # Use test as eval
        eval_ds = ImageDataset(dataset_name, dataset_config['test'])
        eval_loader = DataLoader(eval_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    
    # Test loader
    if dataset_config['test']:
        test_ds = ImageDataset(dataset_name, dataset_config['test'])
        test_loader = DataLoader(test_ds, TEST_BATCH_SIZE, num_workers=4, pin_memory=True)
    else:
        test_loader = None
    
    print("\n" + "="*80)
    print("PHASE 1: LOES LAYER SCORING")
    print("="*80)
    
    # Load base model for embedding extraction (no LoRA yet)
    base_model = Dinov2Model.from_pretrained(MODEL_NAME).to(device)
    base_model.eval()
    for p in base_model.parameters():
        p.requires_grad = False
    
    n_cal_samples = int(len(train_ds) * CALIBRATION_PCT)
    print(f"Collecting embeddings from {n_cal_samples} samples...")
    
    embeddings, labels = collect_embeddings_for_loes(
        base_model, train_ds, n_cal_samples, BATCH_SIZE, device
    )
    
    print("Computing LOES scores...")
    layer_scores = compute_loes_scores(embeddings, labels)
    
    # Print scores
    print("\nLayer LOES Scores (lower = better):")
    sorted_scores = sorted(layer_scores, key=lambda x: x['loes_score'])
    for s in sorted_scores:
        print(f"  Layer {s['layer_idx']:2d}: score={s['loes_score']:.4f}, loss={s['ridge_loss']:.4f}, iso={s['isotropy']:.4f}")
    
    # Save LOES scores
    loes_csv = OUTPUT_DIR / f"loes_scores_{dataset_safe}.csv"
    with open(loes_csv, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=['layer_idx', 'loes_score', 'ridge_loss', 'isotropy'])
        writer.writeheader()
        writer.writerows(sorted(layer_scores, key=lambda x: x['layer_idx']))
    print(f"Saved LOES scores to {loes_csv}")
    
    # Cleanup
    del base_model, embeddings, labels
    torch.cuda.empty_cache()

    all_results = []
    
    # Config A: LoRA-All
    result_a = run_lora_all(
        train_ds, train_loader, eval_loader, test_loader, num_classes, device
    )
    result_a['dataset'] = dataset_safe
    all_results.append(result_a)
    
    # Config B: LoRA-Last-4
    result_b = run_lora_last_k(
        train_ds, train_loader, eval_loader, test_loader, num_classes, k=4, device=device
    )
    result_b['dataset'] = dataset_safe
    all_results.append(result_b)
    
    # Config C: LoRA-LOES-4
    result_c = run_lora_loes_k(
        train_ds, train_loader, eval_loader, test_loader, num_classes, k=4, 
        layer_scores=layer_scores, device=device
    )
    result_c['dataset'] = dataset_safe
    all_results.append(result_c)
    
    # Config D: LoRA-LOES-Weighted
    result_d = run_lora_loes_weighted(
        train_ds, train_loader, eval_loader, test_loader, num_classes,
        layer_scores=layer_scores, device=device
    )
    result_d['dataset'] = dataset_safe
    all_results.append(result_d)
    
    print("\n" + "="*80)
    print(f"SUMMARY: {dataset_safe}")
    print("="*80)
    print(f"{'Config':<25} {'Params':<12} {'Val Acc':<10} {'Test Acc':<10} {'Time':<10}")
    print("-"*80)
    for r in all_results:
        print(f"{r['config']:<25} {format_params(r['total_trainable_params']):<12} "
              f"{100*r['best_val_acc']:.2f}%     {100*r['test_acc']:.2f}%     "
              f"{r['training_time_sec']/60:.1f} min")
    
    return all_results, layer_scores


def save_results_to_csv(all_results):
    """Append results to CSV file."""
    fieldnames = [
        'timestamp', 'dataset', 'config', 'layers', 'rank', 'num_lora_layers',
        'lora_params', 'classifier_params', 'total_trainable_params',
        'best_val_acc', 'test_acc', 'training_time_sec', 'training_time_min',
        'seed', 'batch_size', 'epochs', 'learning_rate'
    ]
    
    write_header = not RESULTS_CSV.exists()
    
    with open(RESULTS_CSV, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        if write_header:
            writer.writeheader()
        
        for r in all_results:
            row = {
                'timestamp': datetime.now().isoformat(),
                'dataset': r['dataset'],
                'config': r['config'],
                'layers': r['layers'],
                'rank': r['rank'],
                'num_lora_layers': r['num_lora_layers'],
                'lora_params': r['lora_params'],
                'classifier_params': r['classifier_params'],
                'total_trainable_params': r['total_trainable_params'],
                'best_val_acc': r['best_val_acc'],
                'test_acc': r['test_acc'],
                'training_time_sec': r['training_time_sec'],
                'training_time_min': r['training_time_sec'] / 60,
                'seed': SEED,
                'batch_size': BATCH_SIZE,
                'epochs': EPOCHS,
                'learning_rate': LEARNING_RATE
            }
            writer.writerow(row)
    
    print(f"\nResults saved to {RESULTS_CSV}")


def main():
    """Main entry point."""
    set_seed(SEED)
    
    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)
    
    print("="*80)
    print("LoRA + LOES EXPERIMENT")
    print("="*80)
    print(f"Model: {MODEL_NAME}")
    print(f"Device: {DEVICE}")
    print(f"Batch size: {BATCH_SIZE}")
    print(f"Epochs: {EPOCHS}")
    print(f"Seed: {SEED}")
    print(f"Datasets: {[d['name'] for d in DATASETS]}")
    print("="*80)
    
    all_results = []
    
    for dataset_config in DATASETS:
        results, layer_scores = run_dataset_experiments(dataset_config, DEVICE)
        all_results.extend(results)
        
        # Save after each dataset
        save_results_to_csv(results)
    
    # Final summary
    print("\n" + "#"*80)
    print("# FINAL SUMMARY")
    print("#"*80)
    
    for dataset in DATASETS:
        ds_name = dataset['name'].split('/')[-1]
        ds_results = [r for r in all_results if r['dataset'] == ds_name]
        
        print(f"\n{ds_name}:")
        print(f"  {'Config':<25} {'Params':<12} {'Test Acc':<10}")
        print(f"  {'-'*50}")
        for r in ds_results:
            print(f"  {r['config']:<25} {format_params(r['total_trainable_params']):<12} {100*r['test_acc']:.2f}%")


if __name__ == "__main__":
    main()