#!/usr/bin/env python3
"""
BigBird convergence-rate experiment with A-sweep capability.

This version adds the ability to sweep over different values of the attention
parameter A (like the BERT A-sweep script), while keeping the multihead BigBird
attention architecture.

Features:
- A-sweep: Test convergence at multiple A values (A_real and scaled versions)
- Improved "sexy" plots with confidence bands and better styling
- Sanity checks with overlay plots for different A values
- Weighted least squares fitting with full covariance estimation

Usage:
    python run_convergence_experiment.py --N 3 --layer_id 0 --A_sweep
    python run_convergence_experiment.py --N 3 --layer_id 0 --A_sweep --sanity_plot
"""

from transformers import AutoTokenizer, BigBirdModel, BigBirdTokenizer, AutoConfig, AutoModelForMaskedLM
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
import argparse
import datetime as dt
import gc
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Union, Any
import matplotlib.pyplot as plt
import matplotlib as mpl


# Chunked attention mean/cov implementation
try:
    from sample_complexity.utils.attention import (
        chunked_attention_memory_efficient,
        chunked_attention_cpu_offload,
    )
    from sample_complexity.utils.text_loaders import (
        load_text,
        load_for_experiment,
        generate_synthetic_embeddings,
        get_experiment_configs,
        get_available_sources,
        get_available_languages,
        TEXT_SOURCES,
        MULTILINGUAL_SOURCES,
        SHUFFLE_STRATEGIES,
    )
except ImportError:
    # Fallback for running from experiments/bigbird/ directory
    import sys
    sys.path.insert(0, str(Path(__file__).parent.parent.parent))
    from utils.attention import (
        chunked_attention_memory_efficient,
        chunked_attention_cpu_offload,
    )
    from utils.text_loaders import (
        load_text,
        load_for_experiment,
        generate_synthetic_embeddings,
        get_experiment_configs,
        get_available_sources,
        get_available_languages,
        TEXT_SOURCES,
        MULTILINGUAL_SOURCES,
        SHUFFLE_STRATEGIES,
    )

# Set up matplotlib style for "sexy" plots
#plt.style.use('seaborn-v0_8-whitegrid')

preferred_styles = [
    "seaborn-v0_8-whitegrid",
    "seaborn-whitegrid",
    "seaborn-v0_8",
    "seaborn",
    "ggplot",
    "default",
]

for s in preferred_styles:
    if s in plt.style.available:
        plt.style.use(s)
        break
mpl.rcParams['figure.facecolor'] = 'white'
mpl.rcParams['axes.facecolor'] = 'white'
mpl.rcParams['axes.edgecolor'] = '#333333'
mpl.rcParams['axes.labelcolor'] = '#333333'
mpl.rcParams['xtick.color'] = '#333333'
mpl.rcParams['ytick.color'] = '#333333'
mpl.rcParams['text.color'] = '#333333'
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.size'] = 11
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['axes.labelsize'] = 12
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['figure.dpi'] = 120

torch.manual_seed(0)

# ============== MODEL CONFIGURATIONS ==============
# BigBird-RoBERTa model variants
MODEL_CONFIGS = {
    "base": {
        "model_id": "google/bigbird-roberta-base",
        "hidden_size": 768,
        "num_attention_heads": 12,
        "attention_head_size": 64,  # hidden_size / num_attention_heads
    },
    "large": {
        "model_id": "google/bigbird-roberta-large",
        "hidden_size": 1024,
        "num_attention_heads": 16,
        "attention_head_size": 64,  # hidden_size / num_attention_heads
    },
}

# Default model size (can be overridden via CLI)
DEFAULT_MODEL_SIZE = "base"

# ============== DEVICE SETUP ==============
_FORCE_CPU = False
_MC_ON_GPU = False
_LIMITS_ON_GPU = False


def cuda_available() -> bool:
    return (not _FORCE_CPU) and torch.cuda.is_available()


def set_force_cpu(force: bool = True) -> None:
    """Force CPU for all compute in this module."""
    global _FORCE_CPU, device
    _FORCE_CPU = bool(force)
    if _FORCE_CPU:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def set_mc_on_gpu(enable: bool = True) -> None:
    """Allow MC loop to use GPU even when main compute is on CPU."""
    global _MC_ON_GPU
    _MC_ON_GPU = bool(enable)


def mc_cuda_available() -> bool:
    return _MC_ON_GPU and torch.cuda.is_available()


def set_limits_on_gpu(enable: bool = True) -> None:
    """Allow full-attention limits to use GPU matmuls via CPU offload."""
    global _LIMITS_ON_GPU
    _LIMITS_ON_GPU = bool(enable)


def limits_cuda_available() -> bool:
    return _LIMITS_ON_GPU and torch.cuda.is_available()


def _is_cuda_oom(err: RuntimeError) -> bool:
    msg = str(err)
    return ("out of memory" in msg) or ("CUBLAS_STATUS_ALLOC_FAILED" in msg) or ("CUDA error" in msg)


device = torch.device("cuda" if cuda_available() else "cpu")
print(f"Using device: {device}")
if cuda_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Flash SDP enabled: {torch.backends.cuda.flash_sdp_enabled()}")
    print(f"Mem-efficient SDP enabled: {torch.backends.cuda.mem_efficient_sdp_enabled()}")

MODEL_INPUT_DEVICE = device


def _normalize_device(dev) -> torch.device:
    if isinstance(dev, torch.device):
        return dev
    if isinstance(dev, int):
        return torch.device(f"cuda:{dev}")
    return torch.device(str(dev))


def _resolve_model_input_device(model, fallback: torch.device) -> torch.device:
    if hasattr(model, "hf_device_map") and isinstance(model.hf_device_map, dict):
        for key in ("embeddings", "model.embeddings"):
            if key in model.hf_device_map:
                return _normalize_device(model.hf_device_map[key])
        try:
            first = next(iter(model.hf_device_map.values()))
            return _normalize_device(first)
        except StopIteration:
            pass
    try:
        return next(model.parameters()).device
    except StopIteration:
        return fallback

# ============== COLOR PALETTE FOR SEXY PLOTS ==============
COLORS = {
    'primary': '#2E86AB',      # Deep blue
    'secondary': '#A23B72',    # Magenta
    'accent': '#F18F01',       # Orange
    'success': '#C73E1D',      # Red
    'dark': '#1B1B1E',         # Almost black
    'real_A': '#2E86AB',       # Color for real A value
    'sweep': ['#F18F01', '#A23B72', '#C73E1D', '#6B4C9A', '#28965A']  # Sweep colors
}


# ============== A-SWEEP UTILITY ==============

def make_A_targets(A_real: float, A_cap: float = 1e4, n_points: int = 5) -> np.ndarray:
    """
    Return n_points A targets, equally spaced in log scale, from A_real to A_cap.

    Args:
        A_real: The real/measured A value for the head
        A_cap: Upper bound for A sweep
        n_points: Number of points in the sweep

    Returns:
        Array of A target values including A_real
    """
    A_real = float(max(A_real, 1e-12))
    A_cap = float(max(A_cap, A_real * 1.1))  # Ensure A_cap > A_real

    A_targets = np.geomspace(A_real, A_cap, n_points).astype(float)
    return A_targets


# ============== POSITION INTERPOLATION ==============

def interpolate_pos_embeddings(model, new_max_length):
    """
    Extends position embeddings via linear interpolation to new_max_length.
    Allows BigBird to handle sequences longer than 4096.
    """
    old_embeddings = model.embeddings.position_embeddings.weight.data
    old_length, embed_dim = old_embeddings.shape
    if new_max_length <= old_length:
        return model

    old_embeddings = old_embeddings.T.unsqueeze(0)
    new_embeddings = F.interpolate(
        old_embeddings,
        size=new_max_length,
        mode='linear',
        align_corners=True
    )
    new_embeddings = new_embeddings.squeeze(0).T

    model.embeddings.position_embeddings = torch.nn.Embedding.from_pretrained(
        new_embeddings,
        freeze=False
    )

    if hasattr(model.embeddings, "token_type_ids"):
        new_token_type_ids = torch.zeros(1, new_max_length, dtype=torch.long)
        model.embeddings.register_buffer("token_type_ids", new_token_type_ids)

    if hasattr(model.embeddings, "position_ids"):
        new_position_ids = torch.arange(new_max_length).unsqueeze(0)
        model.embeddings.register_buffer("position_ids", new_position_ids)

    model.config.max_position_embeddings = new_max_length
    print(f"Extended positions: {old_length} -> {new_max_length}")
    return model


# ============== MODEL LOADING ==============

def get_model_config(model_size: str = "base") -> dict:
    """Get model configuration for the specified size.

    Args:
        model_size: "base" or "large"

    Returns:
        Dictionary with model_id, hidden_size, num_attention_heads, attention_head_size
    """
    if model_size not in MODEL_CONFIGS:
        raise ValueError(f"Unknown model size: {model_size}. Choose from {list(MODEL_CONFIGS.keys())}")
    return MODEL_CONFIGS[model_size]


# def load_model(model_name="BigBird", attention_type='original_full', max_length=None,
#                model_size="base", device_map: Optional[Union[str, Dict[str, Any]]] = None,
#                torch_dtype: Optional[torch.dtype] = None):
#     """Load model from HuggingFace in inference mode (supports multi-GPU sharding).

#     Args:
#         model_name: Model family (currently only "BigBird" supported)
#         attention_type: Attention type ('original_full' or 'block_sparse')
#         max_length: Maximum sequence length (if > 4096, will interpolate position embeddings)
#         model_size: "base" (768 hidden, 12 heads) or "large" (1024 hidden, 16 heads)
#         device_map: Optional device_map for multi-GPU sharding:
#             - None: standard single-device load
#             - "auto": shard across all visible GPUs
#             - dict: explicit device map
#         torch_dtype: Optional dtype for model weights (e.g., torch.float16)
#     """
#     if model_name != "BigBird":
#         raise Exception("Other model than BigBird not supported yet")

#     config_dict = get_model_config(model_size)
#     model_id = config_dict["model_id"]
#     print(f"Loading {model_id} (hidden_size={config_dict['hidden_size']}, heads={config_dict['num_attention_heads']})")

#     # Add _no_split_modules to BigBird classes to enable device_map="auto"
#     from transformers import BigBirdForMaskedLM
#     if not hasattr(BigBirdForMaskedLM, "_no_split_modules"):
#         BigBirdForMaskedLM._no_split_modules = ["BigBirdLayer"]
#         print("Added _no_split_modules=['BigBirdLayer'] to BigBirdForMaskedLM")

