
from __future__ import annotations

import dataclasses
import math
import random
from typing import Dict, Iterable, List, Sequence, Tuple

import numpy as np
import torch
from datasets import load_dataset, Dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.linear_model import LinearRegression


@dataclasses.dataclass
class DatasetConfig:
    """Metadata describing a classification dataset.

    Attributes
    ----------
    hf_name:
        Name of the dataset on the HuggingFace hub.  Many tasks are stored
        under a single umbrella (e.g. ``glue``), so a subset may also be
        required.
    subset:
        Optional subset name for composite datasets.  For example,
        ``glue/sst2`` uses ``hf_name='glue'`` and ``subset='sst2'``.
    text_fields:
        One or more field names that contain the raw text.  Some datasets
        require concatenating multiple fields (e.g. question and answer).  The
        union of all fields is joined with spaces.
    label_field:
        Field name containing the integer label.
    label_names:
        Optional list of human‑readable label names.  If ``None``, the script
        will query the dataset’s feature metadata to derive the names.
    split_train:
        Name of the split used for sampling demonstrations and computing
        covariance matrices.  Common values are ``'train'`` or ``'train[90%]'``.
    split_test:
        Name of the split used for evaluating ICL performance.  For some
        benchmarks (e.g. SST‑2) the official ``test`` split lacks labels;
        ``validation`` should be used instead.
    """

    hf_name: str
    subset: str | None
    text_fields: Sequence[str]
    label_field: str
    label_names: List[str] | None = None
    split_train: str = "train"
    split_test: str = "test"


def load_and_prepare_dataset(cfg: DatasetConfig) -> Tuple[Dataset, Dataset, List[str]]:
    """Load a classification dataset and extract the training/test examples.

    Parameters
    ----------
    cfg:
        A ``DatasetConfig`` containing the dataset name, split names and
        field definitions.

    Returns
    -------
    train_ds:
        HuggingFace ``Dataset`` object containing training examples.  Each
        element has keys ``'text'`` and ``'label'``.
    test_ds:
        HuggingFace ``Dataset`` object containing test examples.  Each
        element has keys ``'text'`` and ``'label'``.
    label_names:
        List of human‑readable class names corresponding to integer labels.

    Notes
    -----
    This helper hides the boilerplate around loading composite datasets
    (e.g. ``glue``) and concatenating multiple text fields into a single
    ``text`` column.  It also obtains the label names either from the config
    or directly from the dataset’s feature specification.
    """
    # Load the raw dataset (with optional subset) using the HuggingFace API.
    ds_kwargs = {} if cfg.subset is None else {"name": cfg.subset}
    raw = load_dataset(cfg.hf_name, **ds_kwargs)

    # Extract the training and test splits.  Some tasks (e.g. SST‑2) have no
    # official test labels; in that case the validation split serves as
    # ``split_test``.
    train_raw = raw[cfg.split_train]
    test_raw = raw[cfg.split_test]

    # Determine the list of label names.  If not provided, attempt to read
    # them from the dataset metadata.  Fallback to numeric strings when
    # unknown.
    if cfg.label_names is not None:
        label_names: List[str] = list(cfg.label_names)
    else:
        # Many HuggingFace datasets expose ``features[label_field].names``.
        names = None
        try:
            names = raw["train"].features[cfg.label_field].names  # type: ignore[arg-type]
        except Exception:
            try:
                names = raw[cfg.split_test].features[cfg.label_field].names  # type: ignore[arg-type]
            except Exception:
                names = None
        if names is None:
            # Fallback: create generic class names.
            num_labels = int(max(train_raw[cfg.label_field] + test_raw[cfg.label_field]) + 1)
            label_names = [f"class_{i}" for i in range(num_labels)]
        else:
            label_names = list(names)  # type: ignore[misc]

    def concat_text(example):
        # Join multiple text fields into a single string separated by spaces.
        parts = [str(example[f]) for f in cfg.text_fields if f in example]
        return " ".join(parts)

    # Map the datasets to a simpler structure with ``text`` and ``label`` keys.
    train_ds = train_raw.map(
        lambda e: {"text": concat_text(e), "label": e[cfg.label_field]},
        remove_columns=[f for f in train_raw.column_names if f not in cfg.text_fields + [cfg.label_field]],
    )
    test_ds = test_raw.map(
        lambda e: {"text": concat_text(e), "label": e[cfg.label_field]},
        remove_columns=[f for f in test_raw.column_names if f not in cfg.text_fields + [cfg.label_field]],
    )

    return train_ds, test_ds, label_names