#     load_kwargs = {}
#     if device_map is not None:
#         load_kwargs["device_map"] = device_map
#         print(f"Using device_map={device_map}")
#     if torch_dtype is not None:
#         load_kwargs["torch_dtype"] = torch_dtype

#     bigbird = BigBirdForMaskedLM.from_pretrained(model_id, **load_kwargs)
#     bigbird.eval()
#     bigbird.set_attention_type(attention_type)

#     # Interpolate position embeddings if needed
#     if max_length is not None and max_length > 4096:
#         print(f"Interpolating position embeddings to max_length={max_length}")
#         bigbird = interpolate_pos_embeddings(bigbird, max_length)

#     # Determine model input device
#     if device_map is None:
#         bigbird.to(device)
#         model_input_device = device
#     else:
#         model_input_device = next(iter(bigbird.parameters())).device

#     print(f"Model loaded. Input device: {model_input_device}")

#     global MODEL_INPUT_DEVICE
#     MODEL_INPUT_DEVICE = model_input_device
#     return bigbird



def load_model(
    model_name: str = "BigBird",
    #attention_type: str = "original_full",
    attention_type: str = "block_sparse",
    max_length: Optional[int] = None,
    model_size: str = "base",
    device_map: Optional[Union[str, Dict[str, Any]]] = None,
    device: Optional[torch.device] = None,
    torch_dtype: Optional[torch.dtype] = None,
):
    """
    Load model from HuggingFace in inference mode (supports multi-GPU sharding for inference).

    device_map:
      - None: normal single-device load
      - "auto": shard across all visible GPUs (Accelerate)
      - dict: explicit device map (Accelerate)
    """
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if model_name != "BigBird":
        raise Exception("Other model than BigBird not supported yet")

    config_dict = get_model_config(model_size)
    model_id = config_dict["model_id"]
    print(f"Loading {model_id} (hidden_size={config_dict['hidden_size']}, heads={config_dict['num_attention_heads']})")

    offline_mode = os.environ.get("HF_HUB_OFFLINE", "0") == "1"

    # Always load on CPU first (allows position interpolation before sharding)
    print("Loading model on CPU...")
    try:
        bigbird = BigBirdModel.from_pretrained(model_id, torch_dtype=torch_dtype)
    except Exception as e:
        print(f"First load attempt failed: {e}")
        fallback_kwargs = {"revision": "refs/pr/2", "use_safetensors": True}
        if not offline_mode:
            fallback_kwargs["force_download"] = True
        print(f"Retrying with fallback_kwargs: {fallback_kwargs}")
        bigbird = BigBirdModel.from_pretrained(model_id, torch_dtype=torch_dtype, **fallback_kwargs)

    bigbird.eval()
    bigbird.set_attention_type(attention_type)

    # Interpolate position embeddings on CPU if needed
    if max_length is not None and max_length > 4096:
        print(f"Interpolating position embeddings to max_length={max_length}")
        bigbird = interpolate_pos_embeddings(bigbird, max_length)

    # Now place on device(s)
    if device_map is None:
        # Single device
        bigbird.to(device)
        model_input_device = device
        print(f"Model loaded on single device: {model_input_device}")
    else:
        # Multi-GPU: dispatch across GPUs using Accelerate
        from accelerate import infer_auto_device_map, dispatch_model

        print(f"Sharding model across GPUs (device_map={device_map})...")

        if device_map == "auto":
            device_map_computed = infer_auto_device_map(
                bigbird,
                no_split_module_classes=["BigBirdLayer"],
            )
        else:
            device_map_computed = device_map

        print(f"Computed device_map: {device_map_computed}")
        bigbird = dispatch_model(bigbird, device_map=device_map_computed)
        model_input_device = next(iter(bigbird.parameters())).device
        print(f"Model sharded across GPUs. Input device: {model_input_device}")

    global MODEL_INPUT_DEVICE
    MODEL_INPUT_DEVICE = model_input_device
    return bigbird


def load_tokenizer(model_name="BigBird", model_size="base"):
    """Load tokenizer for the specified model.

    Args:
        model_name: Model family (currently only "BigBird" supported)
        model_size: "base" or "large"
    """
    if model_name == "BigBird":
        config = get_model_config(model_size)
        model_id = config["model_id"]
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_id)
        except Exception:
            #tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
            tokenizer = BigBirdTokenizer.from_pretrained(model_id)
    else:
        raise Exception("Other model than BigBird not supported yet")
    return tokenizer


def tokenize(text, model_name="BigBird", max_length=4096, model_size="base",
             device_override: Optional[torch.device] = None):
    """Return token embeddings (dict with IDs and attn mask) on device."""
    tokenizer = load_tokenizer(model_name, model_size=model_size)
    tokens = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    target_device = device_override if device_override is not None else MODEL_INPUT_DEVICE
    return {k: v.to(target_device) for k, v in tokens.items()}


# ============== TEXT LOADING ==============

def load_long_text(max_length):
    """Load enough text from wikitext-103 to fill max_length tokens (legacy function)."""
    dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
    estimated_entries = max(10000, max_length * 50 // 20)
    estimated_entries = min(estimated_entries, len(dataset))

    print(f"Loading {estimated_entries} entries to fill {max_length} tokens...")
    texts = []
    for i in range(estimated_entries):
        text = dataset[i]["text"]
        if text.strip():
            texts.append(text)

    return " ".join(texts)


def load_text_flexible(
    max_length: int,
    source: str = "wiki",
    language: str = "en",
    mix_config: Optional[Dict[str, float]] = None,
    mix_strategy: str = "concat",
    shuffle: str = "none",
    experiment_config: Optional[str] = None,
    language_mix_config: Optional[Dict[str, float]] = None,
    repeat_config: Optional[Dict] = None,
    verbose: bool = True
) -> str:
    """
    Flexible text loading with multiple source options.

    This function wraps the text_loaders module to provide easy access
    to different text distributions for convergence experiments.

    Args:
        max_length: Target number of tokens
        source: Text source - one of:
            - "wiki": Wikipedia (default)
            - "news": News articles
            - "books": BookCorpus
            - "code": GitHub code
            - "scientific": ArXiv papers
            - "openwebtext": Web pages
            - "mr_niah": MR-NIAH benchmark haystack
            - "mixed": Multiple sources (requires mix_config)
        language: Language code for wiki/mr_niah source
            - "en", "de", "fr", "es", "zh", "ar", "ru", "ja"
            - "mixed": For mixing multiple languages (requires language_mix_config)
        mix_config: For source="mixed", dict of source->proportion
            Example: {"wiki": 0.5, "news": 0.5}
        mix_strategy: How to combine mixed sources/languages
            - "concat": Concatenate blocks
            - "interleave": Interleave sentences
            - "random": Random word sampling
        shuffle: Shuffle level for ablation studies
            - "none": No shuffling (default)
            - "sentence": Shuffle sentences
            - "paragraph": Shuffle paragraphs
            - "word": Shuffle all words (destroys grammar)
        experiment_config: Use predefined experiment config by name
            (overrides other source parameters)
        language_mix_config: For language="mixed", dict of lang->proportion
            Example: {"en": 0.5, "de": 0.5}
        verbose: Print loading progress

    Returns:
        Text string ready for tokenization
    """
    if experiment_config is not None:
        # Use predefined experiment configuration
        return load_for_experiment(max_length, experiment_config, verbose=verbose)
    else:
        # Use flexible loading (guard repeat_config for older loaders)
        load_kwargs = dict(
            max_length=max_length,
            source=source,
            language=language,
            mix_config=mix_config,
            mix_strategy=mix_strategy,
            shuffle=shuffle,
            language_mix_config=language_mix_config,
            verbose=verbose,
        )
        try:
            import inspect
            if "repeat_config" in inspect.signature(load_text).parameters:
                load_kwargs["repeat_config"] = repeat_config
        except Exception:
            # If signature inspection fails, try without repeat_config
            pass
        return load_text(**load_kwargs)


# ============== LAYER 0 DIRECT EMBEDDINGS ==============

def get_layer0_embeddings_direct(model, tokens):
    """Get embeddings directly WITHOUT running attention for layer 0."""
    with torch.no_grad():
        input_ids = tokens['input_ids']
        token_embeds = model.embeddings.word_embeddings(input_ids)
        seq_length = input_ids.shape[1]
        position_ids = torch.arange(seq_length, device=input_ids.device).unsqueeze(0)
        position_embeds = model.embeddings.position_embeddings(position_ids)
        token_type_ids = torch.zeros_like(input_ids)
        token_type_embeds = model.embeddings.token_type_embeddings(token_type_ids)
        embeddings = token_embeds + position_embeds + token_type_embeds
        embeddings = model.embeddings.LayerNorm(embeddings)
    return embeddings


def get_embeddings_chunked(model, tokens, layer_id, chunk_size=4096):
    """Get embeddings by processing in chunks."""
    if layer_id == 0:
        return get_layer0_embeddings_direct(model, tokens)
    else:
        print(f"  Layer {layer_id} > 0: Using block_sparse for efficient forward pass")
        original_attn_type = model.config.attention_type
        original_device = next(model.parameters()).device

        def _move_tokens(tok, dev):
            return {k: v.to(dev) for k, v in tok.items()}

        try:
            model.set_attention_type('block_sparse')
            X = get_embeddings(tokens, model, layer_id)
            return X
        except RuntimeError as e:
            err_str = str(e)
            is_cuda_oom = "CUDA error" in err_str or "CUBLAS_STATUS_ALLOC_FAILED" in err_str or "out of memory" in err_str
            if not is_cuda_oom:
                raise
            print("  CUDA OOM during embedding forward pass. Falling back to CPU for embeddings...")
            try:
                model.to("cpu")
                cpu_tokens = _move_tokens(tokens, "cpu")
                X_cpu = get_embeddings(cpu_tokens, model, layer_id)
                X = X_cpu.to(original_device)
                return X
            finally:
                model.to(original_device)
        finally:
            model.set_attention_type(original_attn_type)


# ============== CHUNKED ATTENTION (MEMORY-EFFICIENT) ==============
# ============== FORWARD PASS & HOOKS ==============

def forward_pass(tokens, model):
    """Make a forward pass of the tokens through model."""
    model.eval()
    with torch.no_grad():
        model(**tokens)


def get_embeddings(tokens, model, layer_id=0):
    """Get embeddings from hook at layer layer_id before attention block."""
    global cache
    cache = {}

    def pre_attention_hook(module, inputs):
        cache["in"] = inputs[0].detach()

    h = model.encoder.layer[layer_id].attention.self.register_forward_pre_hook(pre_attention_hook)
    forward_pass(tokens, model)
    h.remove()
    return cache["in"]


# ============== QKV EXTRACTION ==============

def extract_qkv(model, layer_id=0):
    """Get Q, K, V weight matrices and biases from attention in layer l."""
    attn_self = model.encoder.layer[layer_id].attention.self

    Wq = attn_self.query.weight.detach()
    Wk = attn_self.key.weight.detach()
    Wv = attn_self.value.weight.detach()

    bq = attn_self.query.bias.detach() if attn_self.query.bias is not None else None
    bk = attn_self.key.bias.detach() if attn_self.key.bias is not None else None
    bv = attn_self.value.bias.detach() if attn_self.value.bias is not None else None

    return Wq, Wk, Wv, bq, bk, bv


def get_attention_parameter(model, X, layer_id):
    """Compute A = ||S^{1/2} K^T Q|| for layer layer_id (scalar sparsity measure)."""
    Wq, Wk, _, _, _, _ = extract_qkv(model, layer_id)
    A_mat = Wk.T @ Wq  # [768, 768] or [1024,1024]
    _, cov, _ = stats_full_distribution(X)
    # cov has shape [b, d, d], squeeze batch dim since b=1
    cov = cov.squeeze(0)  # [d, d]
    cov_reg = cov + 1e-6 * torch.eye(cov.shape[-1], device=cov.device)
    cov12 = torch.linalg.cholesky(cov_reg)
    A_real = torch.linalg.matrix_norm(cov12 @ A_mat, ord=2).item()
    return A_real, A_mat


def get_theoretical_rate(model, X, layer_id):
    """Returns rate, attention parameter A, and Horizon H."""
    Wq, Wk, _, _, _, _ = extract_qkv(model, layer_id)
    A_mat = Wk.T @ Wq
    _, cov, max_eigen_cov = stats_full_distribution(X)
    cov_reg = cov + 1e-6 * torch.eye(cov.shape[-1], device=cov.device).unsqueeze(0)
    cov12 = torch.linalg.cholesky(cov_reg)
    H = torch.linalg.matrix_norm(cov12 @ A_mat, ord=2, dim=(-2, -1))
    Psi = max_eigen_cov * H**2
    rate = 0.5 / (1 + 32*Psi)
    return rate, A_mat, H


def get_qkv_layers(model, X, layer_id=0):
    """Returns Q, K, V using directly the linear layers of the model."""
    attn_self = model.encoder.layer[layer_id].attention.self
    with torch.no_grad():
        Q = attn_self.query(X)
        K = attn_self.key(X)
        V = attn_self.value(X)
    return Q, K, V


def get_qkv_scaled(model, X, layer_id, qk_scale=1.0):
    """Returns Q, K, V (unscaled) for A-sweep.

    Note: We no longer scale Q and K here to avoid numerical overflow.
    Instead, the temperature_scale is passed to the attention operator.
    """
    attn_self = model.encoder.layer[layer_id].attention.self
    with torch.no_grad():
        Q = attn_self.query(X)
        K = attn_self.key(X)
        V = attn_self.value(X)
    return Q, K, V


# ============== STATISTICS ==============

def stats_full_distribution(X):
    """Returns mean, cov, and max eigenvalue of cov."""
    mean, cov = get_mean_cov(X)
    eigenvalues = torch.linalg.eigvalsh(cov)
    max_eigen = eigenvalues[:, -1]
    return mean, cov, max_eigen


def get_mean_cov(input):
    """Input [b, n, d]. Returns mean [b, d] and cov [b, d, d].
    (Memory-efficient: does NOT allocate a centered copy of x).
    """
    n = input.size(1)
    mean = input.mean(dim=1)  # [b, d]
    m2 = (input.transpose(1, 2) @ input) / n  # [b, d, d]
    cov = m2 - mean.unsqueeze(2) @ mean.unsqueeze(1)
    return mean, cov


# ============== ATTENTION OPERATOR ==============

def transpose_for_scores(x, num_attention_heads=12, attention_head_size=64):
    """Reshape from (batch, seq_length, embed_dim) to (batch, num_heads, seq_length, head_size).

    Args:
        x: Input tensor of shape (batch, seq_length, embed_dim)
        num_attention_heads: Number of attention heads (12 for base, 16 for large)
        attention_head_size: Size of each attention head (64 for both base and large)
    """
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)