def compute_covariance_statistics(
    encoder_name: str,
    texts: Iterable[str],
    max_samples: int | None = None,
    batch_size: int = 32,
    device: str | torch.device | None = None,
    seed: int = 0,
) -> Dict[str, float]:
    """Compute key spectral statistics of the sentence‑embedding covariance matrix.

    Given a collection of text strings, this function obtains their fixed‑length
    representations using a sentence encoder (from the `sentence_transformers`
    library), estimates the sample covariance matrix, and derives several
    quantities used in the theoretical bound.  These include the smallest
    eigenvalue ``λ_min``, a quantile eigenvalue ``λ_q``, the spectral norm
    ``||Σ||_2``, the Frobenius norm ``||Σ||_F`` and the effective rank
    ``r_eff = (tr Σ)^2 / ||Σ||_F^2``.

    Parameters
    ----------
    encoder_name:
        Name of the sentence‑transformer model to load.  Examples include
        ``'all-mpnet-base-v2'``, ``'all-MiniLM-L6-v2'`` and
        ``'paraphrase-MiniLM-L12-v2'``.
    texts:
        Iterable over input strings whose embeddings will be used to
        estimate the covariance.  Typically these are training examples.
    max_samples:
        If not ``None``, limit the number of texts to this value by random
        sampling.  Using the full training set may be computationally
        expensive for large datasets; sub‑sampling can provide a good
        approximation.
    batch_size:
        Number of texts to process per encoder forward pass.
    device:
        ``'cpu'`` or ``'cuda'``.  If ``None``, defaults to ``'cuda'`` when
        available, otherwise ``'cpu'``.
    seed:
        Random seed for reproducible sub‑sampling.

    Returns
    -------
    stats:
        A dictionary containing spectral statistics with the following keys:
        ``'lambda_min'``, ``'lambda_q'`` (for ``q=0.1``), ``'norm_2'``,
        ``'norm_fro'``, ``'trace'``, ``'r_eff'`` and ``'dimension'``.

    Notes
    -----
    The sample covariance is computed as ``Σ_hat = (1/n) Xᵀ X`` where ``X``
    denotes the centred matrix of embeddings.  Centreing the data is
    important because otherwise the covariance estimator can be biased by
    the mean feature vector.  We use ``np.linalg.eigvalsh`` to compute the
    eigenvalues of the symmetric covariance matrix.
    """
    # Configure the random number generator for reproducibility.
    rng = np.random.default_rng(seed)

    # Determine the computation device.  Avoid allocating the model on a CUDA
    # device when no GPU is available.
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # Optionally sample a subset of texts to reduce computational cost.
    text_list = list(texts)
    if (max_samples is not None) and (len(text_list) > max_samples):
        indices = rng.choice(len(text_list), size=max_samples, replace=False)
        text_list = [text_list[i] for i in indices]

    # Load the sentence encoder.  SentenceTransformers will automatically
    # download weights if they are not cached locally.
    encoder = SentenceTransformer(encoder_name, device=device)

    # Encode all texts in batches to avoid exhausting GPU/CPU memory.  The
    # embeddings returned by the encoder are typically L2 normalised; we do
    # not normalise further since the covariance estimator is scale invariant
    # with respect to each feature dimension.
    embeddings: List[np.ndarray] = []
    for start in range(0, len(text_list), batch_size):
        batch_texts = text_list[start : start + batch_size]
        batch_emb = encoder.encode(
            batch_texts,
            batch_size=len(batch_texts),
            convert_to_tensor=True,
            show_progress_bar=False,
        )
        embeddings.extend(batch_emb.cpu().numpy())

    # Convert to a 2‑D NumPy array of shape (n_samples, d).  Subtract the
    # empirical mean to centre the data.
    X = np.vstack(embeddings)
    X_mean = X.mean(axis=0, keepdims=True)
    X_centered = X - X_mean
    n_samples, d_dim = X_centered.shape

    # Estimate the covariance matrix Σ_hat = (1/n) Xᵀ X.  We compute the
    # covariance on the host (CPU) using float64 precision for numerical
    # stability.  For large dimensions ``d_dim`` the memory requirement is
    # O(d_dim^2); sub‑sampling can alleviate this cost.
    cov_mat = (X_centered.T @ X_centered) / float(n_samples)

    # Compute eigenvalues.  Since the covariance matrix is symmetric
    # positive‑semi‑definite, we use ``eigvalsh`` which returns the eigenvalues
    # in ascending order.  We ensure a small numerical floor at zero.
    eigvals = np.linalg.eigvalsh(cov_mat)
    eigvals = np.clip(eigvals, a_min=0.0, a_max=None)

    # The smallest eigenvalue (λ_min) and the q‑quantile eigenvalue (λ_q).  The
    # paper recommends q in [0.05, 0.2]; we default to q=0.1 but users can
    # adjust this later when computing K.
    lambda_min = float(np.min(eigvals))
    q = 0.1
    q_index = max(0, int(np.floor(q * len(eigvals))) - 1)
    lambda_q = float(eigvals[q_index])

    # Spectral norm (largest eigenvalue) and Frobenius norm of Σ_hat.  The
    # Frobenius norm is the square root of the sum of squared entries.
    norm_2 = float(np.max(eigvals))
    norm_fro = float(np.linalg.norm(cov_mat, "fro"))

    # Trace of Σ_hat.  Note that ``trace = sum(eigvals)``.  The effective
    # rank r_eff = (trace)^2 / ||Σ||_F^2 captures how many dimensions
    # contribute significant variance【213643380047070†L821-L879】.
    trace_val = float(np.sum(eigvals))
    r_eff = (trace_val ** 2) / (norm_fro ** 2) if norm_fro > 0 else 0.0

    return {
        "lambda_min": lambda_min,
        "lambda_q": lambda_q,
        "norm_2": norm_2,
        "norm_fro": norm_fro,
        "trace": trace_val,
        "r_eff": r_eff,
        "dimension": d_dim,
    }


def predict_required_k(
    stats: Dict[str, float],
    delta: float = 0.05,
    xi: float = 0.1,
    constant: float = 1.0,
    use_calibration: bool = False,
    alpha: float = 1.0,
    q: float = 0.1,
) -> Tuple[float, float]:
    """Compute theoretical and calibrated predictions for the number of demonstrations.

    Given the spectral statistics of a covariance matrix, this function
    implements the simplified sample‑size bound described in the paper.  The
    theoretical bound (``K_star``) is computed from the smallest eigenvalue and
    the spectral norm, while the calibrated bound (``K_cal``) replaces these
    quantities by more empirically appropriate surrogates【213643380047070†L821-L879】.

    Parameters
    ----------
    stats:
        Dictionary returned by ``compute_covariance_statistics`` containing
        ``lambda_min``, ``lambda_q``, ``norm_2``, ``norm_fro`` and ``r_eff``.
    delta:
        Stability threshold δ.  The bound requires that the eigenvalue
        ``λ_min`` be greater than ``δ``; if ``λ_min`` is smaller, the theory
        predicts instability.
    xi:
        Failure probability ξ; appears inside ``log(2/ξ)``.  Smaller values
        produce larger predictions.
    constant:
        Constant C₂ from the bound.  In practice this constant arises from
        matrix concentration inequalities and is typically between 1 and 3【213643380047070†L800-L816】.  Users may tune it for safety.
    use_calibration:
        When ``True``, return the calibrated estimate ``K_cal`` instead of
        ``K_star``.  The calibration multiplies by ``alpha`` and uses the
        Frobenius norm and quantile eigenvalue.
    alpha:
        Global scaling factor α used in the calibrated estimate (see
        equation for ``K_cal`` in Sec. 3.7 of the paper).
    q:
        Quantile level used to select ``λ_q`` from the eigenvalue spectrum.
        Ignored when ``use_calibration`` is ``False``.

    Returns
    -------
    K_star:
        Predicted number of demonstrations from the theoretical bound.
    K_cal:
        Calibrated number of demonstrations.  If ``use_calibration`` is
        ``False``, this value will be equal to ``K_star``.

    Notes
    -----
    The formula implemented here corresponds to equation (1) in the paper,
    simplified by omitting lower‑order terms and replacing log factors by
    ``log(2/ξ)``【213643380047070†L800-L816】.  If the denominator ``λ_min − δ`` is non‑positive,
    the function returns ``inf`` to indicate that no finite number of
    demonstrations can guarantee stability at the chosen threshold.
    """
    # Extract the necessary statistics.
    lambda_min = stats["lambda_min"]
    lambda_q = stats["lambda_q"]
    norm_2 = stats["norm_2"]
    norm_fro = stats["norm_fro"]
    r_eff = stats["r_eff"]

    # Compute the logarithmic factor log(2/ξ).  Ensure ξ ∈(0,1) to avoid
    # numerical errors.
    assert 0 < xi < 1, "xi must be in (0,1)"
    log_factor = math.log(2.0 / xi)

    # Helper to compute bound given eigenvalue and norm.
    def _compute_bound(lmbda: float, norm_val: float) -> float:
        # If the eigenvalue minus delta is non‑positive, no finite K suffices.
        denom = (lmbda - delta)
        if denom <= 0:
            return float("inf")
        # Bound from equation (1): K ≥ C2 * ||Σ||^2 * r_eff * log(2/ξ) / (λ - δ)^2.
        return constant * (norm_val ** 2) * r_eff * log_factor / (denom ** 2)

    # Theoretical prediction using λ_min and the spectral norm.
    k_star = _compute_bound(lambda_min, norm_2)

    # Calibrated prediction using λ_q and the Frobenius norm.  Multiply by α
    # as suggested in Sec. 3.7 of the paper【213643380047070†L821-L882】.
    if use_calibration:
        k_cal = alpha * _compute_bound(lambda_q, norm_fro)
    else:
        k_cal = k_star

    return k_star, k_cal