def attentionOperator(Q, K, V, tiling=False, attention_head_size=64, all_head_size=768, num_attention_heads=12, temperature_scale=1.0):
    """Compute attention with precomputed Q, K, V.

    Args:
        Q, K, V: Query, Key, Value tensors of shape (batch, seq_len, hidden_size)
        tiling: Use FlashAttention via scaled_dot_product_attention
        attention_head_size: Size of each attention head (64 for both base and large)
        all_head_size: Total hidden size (768 for base, 1024 for large)
        num_attention_heads: Number of attention heads (12 for base, 16 for large)
        temperature_scale: Scale factor for attention scores. A value > 1 sharpens
            the attention distribution. Equivalent to scaling Q OR K by temperature_scale
            (which scales A = K^T Q by temperature_scale).
            Applied in a numerically stable way to avoid overflow.
    """
    query_layer = transpose_for_scores(Q, num_attention_heads, attention_head_size)
    key_layer = transpose_for_scores(K, num_attention_heads, attention_head_size)
    value_layer = transpose_for_scores(V, num_attention_heads, attention_head_size)

    if tiling:
        # For scaled_dot_product_attention, we need to pass a custom scale
        # scale = temperature_scale / sqrt(d) instead of 1 / sqrt(d)
        custom_scale = temperature_scale / math.sqrt(attention_head_size)
        context_layer = F.scaled_dot_product_attention(
            query_layer, key_layer, value_layer, scale=custom_scale
        )
    else:
        # Numerically stable attention with temperature scaling
        # Compute raw scores first (without scaling)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # For numerical stability: subtract max, then apply temperature scaling
        # This avoids overflow when temperature_scale is large
        max_scores = attention_scores.max(dim=-1, keepdim=True).values
        attention_scores_shifted = attention_scores - max_scores

        # Apply temperature scaling to the shifted scores
        # softmax(T * x) = softmax(T * (x - max)) since subtracting constant doesn't change softmax
        effective_scale = temperature_scale / math.sqrt(attention_head_size)
        attention_scores_scaled = attention_scores_shifted * effective_scale

        attention_probs = nn.functional.softmax(attention_scores_scaled, dim=-1)
        context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    return context_layer


# ============== METRICS ==============

def metric_error(mean, cov, mean_lim, cov_lim):
    """Returns mean and cov errors as floats."""
    mean_err = torch.linalg.vector_norm(mean - mean_lim, ord=2, dim=-1).mean()
    cov_err = torch.linalg.matrix_norm(cov - cov_lim, ord='fro', dim=(-2, -1)).mean()
    return mean_err.item(), cov_err.item()


# ============== SAMPLING ==============

def sample(X, n, replace):
    """Sample n tokens from X along dim 1."""
    if replace:
        indices = torch.randint(0, X.size(1), (n,), device=X.device)
    else:
        indices = torch.randperm(X.size(1), device=X.device)[:n]
    return X[:, indices, :]


# ============== WEIGHTED LEAST SQUARES ==============

def wls_fit_full(logx: np.ndarray, logy: np.ndarray, var_logy: np.ndarray):
    """
    Weighted least squares: logy = intercept + slope * logx
    Returns: slope, intercept, cov_beta (2x2) for beta=[intercept, slope].
    """
    var_logy = np.asarray(var_logy, dtype=float)
    var_logy = np.maximum(var_logy, 1e-18)

    w = 1.0 / var_logy
    X = np.column_stack((np.ones_like(logx), logx))
    XtW = X.T * w
    beta = np.linalg.solve(XtW @ X, XtW @ logy)

    resid = logy - X @ beta
    dof = max(logx.size - 2, 1)
    sigma2 = (w * resid**2).sum() / dof
    cov = np.linalg.inv(XtW @ X) * sigma2

    intercept = float(beta[0])
    slope = float(beta[1])
    return slope, intercept, cov


# ============== MONTE CARLO ==============

def one_monte_carlo(model, X, n, k, mean_lim, cov_lim, layer_id, replace=True, tiling=False, temperature_scale=1.0,
                    attention_head_size=64, all_head_size=768, num_attention_heads=12, mc_device=None):
    """Run k Monte Carlo experiments of sample size n with optional temperature scaling.

    Args:
        model: The BigBird model
        X: Input embeddings
        n: Sample size
        k: Number of Monte Carlo repetitions
        mean_lim, cov_lim: Limit statistics for error computation
        layer_id: Layer index
        replace: Sample with replacement
        tiling: Use FlashAttention
        temperature_scale: Scale factor for attention temperature. To simulate scaling Q and K
            by qk_scale, use temperature_scale = qk_scale^2.
        attention_head_size: Size of each attention head (64 for both base and large)
        all_head_size: Total hidden size (768 for base, 1024 for large)
        num_attention_heads: Number of attention heads (12 for base, 16 for large)
    """
    mean_errors = []
    cov_errors = []

    for _ in range(k):
        X_sample = sample(X, n, replace)
        if mc_device is not None and X_sample.device != mc_device:
            X_sample = X_sample.to(mc_device)

        Q, K, V = get_qkv_layers(model, X_sample, layer_id)
        if tiling:
            attn_layer = attentionOperator(Q, K, V, tiling=tiling, temperature_scale=temperature_scale,
                                           attention_head_size=attention_head_size, all_head_size=all_head_size,
                                           num_attention_heads=num_attention_heads)
            mean_est, cov_est = get_mean_cov(attn_layer)
            del attn_layer
        else:
            mean_est, cov_est = chunked_attention_memory_efficient(
                Q, K, V,
                attention_head_size=attention_head_size,
                all_head_size=all_head_size,
                use_fp16=True,
                temperature_scale=temperature_scale,
            )
        mean_err, cov_err = metric_error(mean_est, cov_est, mean_lim, cov_lim)

        mean_errors.append(mean_err)
        cov_errors.append(cov_err)

    mean_errors = np.array(mean_errors)
    cov_errors = np.array(cov_errors)

    mean_err_mean = mean_errors.mean()
    mean_err_std = mean_errors.std(ddof=1)
    cov_err_mean = cov_errors.mean()
    cov_err_std = cov_errors.std(ddof=1)

    return mean_err_mean, mean_err_std, cov_err_mean, cov_err_std