def generate_prompt(
    demonstrations: Sequence[Tuple[str, int]],
    test_text: str,
    label_names: Sequence[str],
) -> str:
    """Construct a natural‑language prompt for in‑context classification.

    Each demonstration consists of a text and its corresponding label
    (provided as an integer index into ``label_names``).  The prompt is
    formatted as a sequence of examples followed by the test query.  A
    simple template is used here; users can customise the formatting to
    include task descriptions or separators.

    Parameters
    ----------
    demonstrations:
        List of tuples ``(text, label_id)`` representing the in‑context
        examples.
    test_text:
        Raw text of the query for which the model should predict a label.
    label_names:
        Human‑readable names for each label ID.

    Returns
    -------
    prompt:
        A string containing the full in‑context prompt.
    """
    lines: List[str] = []
    # Add each demonstration in the form ``<text>\nLabel: <label_name>``.
    for demo_text, demo_label in demonstrations:
        label_str = label_names[demo_label]
        lines.append(f"{demo_text}\nLabel: {label_str}\n")
    # Append the query without a label.  We leave the ``Label:`` line empty for
    # the model to fill in.
    lines.append(f"{test_text}\nLabel:")
    return "\n".join(lines)


def score_candidate_label(
    prompt: str,
    candidate: str,
    tokenizer: AutoTokenizer,
    model: AutoModelForCausalLM,
    device: str | torch.device,
) -> float:
    """Compute the log‑likelihood of a candidate label given a prompt.

    The generative language model takes a complete prompt (including the
    candidate label) and returns token‑level probabilities.  We calculate
    the log‑probability of the candidate label tokens conditioned on the
    preceding prompt.  Higher values correspond to more probable labels.

    Parameters
    ----------
    prompt:
        The in‑context prompt without the candidate label.  It should end
        with ``'Label:'``.
    candidate:
        Candidate label string (e.g. ``'positive'``).
    tokenizer:
        Tokenizer associated with the language model.
    model:
        Pre‑trained causal language model (e.g. GPT‑2).  Must be in
        evaluation mode.
    device:
        Device on which to perform the forward pass.

    Returns
    -------
    log_likelihood:
        The sum of log‑probabilities of the candidate label tokens.

    Notes
    -----
    The function encodes the combined string ``prompt + ' ' + candidate`` and
    then extracts the log‑probabilities corresponding only to the candidate
    portion.  We use ``torch.nn.functional.log_softmax`` on the logits to
    convert them into log‑probabilities.  For efficiency we perform the
    forward pass once per candidate; this is acceptable for small label
    sets (e.g. 2–10 classes).
    """
    # Combine the prompt and the candidate label with a space.  This simple
    # separator helps the tokenizer treat the candidate as a separate token.
    full_input = prompt + " " + candidate

    # Tokenise the full input.  ``return_tensors='pt'`` produces a tensor
    # suitable for PyTorch.  We disable addition of special tokens so that
    # the model sees the prompt exactly as written.  Padding is not needed
    # because the sequence length will vary across calls but remains small.
    inputs = tokenizer(
        full_input,
        return_tensors="pt",
        add_special_tokens=False,
    )
    input_ids = inputs["input_ids"].to(device)

    # Identify the token indices corresponding to the candidate portion.  We
    # accomplish this by tokenising the candidate separately and noting the
    # length of its token sequence.
    cand_ids = tokenizer(
        candidate,
        return_tensors="pt",
        add_special_tokens=False,
    )["input_ids"][0]
    cand_len = cand_ids.shape[0]
    # The candidate starts at position ``len(input_ids) - cand_len`` in the
    # combined sequence.
    start_idx = input_ids.shape[1] - cand_len

    # Forward pass through the model.  ``labels`` are not provided because we
    # manually compute log‑likelihoods from the logits.  ``torch.no_grad()``
    # disables gradient calculation, which saves memory and computation.
    with torch.no_grad():
        outputs = model(input_ids=input_ids)
        logits = outputs.logits  # shape: (1, seq_len, vocab_size)
    # Convert logits to log‑probabilities along the vocabulary dimension.  The
    # final dimension corresponds to the prediction of the next token given
    # the preceding context (causal language modelling).
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    # Sum the log‑probabilities of the correct tokens for the candidate.  At
    # position ``i`` (relative to the full sequence) the model predicts token
    # ``input_ids[0, i]`` conditioned on tokens ``input_ids[0, :i]``.
    total_log_prob = 0.0
    for offset in range(cand_len):
        pos = start_idx + offset - 1  # model predicts token at i based on prefix up to i-1
        token_id = cand_ids[offset].item()
        # Guard against the case where pos < 0 (should not happen when prompt
        # ends with 'Label:').  We clip to zero for safety.
        pos = max(0, pos)
        total_log_prob += log_probs[0, pos, token_id].item()

    return float(total_log_prob)