def full_monte_carlo(model, X, n_min, n_max, nb_tot, k, mean_lim, cov_lim, layer_id,
                     replace=True, tiling=False, temperature_scale=1.0, desc="MC experiments",
                     attention_head_size=64, all_head_size=768, num_attention_heads=12, mc_device=None):
    """Full MC experiment for one layer/A-value.

    Args:
        model: The BigBird model
        X: Input embeddings
        n_min, n_max: Sample size range
        nb_tot: Number of sample sizes
        k: Monte Carlo repetitions per sample size
        mean_lim, cov_lim: Limit statistics for error computation
        layer_id: Layer index
        replace: Sample with replacement
        tiling: Use FlashAttention
        temperature_scale: Scale factor for attention temperature. To simulate scaling Q and K
            by qk_scale, use temperature_scale = qk_scale^2.
        desc: Progress bar description
        attention_head_size: Size of each attention head (64 for both base and large)
        all_head_size: Total hidden size (768 for base, 1024 for large)
        num_attention_heads: Number of attention heads (12 for base, 16 for large)
    """
    n_values = np.geomspace(n_min, n_max, nb_tot).astype(int)

    mean_err_means = []
    mean_err_stds = []
    cov_err_means = []
    cov_err_stds = []
    var_log_mean = []
    var_log_cov = []

    if mc_device is None:
        mc_device = X.device
    mean_lim_mc = mean_lim.to(mc_device) if hasattr(mean_lim, "to") else mean_lim
    cov_lim_mc = cov_lim.to(mc_device) if hasattr(cov_lim, "to") else cov_lim

    for n in tqdm(n_values, desc=desc):
        try:
            mean_m, mean_s, cov_m, cov_s = one_monte_carlo(
                model, X, n, k, mean_lim_mc, cov_lim_mc, layer_id, replace, tiling, temperature_scale,
                attention_head_size=attention_head_size, all_head_size=all_head_size,
                num_attention_heads=num_attention_heads, mc_device=mc_device
            )
        except RuntimeError as e:
            if (mc_device.type == "cuda") and _is_cuda_oom(e):
                print("  CUDA OOM during MC; falling back to CPU for MC.")
                try:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    model.to("cpu")
                except Exception:
                    pass
                mc_device = torch.device("cpu")
                mean_lim_mc = mean_lim.to(mc_device) if hasattr(mean_lim, "to") else mean_lim
                cov_lim_mc = cov_lim.to(mc_device) if hasattr(cov_lim, "to") else cov_lim
                mean_m, mean_s, cov_m, cov_s = one_monte_carlo(
                    model, X, n, k, mean_lim_mc, cov_lim_mc, layer_id, replace, tiling, temperature_scale,
                    attention_head_size=attention_head_size, all_head_size=all_head_size,
                    num_attention_heads=num_attention_heads, mc_device=mc_device
                )
            else:
                raise
        mean_err_means.append(mean_m)
        mean_err_stds.append(mean_s)
        cov_err_means.append(cov_m)
        cov_err_stds.append(cov_s)

        # Variance of log for WLS
        se_m = mean_s / np.sqrt(k)
        se_c = cov_s / np.sqrt(k)
        var_log_mean.append((se_m / max(mean_m, 1e-12))**2)
        var_log_cov.append((se_c / max(cov_m, 1e-12))**2)

    return {
        'n_vals': np.array(n_values),
        'mean_err_mean': np.array(mean_err_means),
        'mean_err_std': np.array(mean_err_stds),
        'cov_err_mean': np.array(cov_err_means),
        'cov_err_std': np.array(cov_err_stds),
        'var_log_mean': np.array(var_log_mean),
        'var_log_cov': np.array(var_log_cov),
    }


# ============== SEXY PLOTTING FUNCTIONS ==============

def _plot_convergence_overlay(ax, n_values, err_means, err_stds, slope, intercept,
                               fit_cov, label, color, highlight=False, show_data=True):
    """Plot convergence with confidence bands - sexy version."""
    ms = 8 if highlight else 5
    lw = 2.5 if highlight else 1.5
    alpha_pts = 0.95 if highlight else 0.7
    alpha_band = 0.25 if highlight else 0.12
    z = 10 if highlight else 2

    if show_data:
        ax.errorbar(
            n_values.astype(float),
            err_means,
            yerr=err_stds / np.sqrt(len(err_stds)),  # SE of mean
            fmt='o', capsize=4, markersize=ms, capthick=1.5,
            label=f"{label}", color=color, alpha=alpha_pts, zorder=z,
            markeredgecolor='white', markeredgewidth=0.5 if highlight else 0
        )

    # Fit line
    n_fit = np.geomspace(float(n_values.min()), float(n_values.max()), 150)
    log_n_fit = np.log(n_fit)
    fit_line = np.exp(intercept) * (n_fit ** slope)

    slope_se = np.sqrt(max(fit_cov[1, 1], 0))
    ax.plot(n_fit, fit_line, "--", linewidth=lw, color=color,
            label=f"slope={slope:.3f}$\\pm${slope_se:.3f}" if not show_data else None,
            zorder=z)

    # Confidence band
    fit_var = fit_cov[0, 0] + (log_n_fit**2) * fit_cov[1, 1] + 2.0 * log_n_fit * fit_cov[0, 1]
    fit_std = np.sqrt(np.maximum(fit_var, 0.0))
    upper = np.exp(intercept + slope * log_n_fit + fit_std)
    lower = np.exp(intercept + slope * log_n_fit - fit_std)
    ax.fill_between(n_fit, lower, upper, alpha=alpha_band, color=color, zorder=z-1)