def evaluate_icl_accuracy(
    train_ds: Dataset,
    test_ds: Dataset,
    label_names: Sequence[str],
    model_name: str,
    k_values: Sequence[int],
    num_test_samples: int = 50,
    num_seeds: int = 1,
    device: str | torch.device | None = None,
    random_seed: int = 0,
) -> Dict[int, float]:
    """Evaluate in‑context learning accuracy across different prompt lengths.

    For each K in ``k_values``, this routine randomly draws K demonstration
    examples from the training set and constructs a prompt for each test
    example.  It then uses a causal language model to predict the label via
    log‑likelihood scoring and records the fraction of correct predictions.

    Parameters
    ----------
    train_ds:
        Training dataset providing the pool of demonstrations.  Must contain
        fields ``'text'`` and ``'label'``.
    test_ds:
        Test dataset used to evaluate accuracy.  Must contain fields
        ``'text'`` and ``'label'``.  Typically this is the official test or
        validation split.
    label_names:
        List mapping integer label IDs to their textual names.  These names
        populate the prompts and serve as candidate completions.
    model_name:
        Name of the causal language model to load from HuggingFace (e.g.
        ``'gpt2'`` or ``'gpt2-medium'``).  The model should be compatible
        with auto‑regressive generation.
    k_values:
        Iterable of integers specifying how many demonstrations to include
        before the query.  Typical values are ``range(1, 33)``.
    num_test_samples:
        Number of test examples on which to measure accuracy.  Limiting this
        quantity speeds up experiments at the cost of higher variance.
    num_seeds:
        Number of random seeds for drawing demonstrations.  Repeating the
        evaluation with different seeds reduces variance and yields an
        average accuracy curve.
    device:
        Device on which to run the language model.  Defaults to GPU if
        available.
    random_seed:
        Base random seed to initialise the pseudo‑random number generator.

    Returns
    -------
    acc_dict:
        Dictionary mapping each K to the average classification accuracy
        across ``num_seeds`` repetitions.  Accuracy is measured as the
        fraction of test samples for which the predicted label matches the
        ground‑truth label.

    Notes
    -----
    This function performs multiple nested loops (over seeds, test examples
    and K values) and invokes the language model repeatedly.  For large
    K ranges, deep models or large test sets the evaluation may take
    significant time.  Users should plan their compute budget accordingly.
    """
    # Determine computation device.  Use GPU when available for speed.
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load model and tokenizer once.  We put the model into evaluation mode
    # and prevent gradients to reduce memory usage.
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # GPT‑2 models have no padding token; we assign the EOS token for pad to
    # avoid warnings when tokenising.  This does not affect generation.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.to(device)
    model.eval()

    # Convert train and test datasets into Python lists of tuples for faster
    # random access.  Each entry is (text, label).
    train_examples = [(ex["text"], int(ex["label"])) for ex in train_ds]
    test_examples = [(ex["text"], int(ex["label"])) for ex in test_ds]
    # Restrict test set size for manageable run time.
    test_examples = test_examples[:num_test_samples]

    # Pre‑initialise the accuracy accumulator.
    acc_sums = {k: 0.0 for k in k_values}
    total_runs = num_seeds * len(test_examples)

    # Loop over seeds for demonstration sampling.
    for seed_offset in range(num_seeds):
        rng = random.Random(random_seed + seed_offset)
        # Evaluate each test sample.
        for test_text, true_label in test_examples:
            # For each K, we will sample demonstration indices once per
            # test example to ensure fair comparison across candidate models.
            # We use a fixed random state per (seed, example) for reproducibility.
            base_state = rng.getstate()
            for k in k_values:
                # Restore RNG state so that the same demonstration set is used
                # for different values of K across seeds.  Without this
                # restoration, demonstrations for k=1 would be a prefix of those
                # for k=2, which can introduce bias.  This way each K draws
                # independently from the training set.
                rng.setstate(base_state)
                # Sample K unique training examples as demonstrations.
                demo_indices = rng.sample(range(len(train_examples)), k)
                demonstrations = [train_examples[i] for i in demo_indices]
                # Build the prompt.
                prompt = generate_prompt(demonstrations, test_text, label_names)
                # Score each candidate label by log‑likelihood.
                scores = []
                for cand in label_names:
                    score = score_candidate_label(prompt, cand, tokenizer, model, device)
                    scores.append(score)
                # Choose the label with the highest score.
                pred_label = int(np.argmax(scores))
                if pred_label == true_label:
                    acc_sums[k] += 1.0

    # Compute average accuracy per K.
    acc_dict = {k: (acc_sums[k] / total_runs) for k in k_values}
    return acc_dict


def find_knee_point(
    k_values: Sequence[int],
    accuracies: Sequence[float],
    bootstrap_samples: int = 0,
    random_seed: int = 0,
) -> Tuple[int, float, Tuple[int, int]]:
    """Determine the knee‑point of an accuracy curve using piecewise linear fits.

    The knee‑point ``K_knee`` is defined as the value of K at which a
    two‑segment piecewise linear model best fits the accuracy curve
    ``{(K_i, acc_i)}``【213643380047070†L929-L937】.  We enumerate all possible split
    indices ``t`` (1 ≤ t < N), fit separate linear models to the segments
    ``[0, t]`` and ``[t, N]``, compute the total sum of squared errors (SSE)
    and choose the ``t`` that minimises this SSE.  The corresponding knee
    is ``K_values[t]``.  Optionally, we perform bootstrap resampling on
    the accuracy values to obtain a confidence interval for ``K_knee``.

    Parameters
    ----------
    k_values:
        Monotonically increasing sequence of K values (e.g. 1, 2, …, 32).
    accuracies:
        Sequence of accuracy measurements corresponding to ``k_values``.
    bootstrap_samples:
        Number of bootstrap replicates for uncertainty estimation.  When
        zero, no bootstrapping is performed.
    random_seed:
        Random seed for bootstrap resampling.

    Returns
    -------
    knee_k:
        The value of K at the knee‑point (best split index).
    knee_index:
        The index in ``k_values`` corresponding to ``knee_k``.
    ci:
        A tuple ``(lower_idx, upper_idx)`` containing the 95% percentile
        confidence interval on the knee index when bootstrapping is
        performed.  When ``bootstrap_samples == 0`` the interval equals
        ``(knee_index, knee_index)``.

    Notes
    -----
    This implementation uses simple linear regression (via scikit‑learn) to
    estimate the SSE.  Because ``k_values`` are short (typically 10–30
    points), the exhaustive search over splits is computationally cheap.
    Bootstrapping introduces additional overhead linear in
    ``bootstrap_samples`` but provides uncertainty estimates consistent
    with the paper’s methodology【213643380047070†L929-L937】.
    """
    x = np.array(k_values, dtype=np.float64)
    y = np.array(accuracies, dtype=np.float64)
    n = len(x)
    assert n > 2, "At least three points are required to compute a knee."

    # Pre‑fit linear models for all prefixes and suffixes.  We cache the
    # results to avoid redundant computation in the exhaustive search.
    prefix_models: List[LinearRegression] = []
    prefix_sse: List[float] = []
    for t in range(1, n):
        model = LinearRegression(fit_intercept=True)
        model.fit(x[:t].reshape(-1, 1), y[:t])
        preds = model.predict(x[:t].reshape(-1, 1))
        sse = float(np.sum((y[:t] - preds) ** 2))
        prefix_models.append(model)
        prefix_sse.append(sse)

    suffix_models: List[LinearRegression] = []
    suffix_sse: List[float] = []
    for t in range(1, n):
        model = LinearRegression(fit_intercept=True)
        model.fit(x[t:].reshape(-1, 1), y[t:])
        preds = model.predict(x[t:].reshape(-1, 1))
        sse = float(np.sum((y[t:] - preds) ** 2))
        suffix_models.append(model)
        suffix_sse.append(sse)

    # Compute total SSE for each split and pick the minimum.
    total_sse = [prefix_sse[t - 1] + suffix_sse[t - 1] for t in range(1, n)]
    knee_index = int(np.argmin(total_sse)) + 1  # +1 because split after t points
    knee_k = int(x[knee_index])

    # Optional bootstrapping.  We resample ``y`` with replacement and repeat
    # the knee detection ``bootstrap_samples`` times.  The 95% interval is
    # reported as the 2.5 and 97.5 percentiles of the resulting indices.
    if bootstrap_samples > 0:
        rng = np.random.default_rng(random_seed)
        boot_indices: List[int] = []
        for _ in range(bootstrap_samples):
            resampled_y = rng.choice(y, size=n, replace=True)
            # Fit models on resampled values.
            prefix_sse_b: List[float] = []
            suffix_sse_b: List[float] = []
            for t in range(1, n):
                # Prefix
                model_prefix = LinearRegression().fit(x[:t].reshape(-1, 1), resampled_y[:t])
                preds_prefix = model_prefix.predict(x[:t].reshape(-1, 1))
                sse_prefix = float(np.sum((resampled_y[:t] - preds_prefix) ** 2))
                prefix_sse_b.append(sse_prefix)
                # Suffix
                model_suffix = LinearRegression().fit(x[t:].reshape(-1, 1), resampled_y[t:])
                preds_suffix = model_suffix.predict(x[t:].reshape(-1, 1))
                sse_suffix = float(np.sum((resampled_y[t:] - preds_suffix) ** 2))
                suffix_sse_b.append(sse_suffix)
            total_b = [prefix_sse_b[t - 1] + suffix_sse_b[t - 1] for t in range(1, n)]
            knee_b = int(np.argmin(total_b)) + 1
            boot_indices.append(knee_b)
        # Compute 95% confidence interval on knee indices.
        lower = int(np.percentile(boot_indices, 2.5))
        upper = int(np.percentile(boot_indices, 97.5))
        ci = (lower, upper)
    else:
        ci = (knee_index, knee_index)

    return knee_k, knee_index, ci