def plot_convergence_sexy(n_values, mean_err_means, mean_err_stds, cov_err_means, cov_err_stds,
                          mean_fit, cov_fit, title_suffix="", save_path=None):
    """Create sexy convergence plots with improved styling."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5.5))

    slope_m, int_m, cov_m = mean_fit
    slope_c, int_c, cov_c = cov_fit

    # Compute standard errors for display
    slope_m_se = np.sqrt(max(cov_m[1, 1], 0))
    slope_c_se = np.sqrt(max(cov_c[1, 1], 0))

    # Mean error plot
    ax = axes[0]
    _plot_convergence_overlay(
        ax, n_values, mean_err_means, mean_err_stds,
        slope_m, int_m, cov_m,
        label=f"Mean Error (slope={slope_m:.3f}$\\pm${slope_m_se:.3f})",
        color=COLORS['primary'], highlight=True
    )

    # Add reference line for O(n^{-0.5})
    n_ref = np.geomspace(n_values.min(), n_values.max(), 100)
    ref_line = n_ref ** (-0.5) * mean_err_means[0] * (n_values[0] ** 0.5)
    ax.plot(n_ref, ref_line, ':', color='gray', alpha=0.7, label=r'$O(n^{-0.5})$ reference')

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Sample size $n$', fontweight='medium')
    ax.set_ylabel(r'Mean Error $\|\hat{\mu}_n - \mu^*\|_2$', fontweight='medium')
    ax.set_title(f'Mean Convergence {title_suffix}', fontweight='bold', pad=10)
    ax.legend(loc='upper right', framealpha=0.95, edgecolor='lightgray')
    ax.grid(True, alpha=0.3, which='both')

    # Covariance error plot
    ax = axes[1]
    _plot_convergence_overlay(
        ax, n_values, cov_err_means, cov_err_stds,
        slope_c, int_c, cov_c,
        label=f"Cov Error (slope={slope_c:.3f}$\\pm${slope_c_se:.3f})",
        color=COLORS['secondary'], highlight=True
    )

    # Add reference line
    ref_line = n_ref ** (-0.5) * cov_err_means[0] * (n_values[0] ** 0.5)
    ax.plot(n_ref, ref_line, ':', color='gray', alpha=0.7, label=r'$O(n^{-0.5})$ reference')

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Sample size $n$', fontweight='medium')
    ax.set_ylabel(r'Cov Error $\|\hat{\Sigma}_n - \Sigma^*\|_F$', fontweight='medium')
    ax.set_title(f'Covariance Convergence {title_suffix}', fontweight='bold', pad=10)
    ax.legend(loc='upper right', framealpha=0.95, edgecolor='lightgray')
    ax.grid(True, alpha=0.3, which='both')

    plt.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
        print(f"Saved plot to {save_path}")

    return fig


def plot_A_sweep_sanity(results_by_A: List[Dict], A_real: float, layer_id: int,
                        save_dir: str = "results/sanity_plots"):
    """Create sanity check plots for A-sweep showing all A values overlaid."""
    os.makedirs(save_dir, exist_ok=True)

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Sort by A value for consistent coloring
    sorted_results = sorted(results_by_A, key=lambda x: x['A_target'])

    for i, res in enumerate(sorted_results):
        A_tgt = res['A_target']
        is_real = np.isclose(A_tgt, A_real, rtol=1e-6)

        if is_real:
            color = COLORS['real_A']
            label = f"A={A_tgt:.2e} (real)"
        else:
            color = COLORS['sweep'][i % len(COLORS['sweep'])]
            label = f"A={A_tgt:.2e}"

        # Mean error plot
        _plot_convergence_overlay(
            axes[0], res['n_vals'], res['mean_err_mean'], res['mean_err_std'],
            res['mean_fit'][0], res['mean_fit'][1], res['mean_fit'][2],
            label=label, color=color, highlight=is_real
        )

        # Cov error plot
        _plot_convergence_overlay(
            axes[1], res['n_vals'], res['cov_err_mean'], res['cov_err_std'],
            res['cov_fit'][0], res['cov_fit'][1], res['cov_fit'][2],
            label=label, color=color, highlight=is_real
        )

    axes[0].set_xscale('log')
    axes[0].set_yscale('log')
    axes[0].set_xlabel('Sample size $n$', fontweight='medium')
    axes[0].set_ylabel(r'Mean Error $\|\hat{\mu}_n - \mu^*\|_2$', fontweight='medium')
    axes[0].set_title(f'A-Sweep Sanity Check: Mean (Layer {layer_id})', fontweight='bold')
    axes[0].legend(loc='upper right', fontsize=9, framealpha=0.95)
    axes[0].grid(True, alpha=0.3)

    axes[1].set_xscale('log')
    axes[1].set_yscale('log')
    axes[1].set_xlabel('Sample size $n$', fontweight='medium')
    axes[1].set_ylabel(r'Cov Error $\|\hat{\Sigma}_n - \Sigma^*\|_F$', fontweight='medium')
    axes[1].set_title(f'A-Sweep Sanity Check: Covariance (Layer {layer_id})', fontweight='bold')
    axes[1].legend(loc='upper right', fontsize=9, framealpha=0.95)
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()

    save_path = os.path.join(save_dir, f"A_sweep_sanity_layer{layer_id}.png")
    fig.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
    print(f"Saved A-sweep sanity plot to {save_path}")

    return fig


def plot_rate_vs_A(results_by_A: List[Dict], A_real: float, theoretical_rate: float,
                   max_eigen: float, layer_id: int, save_dir: str = "results"):
    """Plot empirical convergence rate vs A, with theoretical curve."""
    os.makedirs(save_dir, exist_ok=True)

    fig, ax = plt.subplots(figsize=(10, 6))

    # Theoretical curve
    A_range = np.geomspace(min(r['A_target'] for r in results_by_A) * 0.5,
                           max(r['A_target'] for r in results_by_A) * 2, 200)
    const = 32 * max_eigen
    rate_curve = 0.5 / (1 + const * A_range**2)

    ax.plot(A_range, rate_curve, '-', color=COLORS['dark'], linewidth=2.5,
            label='Theoretical: $0.5/(1 + 32\\lambda_{max} A^2)$', zorder=1)

    # Empirical points
    for res in results_by_A:
        A_tgt = res['A_target']
        is_real = np.isclose(A_tgt, A_real, rtol=1e-6)

        color = COLORS['real_A'] if is_real else COLORS['accent']
        marker = 's' if is_real else 'o'
        ms = 12 if is_real else 8

        # Mean rate
        ax.errorbar(A_tgt, abs(res['slope_m_mean']), yerr=res['slope_m_se'],
                   fmt=marker, color=color, markersize=ms, capsize=5, capthick=2,
                   label=f"Mean (A={A_tgt:.2e})" + (" [real]" if is_real else ""),
                   markeredgecolor='white', markeredgewidth=1, zorder=5)

        # Cov rate (slightly offset for visibility)
        ax.errorbar(A_tgt * 1.05, abs(res['slope_c_mean']), yerr=res['slope_c_se'],
                   fmt='^', color=COLORS['secondary'] if not is_real else COLORS['success'],
                   markersize=ms-2, capsize=5, capthick=2, alpha=0.8,
                   markeredgecolor='white', markeredgewidth=1, zorder=4)

    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Rate = 0.5')
    ax.axvline(x=A_real, color=COLORS['real_A'], linestyle=':', alpha=0.7,
               label=f'$A_{{real}}$ = {A_real:.2e}')

    ax.set_xscale('log')
    ax.set_xlabel('Attention Parameter $A = \\|\\Sigma^{1/2} K^T Q\\|$', fontweight='medium')
    ax.set_ylabel('Convergence Rate $|\\alpha|$', fontweight='medium')
    ax.set_title(f'Rate vs A (Layer {layer_id})', fontweight='bold', pad=10)
    ax.legend(loc='best', fontsize=9, framealpha=0.95)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 0.6)

    plt.tight_layout()

    save_path = os.path.join(save_dir, f"rate_vs_A_layer{layer_id}.png")
    fig.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
    print(f"Saved rate vs A plot to {save_path}")

    return fig


# ============== POLYNOMIAL FIT ==============

def polynomial_fit(n_values, mean_err_means, mean_err_stds, cov_err_means, cov_err_stds,
                   var_log_mean=None, var_log_cov=None, sanity_checks=True, save_dir=None):
    """Fit convergence rates with full covariance estimation."""
    log_n = np.log(n_values)

    # Use provided variances or compute from stds
    if var_log_mean is None:
        k = 10  # default
        se_mean = mean_err_stds / np.sqrt(k)
        var_log_mean = (se_mean / np.maximum(mean_err_means, 1e-12))**2
    if var_log_cov is None:
        k = 10
        se_cov = cov_err_stds / np.sqrt(k)
        var_log_cov = (se_cov / np.maximum(cov_err_means, 1e-12))**2

    slope_m, int_m, cov_m = wls_fit_full(log_n, np.log(np.maximum(mean_err_means, 1e-12)), var_log_mean)
    slope_c, int_c, cov_c = wls_fit_full(log_n, np.log(np.maximum(cov_err_means, 1e-12)), var_log_cov)

    slope_m_se = np.sqrt(max(cov_m[1, 1], 0))
    slope_c_se = np.sqrt(max(cov_c[1, 1], 0))

    print(f"Mean error convergence: O(n^{slope_m:.3f} +/- {slope_m_se:.3f})")
    print(f"Cov error convergence: O(n^{slope_c:.3f} +/- {slope_c_se:.3f})")

    if sanity_checks:
        fig = plot_convergence_sexy(
            n_values, mean_err_means, mean_err_stds, cov_err_means, cov_err_stds,
            mean_fit=(slope_m, int_m, cov_m),
            cov_fit=(slope_c, int_c, cov_c),
            save_path=os.path.join(save_dir, "convergence.png") if save_dir else None
        )
        plt.close(fig)  # Close figure to avoid blocking and free memory

    return slope_m, slope_m_se, slope_c, slope_c_se, (slope_m, int_m, cov_m), (slope_c, int_c, cov_c)


# ============== SAVE FUNCTIONS ==============

def save_layer_data(results, layer_id, output_dir="results/data"):
    """Save MC experiment data for one layer to CSV."""
    os.makedirs(output_dir, exist_ok=True)

    df = pd.DataFrame(results)
    path = f"{output_dir}/layer_{layer_id}_results.csv"
    df.to_csv(path, index=False)
    print(f"Saved layer {layer_id} data to {path}")


# ============== RUN EXPERIMENT ==============

def run_experiment(N=1, layer_id=0, n_min=100, n_max=500, nb_tot=15, k=10,
                   replace=False, tiling=None, plot=True, plot_layer=False,
                   A_sweep=False, A_cap=1e4, n_A_points=5, sanity_plot=False,
                   output_dir="results", add_timestamp=True,
                   # Model size parameter
                   model_size="base",
                   device_map: Optional[str] = None,
                   # Text source parameters
                   text_source="wiki", text_language="en",
                   text_mix_config=None, text_mix_strategy="concat",
                   text_shuffle="none", text_experiment_config=None,
                   text_language_mix_config=None,
                   text_repeat_config=None,
                   aggressive_cleanup=False):
    """
    Run full BigBird attention convergence experiment.

    Args:
        N: Multiplier for max_length (max_length = 4096 * N)
        layer_id: Layer to analyze (0-11 for base, 0-23 for large)
        n_min, n_max: Sample size range
        nb_tot: Number of sample sizes (log-spaced)
        k: Monte Carlo repetitions per sample size
        replace: Sample with replacement
        tiling: Use FlashAttention (None = auto-detect GPU)
        plot: Show convergence plots
        plot_layer: Show theoretical vs empirical rate plot
        A_sweep: If True, sweep over multiple A values
        A_cap: Upper bound for A sweep
        n_A_points: Number of A values in sweep
        sanity_plot: Generate sanity check plots for A-sweep
        output_dir: Directory to save results
        device_map: Optional HF device_map for sharded multi-GPU inference
        add_timestamp: Add timestamp to output directory to prevent overwrites
        model_size: "base" (768 hidden, 12 heads) or "large" (1024 hidden, 16 heads)

        Text source parameters:
        text_source: Source of text data ("wiki", "news", "books", "code",
                     "scientific", "openwebtext", "mr_niah", "mixed")
        text_language: Language code for wiki/mr_niah ("en", "de", "fr", "es", "zh", "ar", "ru", "ja")
            or "mixed" for mixing multiple languages
        text_mix_config: For mixed sources, dict of source->proportion
        text_mix_strategy: How to mix sources/languages ("concat", "interleave", "random")
        text_shuffle: Shuffle level ("none", "sentence", "paragraph", "word")
        text_experiment_config: Use predefined config (overrides other text params)
        text_language_mix_config: For language="mixed", dict of lang->proportion
            Example: {"en": 0.5, "de": 0.5}
        aggressive_cleanup: If True, free as much memory as possible after saving outputs
    """
    if tiling is None:
        tiling = cuda_available()

    max_length = 4096 * N

    # Get model configuration
    model_config = get_model_config(model_size)
    attention_head_size = model_config["attention_head_size"]
    all_head_size = model_config["hidden_size"]
    num_attention_heads = model_config["num_attention_heads"]

    # Add timestamp to output directory to prevent overwrites
    if add_timestamp:
        timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join(output_dir, f"layer{layer_id}_{timestamp}")

    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")

    # ============== CHECK DEVICE ==============
    print(f"{'='*60}")
    print(f"BigBird Attention Convergence Experiment")
    print(f"{'='*60}")
    print(f"Device: {device}")
    if cuda_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Configuration:")
    print(f"  N = {N} -> max_length = {max_length}")
    print(f"  layer_id = {layer_id}")
    print(f"  n_min = {n_min}, n_max = {n_max}, nb_tot = {nb_tot}, k = {k}")
    print(f"  A_sweep = {A_sweep}" + (f" (A_cap={A_cap:.2e}, n_points={n_A_points})" if A_sweep else ""))
    print(f"  Model = BigBird-RoBERTa-{model_size} (hidden={all_head_size}, heads={num_attention_heads})")
    print(f"  Text source = {text_source} ({text_language})")
    if text_shuffle != "none":
        print(f"  Text shuffle = {text_shuffle}")
    if text_language_mix_config:
        print(f"  Language mix = {text_language_mix_config}")
    if text_repeat_config:
        print(f"  Text repeat = {text_repeat_config}")
    if text_experiment_config:
        print(f"  Using experiment config: {text_experiment_config}")
    print(f"{'='*60}\n")

    # ============== LOAD MODEL & DATA ==============
    print("Loading model...")
    model = load_model(
        max_length=max_length,
        #attention_type='original_full',
        attention_type= "block_sparse",
        model_size=model_size,
        device_map=device_map,
    )

    print("Loading dataset...")
    text = load_text_flexible(
        max_length=max_length,
        source=text_source,
        language=text_language,
        mix_config=text_mix_config,
        mix_strategy=text_mix_strategy,
        shuffle=text_shuffle,
        experiment_config=text_experiment_config,
        language_mix_config=text_language_mix_config,
        repeat_config=text_repeat_config,
        verbose=True
    )

    print("Tokenizing...")
    tokens = tokenize(
        text,
        max_length=max_length,
        model_size=model_size,
        device_override=MODEL_INPUT_DEVICE,
    )
    actual_tokens = tokens['input_ids'].shape[1]
    print(f"Tokenized {actual_tokens} tokens (requested {max_length})")

    if actual_tokens < max_length:
        print(f"WARNING: Only got {actual_tokens}/{max_length} tokens!")
    if n_max > actual_tokens * 0.1:
        print(f"WARNING: n_max={n_max} is >10% of total tokens ({actual_tokens})")

    # ============== GET EMBEDDINGS ==============
    print(f"\nGetting embeddings for layer {layer_id}...")
    if layer_id == 0:
        X = get_layer0_embeddings_direct(model, tokens)
    else:
        X = get_embeddings_chunked(model, tokens, layer_id)
    print(f"Embeddings shape: {X.shape}")

    # ============== COMPUTE A_real AND THEORETICAL QUANTITIES ==============
    print("\nComputing attention parameters...")
    A_real, A_mat = get_attention_parameter(model, X, layer_id)
    theoretical_rate, _, H = get_theoretical_rate(model, X, layer_id)
    _, _, max_eigen = stats_full_distribution(X)

    print(f"A_real (||S^{{1/2}} K^T Q||): {A_real:.4e}")
    print(f"Theoretical rate: {theoretical_rate.item():.6f}")
    print(f"H (horizon): {H.item():.4f}")
    print(f"Max eigenvalue: {max_eigen.item():.4f}")

    # ============== A-SWEEP OR SINGLE RUN ==============
    mc_device = None
    if mc_cuda_available():
        mc_device = torch.device("cuda")

    if A_sweep:
        A_targets = make_A_targets(A_real, A_cap=A_cap, n_points=n_A_points)
        print(f"\nA-sweep targets: {[f'{a:.2e}' for a in A_targets]}")

        results_by_A = []

        for j, A_tgt in enumerate(A_targets):
            gamma = A_tgt / max(A_real, 1e-12)
            # temperature_scale = gamma: scales A by gamma (equivalent to scaling Q OR K by gamma)
            temperature_scale = gamma

            print(f"\n--- A target {j+1}/{len(A_targets)}: A={A_tgt:.2e} (gamma={gamma:.2f}, temperature_scale={temperature_scale:.3f}) ---")

            # Compute limits for this A value using temperature scaling (numerically stable)
            Q, K, V = get_qkv_layers(model, X, layer_id)

            if actual_tokens > 4096:
                # Disable fp16 for very long sequences to avoid numerical accumulation errors
                use_fp16 = actual_tokens < 50000
                if limits_cuda_available():
                    print("  Using GPU offload for FULL attention limits...")
                    try:
                        mean_lim, cov_lim = chunked_attention_cpu_offload(
                            Q, K, V,
                            attention_head_size=attention_head_size,
                            all_head_size=all_head_size,
                            temperature_scale=temperature_scale,
                            compute_device="cuda",
                        )
                    except RuntimeError as e:
                        if _is_cuda_oom(e):
                            print("  CUDA OOM during limits; falling back to CPU.")
                            mean_lim, cov_lim = chunked_attention_memory_efficient(
                                Q, K, V,
                                attention_head_size=attention_head_size,
                                all_head_size=all_head_size,
                                use_fp16=use_fp16,
                                temperature_scale=temperature_scale,
                            )
                        else:
                            raise
                else:
                    mean_lim, cov_lim = chunked_attention_memory_efficient(
                        Q, K, V,
                        attention_head_size=attention_head_size,
                        all_head_size=all_head_size,
                        use_fp16=use_fp16,
                        temperature_scale=temperature_scale,
                    )
            else:
                attn_full = attentionOperator(Q, K, V, tiling=tiling, temperature_scale=temperature_scale,
                                              attention_head_size=attention_head_size,
                                              all_head_size=all_head_size,
                                              num_attention_heads=num_attention_heads)
                mean_lim, cov_lim = get_mean_cov(attn_full)
                del attn_full

            # Check for NaN in limit statistics
            if torch.isnan(mean_lim).any() or torch.isnan(cov_lim).any():
                print(f"WARNING: NaN detected in limit statistics for A={A_tgt:.2e}!")
                print("  Skipping this A value.")
                del Q, K, V
                if cuda_available():
                    torch.cuda.empty_cache()
                continue
            del Q, K, V
            if cuda_available():
                torch.cuda.empty_cache()

            # Run MC with temperature scaling
            if mc_device is not None:
                model.to(mc_device)
            mc_results = full_monte_carlo(
                model, X, n_min, n_max, nb_tot, k,
                mean_lim, cov_lim, layer_id, replace=replace, tiling=tiling,
                temperature_scale=temperature_scale, desc=f"MC (A={A_tgt:.1e})",
                attention_head_size=attention_head_size, all_head_size=all_head_size,
                num_attention_heads=num_attention_heads, mc_device=mc_device
            )

            # Fit
            slope_m, slope_m_se, slope_c, slope_c_se, mean_fit, cov_fit = polynomial_fit(
                mc_results['n_vals'], mc_results['mean_err_mean'], mc_results['mean_err_std'],
                mc_results['cov_err_mean'], mc_results['cov_err_std'],
                mc_results['var_log_mean'], mc_results['var_log_cov'],
                sanity_checks=False
            )

            results_by_A.append({
                'A_target': A_tgt,
                'A_real': A_real,
                'is_real_A': np.isclose(A_tgt, A_real, rtol=1e-6),
                'gamma': gamma,
                'temperature_scale': temperature_scale,
                'slope_m_mean': slope_m,
                'slope_m_se': slope_m_se,
                'slope_c_mean': slope_c,
                'slope_c_se': slope_c_se,
                'n_vals': mc_results['n_vals'],
                'mean_err_mean': mc_results['mean_err_mean'],
                'mean_err_std': mc_results['mean_err_std'],
                'cov_err_mean': mc_results['cov_err_mean'],
                'cov_err_std': mc_results['cov_err_std'],
                'mean_fit': mean_fit,
                'cov_fit': cov_fit,
            })

        # Save results - CSV summary
        save_df = pd.DataFrame([{
            'A_target': r['A_target'],
            'A_real': r['A_real'],
            'is_real_A': r['is_real_A'],
            'gamma': r['gamma'],
            'slope_m_mean': r['slope_m_mean'],
            'slope_m_se': r['slope_m_se'],
            'slope_c_mean': r['slope_c_mean'],
            'slope_c_se': r['slope_c_se'],
        } for r in results_by_A])
        save_df.to_csv(os.path.join(output_dir, f"A_sweep_summary.csv"), index=False)

        # Save hyperparameters to CSV
        hyperparams_df = pd.DataFrame([{
            'N': N,
            'max_length': max_length,
            'actual_tokens': actual_tokens,
            'layer_id': layer_id,
            'n_min': n_min,
            'n_max': n_max,
            'nb_tot': nb_tot,
            'k': k,
            'A_cap': A_cap,
            'n_A_points': n_A_points,
            'A_real': A_real,
            'theoretical_rate': theoretical_rate.item(),
            'H': H.item(),
            'max_eigen': max_eigen.item(),
            'replace': replace,
            'tiling': tiling,
            # Text source parameters
            'text_source': text_source,
            'text_language': text_language,
            'text_shuffle': text_shuffle,
            'text_experiment_config': text_experiment_config,
            'text_mix_config': str(text_mix_config) if text_mix_config else None,
            'text_mix_strategy': text_mix_strategy,
            'text_language_mix_config': str(text_language_mix_config) if text_language_mix_config else None,
            'text_repeat_config': str(text_repeat_config) if text_repeat_config else None,
        }])
        hyperparams_df.to_csv(os.path.join(output_dir, "hyperparameters.csv"), index=False)
        print(f"Saved hyperparameters to {os.path.join(output_dir, 'hyperparameters.csv')}")

        # Save FULL data for re-plotting (pickle format)
        import pickle
        full_data = {
            'results_by_A': results_by_A,
            'A_real': A_real,
            'theoretical_rate': theoretical_rate.item(),
            'max_eigen': max_eigen.item(),
            'H': H.item(),
            'layer_id': layer_id,
            'N': N,
            'max_length': max_length,
            'actual_tokens': actual_tokens,
            'n_min': n_min,
            'n_max': n_max,
            'nb_tot': nb_tot,
            'k': k,
            'A_cap': A_cap,
            'config': {
                'N': N, 'layer_id': layer_id, 'n_min': n_min, 'n_max': n_max,
                'nb_tot': nb_tot, 'k': k, 'A_cap': A_cap, 'n_A_points': n_A_points
            }
        }
        pickle_path = os.path.join(output_dir, f"full_results.pkl")
        with open(pickle_path, 'wb') as f:
            pickle.dump(full_data, f)
        print(f"Saved full data for re-plotting to: {pickle_path}")

        # Plots
        if sanity_plot:
            plot_A_sweep_sanity(results_by_A, A_real, layer_id,
                               save_dir=os.path.join(output_dir, "sanity_plots"))

        if plot_layer:
            plot_rate_vs_A(results_by_A, A_real, theoretical_rate.item(),
                          max_eigen.item(), layer_id, save_dir=output_dir)

        # Also plot the real A result
        real_A_result = next((r for r in results_by_A if r['is_real_A']), results_by_A[0])
        if plot:
            plot_convergence_sexy(
                real_A_result['n_vals'], real_A_result['mean_err_mean'], real_A_result['mean_err_std'],
                real_A_result['cov_err_mean'], real_A_result['cov_err_std'],
                real_A_result['mean_fit'], real_A_result['cov_fit'],
                title_suffix=f"(Layer {layer_id}, A={A_real:.2e})",
                save_path=os.path.join(output_dir, f"convergence_layer{layer_id}.png")
            )

        # ============== CLEANUP MEMORY ==============
        if aggressive_cleanup:
            del results_by_A, save_df, hyperparams_df, full_data
            del real_A_result
        if mc_device is not None:
            model.to("cpu")
        del model, X, tokens
        gc.collect()
        if cuda_available():
            torch.cuda.empty_cache()
        print(f"Memory cleaned up for layer {layer_id}")

        return {
            'results_by_A': results_by_A,
            'A_real': A_real,
            'theoretical_rate': theoretical_rate.item(),
            'max_eigen': max_eigen.item(),
            'layer_id': layer_id,
            'output_dir': output_dir,
        }

    else:
        # ============== SINGLE RUN (no A-sweep) ==============
        print("\nComputing FULL attention limits...")
        Q_full, K_full, V_full = get_qkv_layers(model, X, layer_id)

        if actual_tokens > 4096:
            # Disable fp16 for very long sequences to avoid numerical accumulation errors
            use_fp16 = actual_tokens < 50000
            if not use_fp16:
                print(f"  Using fp32 for {actual_tokens} tokens (fp16 disabled for long sequences)")
            print(f"  Using chunked FULL attention for {actual_tokens} tokens...")
            if limits_cuda_available():
                print("  Using GPU offload for FULL attention limits...")
                try:
                    mean_lim, cov_lim = chunked_attention_cpu_offload(
                        Q_full, K_full, V_full,
                        attention_head_size=attention_head_size,
                        all_head_size=all_head_size,
                        compute_device="cuda",
                    )
                except RuntimeError as e:
                    if _is_cuda_oom(e):
                        print("  CUDA OOM during limits; falling back to CPU.")
                        mean_lim, cov_lim = chunked_attention_memory_efficient(
                            Q_full, K_full, V_full,
                            attention_head_size=attention_head_size,
                            all_head_size=all_head_size,
                            use_fp16=use_fp16,
                        )
                    else:
                        raise
            else:
                mean_lim, cov_lim = chunked_attention_memory_efficient(
                    Q_full, K_full, V_full,
                    attention_head_size=attention_head_size,
                    all_head_size=all_head_size,
                    use_fp16=use_fp16,
                )
        else:
            attn_full = attentionOperator(Q_full, K_full, V_full, tiling=tiling,
                                          attention_head_size=attention_head_size,
                                          all_head_size=all_head_size,
                                          num_attention_heads=num_attention_heads)
            mean_lim, cov_lim = get_mean_cov(attn_full)
            print(f"Full attention output shape: {attn_full.shape}")
            del attn_full

        # Check for NaN in limit statistics
        if torch.isnan(mean_lim).any() or torch.isnan(cov_lim).any():
            print("WARNING: NaN detected in limit statistics!")
            print("  This may be due to numerical issues with very long sequences.")
            print("  Try reducing N or check if embeddings contain extreme values.")
            raise ValueError("NaN detected in limit statistics. Cannot proceed with MC experiments.")

        del Q_full, K_full, V_full
        if cuda_available():
            torch.cuda.empty_cache()

        # ============== MONTE CARLO EXPERIMENTS ==============
        print(f"\nRunning MC experiments...")
        if mc_device is not None:
            model.to(mc_device)
        mc_results = full_monte_carlo(
            model, X, n_min, n_max, nb_tot, k,
            mean_lim, cov_lim, layer_id, replace=replace, tiling=tiling,
            attention_head_size=attention_head_size, all_head_size=all_head_size,
            num_attention_heads=num_attention_heads, mc_device=mc_device
        )

        # ============== FIT & PLOT ==============
        print("\nFitting convergence rates...")
        slope_m, slope_m_se, slope_c, slope_c_se, mean_fit, cov_fit = polynomial_fit(
            mc_results['n_vals'], mc_results['mean_err_mean'], mc_results['mean_err_std'],
            mc_results['cov_err_mean'], mc_results['cov_err_std'],
            mc_results['var_log_mean'], mc_results['var_log_cov'],
            sanity_checks=plot, save_dir=output_dir
        )

        # ============== PLOT LAYER (theoretical vs empirical) ==============
        if plot_layer:
            fig = plt.figure(figsize=(10, 6))
            ax = fig.add_subplot(111)

            const = 32 * max_eigen.item()
            h_range = np.logspace(np.log10(H.item()) - 3, np.log10(H.item()) + 2, 100)
            rate_curve = 0.5 / (1 + const * h_range**2)

            ax.plot(h_range, rate_curve, color=COLORS['dark'], linewidth=2.5,
                   label="Theoretical curve")

            ax.errorbar(H.item(), abs(slope_m), yerr=slope_m_se,
                       fmt='o', color=COLORS['primary'], markersize=12, capsize=6,
                       label=f"Mean (empirical): {slope_m:.3f}$\\pm${slope_m_se:.3f}",
                       markeredgecolor='white', markeredgewidth=1)
            ax.errorbar(H.item(), abs(slope_c), yerr=slope_c_se,
                       fmt='s', color=COLORS['secondary'], markersize=12, capsize=6,
                       label=f"Cov (empirical): {slope_c:.3f}$\\pm${slope_c_se:.3f}",
                       markeredgecolor='white', markeredgewidth=1)

            ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Rate = 0.5')

            ax.legend(loc='best', framealpha=0.95)
            ax.set_xlabel("$H$ (horizon scale)", fontweight='medium')
            ax.set_ylabel("Convergence rate $|\\alpha|$", fontweight='medium')
            ax.set_title(f"Theoretical vs Empirical Rate - Layer {layer_id}", fontweight='bold')
            ax.set_xscale('log')
            ax.grid(True, alpha=0.3)

            plt.tight_layout()
            save_path = os.path.join(output_dir, f"rate_vs_H_layer{layer_id}.png")
            fig.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
            plt.close(fig)  # Close figure to avoid blocking

        print(f"\n{'='*60}")
        print(f"RESULTS:")
        print(f"  Mean error convergence: O(n^{slope_m:.3f} +/- {slope_m_se:.3f})")
        print(f"  Cov error convergence:  O(n^{slope_c:.3f} +/- {slope_c_se:.3f})")
        print(f"  Theoretical rate:       {theoretical_rate.item():.6f}")
        print(f"{'='*60}")

        # Save hyperparameters to CSV
        hyperparams_df = pd.DataFrame([{
            'N': N,
            'max_length': max_length,
            'actual_tokens': actual_tokens,
            'layer_id': layer_id,
            'n_min': n_min,
            'n_max': n_max,
            'nb_tot': nb_tot,
            'k': k,
            'A_real': A_real,
            'theoretical_rate': theoretical_rate.item(),
            'H': H.item(),
            'max_eigen': max_eigen.item(),
            'slope_mean': slope_m,
            'slope_mean_se': slope_m_se,
            'slope_cov': slope_c,
            'slope_cov_se': slope_c_se,
            'replace': replace,
            'tiling': tiling,
            # Model parameters
            'model_size': model_size,
            'hidden_size': all_head_size,
            'num_attention_heads': num_attention_heads,
            'attention_head_size': attention_head_size,
            # Text source parameters
            'text_source': text_source,
            'text_language': text_language,
            'text_shuffle': text_shuffle,
            'text_experiment_config': text_experiment_config,
            'text_mix_config': str(text_mix_config) if text_mix_config else None,
            'text_mix_strategy': text_mix_strategy,
            'text_language_mix_config': str(text_language_mix_config) if text_language_mix_config else None,
            'text_repeat_config': str(text_repeat_config) if text_repeat_config else None,
        }])
        hyperparams_df.to_csv(os.path.join(output_dir, "hyperparameters.csv"), index=False)
        print(f"Saved hyperparameters to {os.path.join(output_dir, 'hyperparameters.csv')}")

        # ============== CLEANUP MEMORY (non-A-sweep) ==============
        # Move results to CPU and free GPU memory before returning
        mean_lim_cpu = mean_lim.cpu()
        cov_lim_cpu = cov_lim.cpu()
        theoretical_rate_val = theoretical_rate.item() if hasattr(theoretical_rate, 'item') else theoretical_rate
        H_val = H.item() if hasattr(H, 'item') else H
        max_eigen_val = max_eigen.item() if hasattr(max_eigen, 'item') else max_eigen

        # Save convergence data for re-plotting (CSV + pickle)
        convergence_df = pd.DataFrame({
            "n_vals": mc_results["n_vals"],
            "mean_err_mean": mc_results["mean_err_mean"],
            "mean_err_std": mc_results["mean_err_std"],
            "cov_err_mean": mc_results["cov_err_mean"],
            "cov_err_std": mc_results["cov_err_std"],
            "var_log_mean": mc_results["var_log_mean"],
            "var_log_cov": mc_results["var_log_cov"],
        })
        convergence_csv = os.path.join(output_dir, "convergence_data.csv")
        convergence_df.to_csv(convergence_csv, index=False)
        print(f"Saved convergence data to {convergence_csv}")

        import pickle
        full_data = {
            "mc_results": mc_results,
            "A_real": A_real,
            "theoretical_rate": theoretical_rate_val,
            "H": H_val,
            "max_eigen": max_eigen_val,
            "layer_id": layer_id,
            "N": N,
            "max_length": max_length,
            "actual_tokens": actual_tokens,
            "n_min": n_min,
            "n_max": n_max,
            "nb_tot": nb_tot,
            "k": k,
            "config": {
                "N": N,
                "layer_id": layer_id,
                "n_min": n_min,
                "n_max": n_max,
                "nb_tot": nb_tot,
                "k": k,
            },
        }
        pickle_path = os.path.join(output_dir, "full_results.pkl")
        with open(pickle_path, "wb") as f:
            pickle.dump(full_data, f)
        print(f"Saved full data for re-plotting to: {pickle_path}")

        # Build return payload before optional aggressive cleanup
        return_payload = {
            'actual_tokens': actual_tokens,
            'n_vals': mc_results['n_vals'],
            'mean_err_mean': mc_results['mean_err_mean'],
            'mean_err_std': mc_results['mean_err_std'],
            'cov_err_mean': mc_results['cov_err_mean'],
            'cov_err_std': mc_results['cov_err_std'],
            'slope_mean': slope_m,
            'slope_mean_std': slope_m_se,
            'slope_cov': slope_c,
            'slope_cov_std': slope_c_se,
            'theoretical_rate': theoretical_rate_val,
            'H': H_val,
            'max_eigen': max_eigen_val,
            'A_real': A_real,
            'layer_id': layer_id,
            'max_length': max_length,
            'mean_lim': mean_lim_cpu,
            'cov_lim': cov_lim_cpu,
            'output_dir': output_dir,
        }

        if aggressive_cleanup:
            del mc_results, hyperparams_df
        if mc_device is not None:
            model.to("cpu")
        del model, X, tokens, mean_lim, cov_lim
        gc.collect()
        if cuda_available():
            torch.cuda.empty_cache()
        print(f"Memory cleaned up for layer {layer_id}")

        return return_payload


# ============== CLI ==============

def parse_args():
    parser = argparse.ArgumentParser(description="BigBird Attention Convergence Experiment with A-sweep")

    parser.add_argument("--N", type=int, default=8, help="Multiplier for max_length (4096 * N)")
    parser.add_argument("--layer_id", type=int, default=0, help="Layer to analyze (0-11)")
    parser.add_argument("--n_min", type=int, default=100, help="Minimum sample size")
    parser.add_argument("--n_max", type=int, default=2000, help="Maximum sample size")
    parser.add_argument("--nb_tot", type=int, default=15, help="Number of sample sizes")
    parser.add_argument("--k", type=int, default=10, help="MC repetitions per sample size")
    parser.add_argument("--replace", action="store_true", help="Sample with replacement")
    parser.add_argument("--no_plot", action="store_true", help="Disable convergence plots")
    parser.add_argument("--plot_layer", action="store_true", help="Show theoretical vs empirical rate plot")
    parser.add_argument("--aggressive_cleanup", action="store_true",
                        help="Free as much memory as possible after saving outputs")

    # A-sweep options
    parser.add_argument("--A_sweep", action="store_true", help="Enable A-sweep experiment")
    parser.add_argument("--A_cap", type=float, default=1e4, help="Upper bound for A sweep")
    parser.add_argument("--n_A_points", type=int, default=5, dest="n_A_points", help="Number of A values in sweep")
    parser.add_argument("--sanity_plot", action="store_true", help="Generate sanity check plots")

    # Model size option
    parser.add_argument("--bigbird_large", action="store_true",
                        help="Use BigBird-RoBERTa-Large (1024 hidden, 16 heads) instead of Base (768 hidden, 12 heads)")
    parser.add_argument("--device_map", type=str, default=None,
                        help="HF device_map for sharded multi-GPU inference (e.g., 'auto')")

    parser.add_argument("--output_dir", type=str, default="results", help="Output directory")
    parser.add_argument("--no_timestamp", action="store_true", help="Don't add timestamp to output dir")
    parser.add_argument("--force_cpu", action="store_true",
                        help="Force CPU for the entire run (disables CUDA even if available)")
    parser.add_argument("--mc_on_gpu", action="store_true",
                        help="Run Monte Carlo on GPU even if main compute is on CPU")
    parser.add_argument("--limits_on_gpu", action="store_true",
                        help="Run full-attention limits with GPU matmuls via CPU offload")

    # Multi-layer options
    parser.add_argument("--all_layers", action="store_true",
                        help="Run all layers 0-11 sequentially")
    parser.add_argument("--layers", type=str, default=None,
                        help="Comma-separated list of layers to run (e.g., '1,2,3' or '1-5' or '1-3,7,9-11')")

    # Text source options
    parser.add_argument("--text_source", type=str, default="wiki",
                        choices=["wiki", "news", "books", "code", "scientific", "openwebtext", "mr_niah", "mixed"],
                        help="Text source (default: wiki)")
    parser.add_argument("--text_language", type=str, default="en",
                        choices=["en", "de", "fr", "es", "zh", "ar", "ru", "ja", "mixed"],
                        help="Language for wiki/mr_niah source, or 'mixed' for multi-language (default: en)")
    parser.add_argument("--text_shuffle", type=str, default="none",
                        choices=["none", "sentence", "paragraph", "word"],
                        help="Shuffle level for ablation studies (default: none)")
    parser.add_argument("--text_mix", type=str, default=None,
                        help="Mix config as 'source1:prop1,source2:prop2' (e.g., 'wiki:0.5,news:0.5')")
    parser.add_argument("--text_mix_strategy", type=str, default="concat",
                        choices=["concat", "interleave", "random"],
                        help="How to combine mixed sources/languages (default: concat)")
    parser.add_argument("--text_language_mix", type=str, default=None,
                        help="Language mix config as 'lang1:prop1,lang2:prop2' (e.g., 'en:0.5,de:0.5')")
    parser.add_argument("--text_config", type=str, default=None,
                        help="Use predefined experiment config (overrides other text params)")
    parser.add_argument("--list_text_configs", action="store_true",
                        help="List available predefined text configs and exit")
    parser.add_argument("--list_text_sources", action="store_true",
                        help="List available text sources and languages, then exit")

    return parser.parse_args()


def parse_layer_range(layers_str: str) -> List[int]:
    """Parse layer specification like '1,2,3' or '1-5' or '1-3,7,9-11'."""
    layers = []
    for part in layers_str.split(','):
        part = part.strip()
        if '-' in part:
            start, end = part.split('-')
            layers.extend(range(int(start), int(end) + 1))
        else:
            layers.append(int(part))
    return sorted(set(layers))


# ============== MAIN ==============

def parse_mix_config(mix_str: str) -> Optional[Dict[str, float]]:
    """Parse mix config string like 'wiki:0.5,news:0.5' to dict."""
    if not mix_str:
        return None
    mix_config = {}
    for part in mix_str.split(','):
        source, prop = part.strip().split(':')
        mix_config[source.strip()] = float(prop.strip())
    return mix_config


def parse_language_mix_config(mix_str: str) -> Optional[Dict[str, float]]:
    """Parse language mix config string like 'en:0.5,de:0.5' to dict."""
    if not mix_str:
        return None
    mix_config = {}
    for part in mix_str.split(','):
        lang, prop = part.strip().split(':')
        mix_config[lang.strip()] = float(prop.strip())
    return mix_config


if __name__ == "__main__":
    args = parse_args()
    if args.force_cpu:
        set_force_cpu(True)
        print("Force CPU enabled: all computation will run on CPU.")
    if args.mc_on_gpu:
        set_mc_on_gpu(True)
        print("MC-on-GPU enabled: Monte Carlo will use CUDA if available.")
    if args.limits_on_gpu:
        set_limits_on_gpu(True)
        print("Limits-on-GPU enabled: full-attention limits will use GPU offload.")

    # Handle list commands first
    if args.list_text_configs:
        print("\nAvailable predefined text experiment configurations:")
        print("=" * 70)
        for name, config in get_experiment_configs().items():
            print(f"  {name:25s} : {config['description']}")
        print("\nUsage: --text_config <config_name>")
        print("=" * 70)
        exit(0)

    if args.list_text_sources:
        print("\nAvailable text sources:")
        print("=" * 70)
        for name, config in TEXT_SOURCES.items():
            print(f"  {name:15s} : {config['description']}")
        print("\nAvailable languages (for wiki source):")
        print("=" * 70)
        for lang, config in MULTILINGUAL_SOURCES.items():
            print(f"  {lang:5s} : {config['description']}")
        print("\nUsage examples:")
        print("  --text_source wiki --text_language de")
        print("  --text_source news")
        print("  --text_source mixed --text_mix 'wiki:0.5,news:0.5'")
        print("  --text_shuffle sentence")
        print("=" * 70)
        exit(0)

    # Determine model size
    model_size = "large" if args.bigbird_large else "base"
    num_layers = 24 if model_size == "large" else 12
    print(f"Using BigBird-RoBERTa-{model_size.capitalize()} ({num_layers} layers)")

    # Parse text mix config if provided
    text_mix_config = parse_mix_config(args.text_mix) if args.text_mix else None

    # If mix config provided, set source to 'mixed'
    if text_mix_config and args.text_source != "mixed":
        print(f"Note: --text_mix provided, setting text_source to 'mixed'")
        args.text_source = "mixed"

    # Parse language mix config if provided
    text_language_mix_config = parse_language_mix_config(args.text_language_mix) if args.text_language_mix else None

    # If language mix config provided, set language to 'mixed'
    if text_language_mix_config and args.text_language != "mixed":
        print(f"Note: --text_language_mix provided, setting text_language to 'mixed'")
        args.text_language = "mixed"

    # Determine which layers to run
    if args.all_layers:
        layers_to_run = list(range(0, num_layers))  # Base: 0-11, Large: 0-23
    elif args.layers:
        layers_to_run = parse_layer_range(args.layers)
    else:
        layers_to_run = [args.layer_id]

    print(f"Layers to run: {layers_to_run}")

    # Create a run timestamp for multi-layer runs
    if len(layers_to_run) > 1:
        run_timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
        base_output_dir = os.path.join(args.output_dir, f"run_{run_timestamp}")
        os.makedirs(base_output_dir, exist_ok=True)
        print(f"Multi-layer run output directory: {base_output_dir}")
    else:
        base_output_dir = args.output_dir

    all_results = []

    for layer_id in layers_to_run:
        print(f"\n{'#'*70}")
        print(f"# RUNNING LAYER {layer_id}")
        print(f"{'#'*70}\n")

        # For multi-layer runs, use layer-specific subdirectory without extra timestamp
        if len(layers_to_run) > 1:
            layer_output_dir = os.path.join(base_output_dir, f"layer{layer_id}")
            add_timestamp = False
        else:
            layer_output_dir = args.output_dir
            add_timestamp = not args.no_timestamp

        try:
            results = run_experiment(
                N=args.N,
                layer_id=layer_id,
                n_min=args.n_min,
                n_max=args.n_max,
                nb_tot=args.nb_tot,
                k=args.k,
                replace=args.replace,
                plot=not args.no_plot,
                plot_layer=args.plot_layer,
                A_sweep=args.A_sweep,
                A_cap=args.A_cap,
                n_A_points=args.n_A_points,
                sanity_plot=args.sanity_plot,
                output_dir=layer_output_dir,
                add_timestamp=add_timestamp,
                # Model size
                model_size=model_size,
                device_map=args.device_map,
                # Text source parameters
                text_source=args.text_source,
                text_language=args.text_language,
                text_mix_config=text_mix_config,
                text_mix_strategy=args.text_mix_strategy,
                text_shuffle=args.text_shuffle,
                text_experiment_config=args.text_config,
                text_language_mix_config=text_language_mix_config,
                aggressive_cleanup=args.aggressive_cleanup,
            )
            all_results.append(results)
            print(f"\nLayer {layer_id} completed. Results saved to: {results.get('output_dir', layer_output_dir)}")

        except Exception as e:
            print(f"\nERROR in layer {layer_id}: {e}")
            import traceback
            traceback.print_exc()
            continue

    # Summary
    print(f"\n{'='*70}")
    print(f"COMPLETED {len(all_results)}/{len(layers_to_run)} layers")
    if len(layers_to_run) > 1:
        print(f"Results directory: {base_output_dir}")
    print(f"{'='*70}")