def main():
    """Example usage of the experimental pipeline.

    This function demonstrates how to reproduce a subset of the experiments
    described in the ICLR paper.  It defines several datasets, encoders and
    language models, computes the spectral statistics for each dataset‑encoder
    pair, predicts the required number of demonstrations from the theory,
    evaluates empirical accuracy curves for a single language model and
    finally compares the predicted ``K*`` and the observed knee.

    Users are encouraged to adapt the lists ``datasets_cfgs``,
    ``encoder_names`` and ``lm_names`` to match their computational
    resources.  The default settings below use small models and short
    evaluation loops suitable for demonstration purposes.  Remove or
    uncomment datasets and models as needed to fully replicate the paper.
    """
    # Define datasets.  Only a subset is enabled by default for brevity.
    datasets_cfgs: Dict[str, DatasetConfig] = {
        # SST‑2 sentiment classification.  Validation split is used as test
        # because the official test set lacks labels.
        "sst2": DatasetConfig(
            hf_name="glue",
            subset="sst2",
            text_fields=["sentence"],
            label_field="label",
            label_names=["negative", "positive"],
            split_train="train",
            split_test="validation",
        ),
        # AG News (four classes).  Use official train/test splits.
        "ag_news": DatasetConfig(
            hf_name="ag_news",
            subset=None,
            text_fields=["text"],
            label_field="label",
            label_names=["World", "Sports", "Business", "Sci/Tech"],
            split_train="train",
            split_test="test",
        ),
        # TREC‑6 coarse question classification.  The dataset has six labels.
        "trec": DatasetConfig(
            hf_name="trec",
            subset="coarse",
            text_fields=["text"],
            label_field="coarse_label",
            label_names=["Description", "Entity", "Expression", "Human", "Location", "Number"],
            split_train="train",
            split_test="test",
        ),
        # Yahoo Answers Topics (ten classes).  Here the label names are
        # derived automatically because they are available via the dataset
        # metadata.  We use a truncated text field (question title only)
        # consistent with common practice.
        "yahoo": DatasetConfig(
            hf_name="yahoo_answers_topics",
            subset=None,
            text_fields=["question_title"],
            label_field="topic",
            label_names=None,
            split_train="train",
            split_test="test",
        ),
    }

    # Choose which datasets to run.  Uncomment to include more.
    selected_datasets = ["sst2", "ag_news"]  # e.g. ["sst2", "ag_news", "trec", "yahoo"]

    # Define sentence encoders.  These correspond to the encoders used in
    # the paper.  You can add ``all-MiniLM-L6-v2`` or ``paraphrase-MiniLM-L12-v2``
    # for additional experiments.
    encoder_names = ["all-mpnet-base-v2"]

    # Define language models.  GPT‑2 small and medium are local and do not
    # require API access.  GPT‑3 cannot be loaded via HuggingFace; it is not
    # included here.  You may replace ``gpt2`` with ``gpt2-medium`` to
    # simulate a larger model.
    lm_names = ["gpt2"]

    # Range of demonstration counts K over which to measure accuracy.
    k_values = list(range(1, 9))  # Use 1..8 for quick demonstration

    # Number of test samples and seeds for accuracy evaluation.  Increase for
    # more stable estimates at the cost of runtime.
    num_test_samples = 20
    num_seeds = 1

    # Delta and xi parameters for the theoretical bound.
    delta = 0.05
    xi = 0.1
    constant = 1.0  # C2 in the theoretical bound【213643380047070†L800-L816】

    # Loop over datasets.
    results: List[Dict[str, object]] = []
    for ds_name in selected_datasets:
        cfg = datasets_cfgs[ds_name]
        print(f"\nLoading dataset {ds_name}…")
        train_ds, test_ds, label_names = load_and_prepare_dataset(cfg)
        print(f"Loaded {len(train_ds)} training and {len(test_ds)} test examples.")

        for enc_name in encoder_names:
            print(f"\nComputing covariance statistics for encoder {enc_name}…")
            stats = compute_covariance_statistics(
                encoder_name=enc_name,
                texts=[ex["text"] for ex in train_ds],
                max_samples=1000,  # limit for demonstration purposes
                batch_size=32,
            )
            print(f"λ_min: {stats['lambda_min']:.6f}, λ_q: {stats['lambda_q']:.6f}, "
                  f"||Σ||_2: {stats['norm_2']:.6f}, ||Σ||_F: {stats['norm_fro']:.6f}, "
                  f"r_eff: {stats['r_eff']:.2f}")

            # Predict required demonstration counts.
            k_star, k_cal = predict_required_k(
                stats,
                delta=delta,
                xi=xi,
                constant=constant,
                use_calibration=True,
                alpha=1.0,
                q=0.1,
            )
            print(f"Theoretical K*: {k_star:.2f}, Calibrated K_cal: {k_cal:.2f}")

            # Evaluate empirical accuracy curves for each language model.
            for lm_name in lm_names:
                print(f"\nEvaluating ICL accuracy for model {lm_name}…")
                acc_dict = evaluate_icl_accuracy(
                    train_ds=train_ds,
                    test_ds=test_ds,
                    label_names=label_names,
                    model_name=lm_name,
                    k_values=k_values,
                    num_test_samples=num_test_samples,
                    num_seeds=num_seeds,
                    random_seed=42,
                )
                # Convert accuracy dictionary into sorted lists.
                acc_vals = [acc_dict[k] for k in k_values]
                # Find the knee‑point of the accuracy curve.
                knee_k, knee_idx, ci = find_knee_point(k_values, acc_vals, bootstrap_samples=0)
                print(f"Accuracy curve: {[(k, round(acc, 3)) for k, acc in zip(k_values, acc_vals)]}")
                print(f"Estimated knee‑point K_knee: {knee_k} at index {knee_idx}")

                # Record the results for later analysis.
                results.append({
                    "dataset": ds_name,
                    "encoder": enc_name,
                    "model": lm_name,
                    "k_star": k_star,
                    "k_cal": k_cal,
                    "k_knee": knee_k,
                    "accuracy_curve": acc_dict,
                })

    # Summarise results.
    print("\nSummary of experiments:")
    for res in results:
        ds_name = res["dataset"]
        enc = res["encoder"]
        lm = res["model"]
        k_star = res["k_star"]
        k_cal = res["k_cal"]
        k_knee = res["k_knee"]
        ratio_star = k_star / max(k_knee, 1e-6) if math.isfinite(k_star) else float("inf")
        ratio_cal = k_cal / max(k_knee, 1e-6) if math.isfinite(k_cal) else float("inf")
        print(
            f"Dataset: {ds_name}, Encoder: {enc}, Model: {lm}\n"
            f"  K*: {k_star:.2f}, K_cal: {k_cal:.2f}, Knee: {k_knee}\n"
            f"  Ratio K*/knee: {ratio_star:.2f}, Ratio K_cal/knee: {ratio_cal:.2f}\n"
        )


if __name__ == "__main__":
    main()