"""Flash head implementation for faster efficient language model head."""

import json
import math
import os
from time import time
from typing import List, Literal, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.functional import kl_div

if torch.cuda.is_available():
    from torchpq.clustering import KMeans
else:
    print("CUDA unavailable, skipping torchpq import.")


from transformers import AutoModelForCausalLM, AutoTokenizer

from efficient_heads.pipeline import GenerationPipeline

CACHE_FILENAME = "clustering_cache.pt"
CACHE_CONFIG_FILENAME = "clustering_config.json"

from typing import Tuple


def _kmeans_l2_torchpq(
    weight_matrix: torch.Tensor,
    n_clusters: int,
    max_iters: int = 1000,
) -> Tuple[torch.Tensor, torch.Tensor]:
    normalized_weight_matrix = weight_matrix / torch.norm(
        weight_matrix, p=2, dim=1, keepdim=True
    )
    kmeans = KMeans(
        n_clusters=n_clusters,
        n_redo=3,
        max_iter=max_iters,
        distance="cosine",
        init_mode="random",  # "kmeans++",#
        size=math.ceil(weight_matrix.shape[0] / n_clusters),  # None
    )
    with torch.no_grad():
        labels = kmeans.fit(normalized_weight_matrix.t().float().contiguous())
    return (
        kmeans.centroids.to(dtype=normalized_weight_matrix.dtype)
        .unsqueeze(0)
        .unsqueeze(0),
        labels,
    )


def _save_cache(
    centroids: torch.Tensor,
    cluster_assignments: torch.Tensor,
    cache_file: str,
    original_shape: Tuple[int],
    metadata_file: str,
):
    cache_data = {
        "centroids": centroids.cpu(),
        "cluster_assignments": cluster_assignments.cpu(),
    }
    torch.save(cache_data, cache_file)

    metadata = {
        "n_clusters": centroids.shape[0],
        "vocab_size": original_shape[0],
        "hidden_size": original_shape[1],
        "creation_time": time(),
    }
    with open(metadata_file, "w") as f:
        json.dump(metadata, f)


def _get_centroids(
    lm_head: nn.Linear,
    n_clusters: int,
    special_token_ids: List[int],
    cache_dir: str,
    enforce_equal_cluster_sizes: bool,
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = os.path.join(cache_dir, CACHE_FILENAME)
    metadata_file = os.path.join(cache_dir, CACHE_CONFIG_FILENAME)
    original_shape = lm_head.weight.shape

    try:
        if os.path.exists(cache_file) and os.path.exists(metadata_file):
            with open(metadata_file) as f:
                metadata = json.load(f)

            if (
                # metadata["n_clusters"] == n_clusters
                metadata["vocab_size"] == original_shape[0]
                and metadata["hidden_size"] == original_shape[1]
            ):
                print("Loading clustering from cache")
                cache_data = torch.load(cache_file)
                centroids = cache_data["centroids"].to(lm_head.weight.device)
                cluster_assignments = cache_data["cluster_assignments"].to(
                    lm_head.weight.device
                )
                return centroids, cluster_assignments
    except Exception as e:
        print(f"Error loading cache: {e}")
    print("Computing clustering from scratch")
    regular_indices = [
        i for i in range(original_shape[0]) if i not in special_token_ids
    ]
    regular_weights = lm_head.weight[regular_indices]

    centroids, regular_assignments = _kmeans_l2_torchpq(
        regular_weights, n_clusters=n_clusters
    )
    cluster_assignments = torch.zeros(
        original_shape[0], dtype=torch.long, device=lm_head.weight.device
    )

    special_cluster_start = n_clusters
    for idx, token_id in enumerate(special_token_ids):
        cluster_assignments[token_id] = special_cluster_start + idx

    for reg_idx, orig_idx in enumerate(regular_indices):
        cluster_assignments[orig_idx] = regular_assignments[reg_idx]

    _save_cache(
        centroids=centroids,
        cluster_assignments=cluster_assignments,
        cache_file=cache_file,
        original_shape=original_shape,
        metadata_file=metadata_file,
    )
    return centroids, cluster_assignments


def get_flash_head_parameters(
    lm_head: nn.Module,
    tokenizer: AutoTokenizer,
    n_clusters: int,
    special_token_types: Optional[List[str]] = None,
    cache_dir: str = "./ivf_cache_8192",
    verbose: bool = True,
    enforce_equal_cluster_sizes: bool = True,
) -> Tuple[torch.Tensor]:
    """Get parameters for the FlashHead layer.

    :param lm_head:
        The language model head to replace.
    :param tokenizer:
        The tokenizer.
    :param n_clusters:
        The number of clusters.
    :param special_token_types:
        Type of special tokens, defaults to []
    :param cache_dir:
        A cache directory, defaults to './ivf_cache_8192'
    :param verbose:
        Whether to be verbose with the clusterings.

    :return:
        The centroids and clustering assignments to use.
    """
    if special_token_types is None:
        special_token_types = []

    special_token_ids = set()

    for token_type in special_token_types:
        if (
            hasattr(tokenizer, token_type)
            and getattr(tokenizer, token_type) is not None
        ):
            token_id = getattr(tokenizer, f"{token_type}_id")
            if token_id is not None:
                special_token_ids.add(token_id)

    centroids, cluster_assignments = _get_centroids(
        lm_head=lm_head,
        n_clusters=n_clusters,
        special_token_ids=special_token_ids,
        cache_dir=cache_dir,
        enforce_equal_cluster_sizes=enforce_equal_cluster_sizes,
    )

    total_clusters = n_clusters + len(special_token_ids)
    original_shape = lm_head.weight.shape

    if verbose:
        for cluster_id in range(20):
            indices = torch.where(cluster_assignments == cluster_id)[0]
            sample_indices = indices[:100].tolist()
            print(f"\nCluster {cluster_id}:")
            print(
                f"Size: {len(indices)} tokens ({(len(indices)/original_shape[0])*100:.2f}% of vocab)"
            )
            if tokenizer:
                tokens = [
                    tokenizer.decode([idx], skip_special_tokens=False)
                    for idx in sample_indices
                ]
                print(f"Sample tokens: {tokens}")
        print("\nSpecial Token Clusters:")
        for token_id in special_token_ids:
            cluster_id = cluster_assignments[token_id].item()
            token = tokenizer.decode([token_id], skip_special_tokens=False)
            print(f"Token: {token}, Cluster: {cluster_id}")

    cluster_to_vocab_maps = [
        torch.where(cluster_assignments == i)[0] for i in range(total_clusters)
    ]

    combined_centroids = torch.zeros(
        (original_shape[1], total_clusters),
        device=lm_head.weight.device,
        dtype=lm_head.weight.dtype,
    )

    centroids_reshaped = centroids.squeeze(0).squeeze(0)
    combined_centroids[:, :n_clusters] = centroids_reshaped

    special_cluster_start = n_clusters
    for idx, token_id in enumerate(special_token_ids):
        cluster_id = special_cluster_start + idx
        combined_centroids[:, cluster_id] = lm_head.weight[token_id]

    max_len = max(m.shape[0] for m in cluster_to_vocab_maps)
    vocab_maps_tensor = torch.full(
        (len(cluster_to_vocab_maps), max_len), -1, device=lm_head.weight.device
    )
    for i, m in enumerate(cluster_to_vocab_maps):
        length = m.shape[0]
        vocab_maps_tensor[i, :length] = m
        if enforce_equal_cluster_sizes:
            vocab_maps_tensor[i, length:] = m[0]

    return {
        "centroids": combined_centroids,
        "vocab_maps_tensor": vocab_maps_tensor,
        "enforce_equal_cluster_sizes": enforce_equal_cluster_sizes,
    }


class FlashHead(nn.Module):
    """An implementation of the Flash Head."""

    def __init__(
        self,
        lm_head: nn.Linear,
        centroids: torch.Tensor,
        vocab_maps_tensor: torch.Tensor,
        n_probes: int,
        enforce_equal_cluster_sizes: bool,
        forward_type: Literal[
            "partial_logits",
            "monte_carlo_full_logits",
            "approximated_full_logits",
        ],
        always_pick_largest_clusters: bool = False,
        always_pick_smallest_clusters: bool = False,
    ):
        """
        CSP-Head, a drop-in replacement for the classification head.

        :param lm_head:
            The original classification head.
        :param centroids:
            The cluster centroids to use.
        :param vocab_maps_tensor:
            A mapping between cluster centroid index and token index.
        :param n_probes:
            Number of probes to use.
        :param forward_type:
            What type of logits to produce in the forward pass. Options: partial_logits,
            full_logits, no_logits.
        :param always_pick_largest_clusters:
            Whether to always select the clusters of the largest size as part of the probes.
            Note that this will result in poor output quality and is intended to be used for
            generating the upper-bound for on-device performance.
        :param always_pick_smallest_clusters:
            Whether to always select the clusters of the smallest size as part of the probes.
            Note that this will result in poor output quality and is intended to be used for
            generating the lower-bound for on-device performance.
        """
        super().__init__()
        self.original_shape = lm_head.weight.shape
        self.device = lm_head.weight.device
        self.original_lm_head = lm_head
        self.n_probes = n_probes
        self.vocab_maps_tensor = vocab_maps_tensor
        self.vocab_maps_lengths = (vocab_maps_tensor != -1).sum(dim=1)

        self.centroids = centroids.contiguous()
        self.pre_normalized_centroids = centroids / centroids.norm(
            dim=0, keepdim=True
        )
        self.pre_normalized_centroids = (
            self.pre_normalized_centroids.t().contiguous()
        )
        self.cluster_linear = nn.Linear(
            self.pre_normalized_centroids.shape[1],
            self.pre_normalized_centroids.shape[0],
            bias=False,
        )
        self.cluster_linear.weight = nn.Parameter(
            self.pre_normalized_centroids
        ).to(self.pre_normalized_centroids.device)

        self.row_indices = torch.arange(
            vocab_maps_tensor.shape[1], device=lm_head.weight.device
        )[None, :]

        if always_pick_largest_clusters and always_pick_smallest_clusters:
            raise ValueError(
                "always_pick_largest_clusters and always_pick_smallest_clusters are"
                "mutually exclusive."
            )
        self.always_pick_largest_clusters = always_pick_largest_clusters
        self.always_pick_smallest_clusters = always_pick_smallest_clusters
        if self.always_pick_smallest_clusters:
            self.smallest_clusters = (
                torch.sort(self.vocab_maps_lengths, descending=False)[1][
                    :n_probes
                ]
                .unsqueeze(0)
                .unsqueeze(0)
            )
        if self.always_pick_largest_clusters:
            self.largest_clusters = (
                torch.sort(self.vocab_maps_lengths, descending=True)[1][
                    :n_probes
                ]
                .unsqueeze(0)
                .unsqueeze(0)
            )

        self.output_buffer = torch.tensor(
            [[0.0]],
            device=self.pre_normalized_centroids.device,
            dtype=torch.int64,
        )
        self.enforce_equal_cluster_sizes = enforce_equal_cluster_sizes

    def _get_cluster_probs(self, hidden_states, temperature=1.0):
        # Note that for probabilities, we use the "unnormalized" hidden states and cluster centroids,
        # as opposed to top-clusters where we use the normalized. This is due to the scale of logits being
        # suitable for softmax, which is not the case with normalization.
        hidden_states_norm = (
            hidden_states  # / hidden_states.norm(dim=2, keepdim=True)
        )
        similarities = torch.nn.functional.linear(
            hidden_states_norm, self.centroids.t(), bias=None
        )
        probs = torch.softmax(similarities / temperature, dim=-1)
        return probs

    def get_top_clusters(
        self,
        hidden_states: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        Return the clusters most likely to be associated with the true token of `hidden_states`.

        :param hidden_states:
            The output embedding of the model body.
        :param do_sample:
            If ``False`` return the top-k clusters according to similarity. If ``True``,
            sample k clusters without replacement according to probabilities generated from the
            similarities.
        :param temperature:
            The temperature for the softmax operator, only relevant when `do_sample` is ``True``.
        :returns:
            A tensor containing the indices of the clusters.
        """
        if do_sample:
            probs = self._get_cluster_probs(
                hidden_states=hidden_states, temperature=temperature
            )
            B, T, num_clusters = probs.shape
            probs_flat = probs.view(-1, num_clusters)
            sampled_indices = torch.multinomial(
                probs_flat, self.n_probes, replacement=False
            )
            top_clusters = sampled_indices.view(B, T, self.n_probes)
        else:
            # both are contiguous()
            similarities = self.cluster_linear(hidden_states)
            # similarities = torch.nn.functional.linear(
            #     hidden_states,
            #     self.pre_normalized_centroids,
            #     bias=None,
            # )
            _, top_clusters = torch.topk(similarities, k=self.n_probes, dim=-1)
        if self.always_pick_smallest_clusters:
            return self.smallest_clusters
        if self.always_pick_largest_clusters:
            return self.largest_clusters
        return top_clusters

    def _get_cluster_logits(
        self,
        hidden_states: torch.Tensor,
        top_clusters: torch.Tensor,
        use_identical_tiebreak: bool,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        if top_clusters.shape[1] > 1 or top_clusters.shape[0] > 1:
            raise NotImplementedError

        if self.enforce_equal_cluster_sizes:
            cluster_indices = top_clusters[0, 0]
            maps = self.vocab_maps_tensor.index_select(0, cluster_indices)
            indices = maps.flatten()
        else:
            cluster_indices = top_clusters[0, 0]
            lengths = self.vocab_maps_lengths[cluster_indices]
            maps = self.vocab_maps_tensor[cluster_indices]
            mask = self.row_indices < lengths[:, None]
            indices = maps[mask]
            # cluster_indices = top_clusters[0, 0]
            # lengths = self.vocab_maps_lengths.index_select(0, cluster_indices)
            # maps = self.vocab_maps_tensor.index_select(0, cluster_indices)
            # mask = self.row_indices < lengths.unsqueeze(1)
            # indices = torch.masked_select(maps, mask)
        mapping = None

        if use_identical_tiebreak:
            sorted = indices.sort()
            indices = sorted.values
            mapping = sorted.indices

        result = self.original_lm_head.weight.index_select(0, indices)

        final_result = (
            torch.nn.functional.linear(hidden_states, result, bias=None),
            mapping,
        )
        return final_result

    def _get_full_logits_approximation(
        self, hidden_states: torch.Tensor
    ) -> torch.Tensor:
        batch_size, seq_len = hidden_states.shape[:2]
        vocab_size = self.original_lm_head.weight.shape[0]
        num_clusters = self.centroids.shape[-1]

        similarities = torch.nn.functional.linear(
            hidden_states, self.centroids
        )
        cluster_probs = torch.softmax(similarities, dim=-1)

        all_logits = torch.nn.functional.linear(
            self.original_lm_head.weight, hidden_states
        )

        P = torch.zeros(
            (batch_size, seq_len, vocab_size),
            device=self.device,
            dtype=hidden_states.dtype,
        )
        for c in range(num_clusters):
            allowed_tokens = self.vocab_maps_tensor[
                c, : self.vocab_maps_lengths[c]
            ]
            p_c = cluster_probs[:, :, c]
            for token in allowed_tokens.tolist():
                P[:, :, token] += p_c

        adjustment = torch.clamp(self.n_probes * P, max=1.0)
        log_adjustment = torch.log(adjustment)

        effective_logits = all_logits + log_adjustment
        return effective_logits

    def _get_full_logits_monte_carlo(
        self, hidden_states: torch.Tensor
    ) -> torch.Tensor:

        all_probs = None

        n_montecarlo_samples = 400
        for _ in range(n_montecarlo_samples):
            logits = self._get_full_logits_single_probe_sample(
                hidden_states, do_sample=True
            )

            probs = torch.softmax(logits, dim=-1)
            if all_probs is None:
                all_probs = probs
            else:
                all_probs += probs

        all_probs /= n_montecarlo_samples
        reversed_logits = torch.log(all_probs)
        return reversed_logits

    def _get_partial_logits_of_probes(
        self, hidden_states: torch.Tensor, do_sample: bool = False
    ) -> torch.Tensor:
        batch_size, seq_len = hidden_states.shape[:2]
        top_clusters = self.get_top_clusters(
            hidden_states, do_sample=do_sample
        )

        # all_logits = torch.nn.functional.linear(
        #     self.original_lm_head.weight, hidden_states
        # )
        all_logits = self.original_lm_head(hidden_states)

        min_value = all_logits.min().item()  # -np.inf#
        full_logits = torch.full(
            (batch_size, seq_len, self.original_lm_head.weight.shape[0]),
            min_value,
            device=hidden_states.device,
            dtype=hidden_states.dtype,
        )

        for b in range(batch_size):
            for j in range(seq_len):
                clusters = top_clusters[b, j]
                vocab_maps = self.vocab_maps_tensor[clusters]
                lengths = self.vocab_maps_lengths[clusters]
                mask = torch.arange(
                    vocab_maps.shape[1], device=self.device
                ).unsqueeze(0) < lengths.unsqueeze(1)
                vocab_indices = vocab_maps[mask]
                full_logits[b, j, vocab_indices] = all_logits[
                    b, j, vocab_indices
                ]

        if False:
            all_probs = torch.softmax(all_logits, dim=-1)
            reduced_probs = all_probs
            reduced_probs[full_logits == -np.inf] = 0
            summation = torch.sum(reduced_probs, dim=-1, keepdim=True)
            print(summation.mean(), " ", summation.std())
        return full_logits

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # assert(False)
        # Slow and only used for evaluation. Use get_next_token for inference.
        return self._get_partial_logits_of_probes(hidden_states.contiguous())

    def get_next_token(
        self,
        hidden_states: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
        use_identical_tiebreak: bool = True,
    ) -> torch.Tensor:
        """
        Return the next token, given `hidden_states`.

        :param hidden_states:
            The output of the model body.
        :param do_sample:
            Whether to sample the next token according to probabilities,
            or simply return the most probable.
        :param temperature:
            The temperature to use in the softmax
            (both the softmax in cluster probabilities and for the
            softmax in token probabilities).
            Only relevant when `do_sample` is ``True``.
        :param use_identical_tiebreak:
            Whether to reorder the logits so that when two logits are the same,
            the new head will use the same tiebreak as the original.
        :returns:
            The next predicted token, represented as an index.
        """
        profile = False

        if profile:
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
        top_clusters = self.get_top_clusters(
            hidden_states,
            do_sample=do_sample,
            temperature=temperature,
        )
        if profile:
            end_event.record()
            torch.cuda.synchronize()
            time = start_event.elapsed_time(end_event)
            print("Get top clusters ", time * 1000)

        if profile:
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
        cluster_logits, mapping = self._get_cluster_logits(
            hidden_states, top_clusters, use_identical_tiebreak
        )
        if profile:
            end_event.record()
            torch.cuda.synchronize()
            time = start_event.elapsed_time(end_event)
            print("Get logits ", time * 1000)

        if profile:
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            start_event.record()
        if do_sample:
            probs = (cluster_logits[:, -1, :] / temperature).softmax(dim=-1)
            cluster_token_idx = torch.multinomial(probs, num_samples=1)
        else:
            cluster_token_idx = cluster_logits[:, -1, :].argmax(
                dim=-1, keepdim=True
            )
            if use_identical_tiebreak:
                cluster_token_idx = mapping[cluster_token_idx]

        if self.enforce_equal_cluster_sizes:
            cap = self.row_indices.shape[1]
            cluster_index = cluster_token_idx // cap
            relative_pos = cluster_token_idx % cap
            vocab_index = self.vocab_maps_tensor[
                top_clusters[0, 0, cluster_index], relative_pos
            ]
            self.output_buffer[0][0] = vocab_index
            if profile:
                end_event.record()
                torch.cuda.synchronize()
                time = start_event.elapsed_time(end_event)
                print("Final ", time * 1000)
            return self.output_buffer

        cluster_position = cluster_token_idx.item()
        candidate_clusters = top_clusters[0, 0]

        cluster_lengths = self.vocab_maps_lengths[candidate_clusters]

        cumulative_lengths = torch.cumsum(cluster_lengths, dim=0)
        cluster_index = torch.searchsorted(
            cumulative_lengths, cluster_position + 1
        )

        if cluster_index == 0:
            relative_position = cluster_position
        else:
            relative_position = (
                cluster_position - cumulative_lengths[cluster_index - 1]
            )
        vocab_index = self.vocab_maps_tensor[
            candidate_clusters[cluster_index], relative_position
        ]
        next_token = torch.tensor([[vocab_index]], device=hidden_states.device)
        return next_token


def monte_carlo_cluster_inclusion(
    flash_head: FlashHead, hidden_states: torch.Tensor, n_trials: int = 100000
) -> np.array:
    """
    Simulate the cluster inclusion probability P(C ∈ S) with Monte Carlo.

    :param flash_head:
        The FlashHead to analyze.
    :param hidden_states:
        The output of the model body.
    :param n_trials:
        The number of trials to run.
    :returns:
        The simulated probability. Note that this is unnormalized.
    """
    num_clusters = flash_head.centroids.shape[-1]
    counts = torch.zeros(num_clusters, device=hidden_states.device)

    for _ in range(n_trials):
        top_clusters = flash_head.get_top_clusters(
            hidden_states[0:1, 0:1], do_sample=True
        )
        clusters_sampled = top_clusters[0, 0]
        counts[clusters_sampled] += 1

    empirical_inclusion = counts / n_trials
    return empirical_inclusion.cpu().numpy()


def analyse_cluster_inclusion(
    flash_head: FlashHead, hidden_states: torch.Tensor, n_trials: int = 100000
) -> None:
    """
    Analyze the cluster inclusion probability P(C ∈ S) with Monte Carlo.

    :param flash_head:
        The FlashHead to analyze.
    :param hidden_states:
        The output of the model body.
    :param n_trials:
        The number of trials to run.
    """
    empirical_probs = monte_carlo_cluster_inclusion(
        flash_head, hidden_states, n_trials=n_trials
    )
    empirical_probs /= empirical_probs.sum()

    with torch.no_grad():
        cluster_probs = flash_head._get_cluster_probs(
            hidden_states[0:1, 0:1],
        )
        cluster_probs = cluster_probs[0, 0]

    n_probes = flash_head.n_probes
    theoretical_inclusion_upper = (
        torch.clamp(n_probes * cluster_probs, max=1.0).cpu().float().numpy()
    )
    theoretical_inclusion_upper /= theoretical_inclusion_upper.sum()

    theoretical_inclusion_lower = (
        (1 - (1 - cluster_probs) ** n_probes).cpu().float().numpy()
    )
    theoretical_inclusion_lower /= theoretical_inclusion_lower.sum()

    theoretical_inclusion_upper_too_much = (
        (n_probes * cluster_probs).cpu().float().numpy()
    )
    theoretical_inclusion_upper_too_much /= (
        theoretical_inclusion_upper_too_much.sum()
    )

    theoretical_cluster = cluster_probs.cpu().float().numpy()

    for i, (
        emp,
        theo_upper,
        theo_lower,
        theo_upper_too_much,
        theo_cluster,
    ) in enumerate(
        zip(
            empirical_probs,
            theoretical_inclusion_upper,
            theoretical_inclusion_lower,
            theoretical_inclusion_upper_too_much,
            theoretical_cluster,
        )
    ):
        print(
            f"Cluster {i}: "
            f"Empirical Inclusion = {emp:.8f}, "
            f"Theoretical Upper = {theo_upper:.8f}, "
            f"Theoretical Lower = {theo_lower:.8f}, "
            f"Theoretical Upper (too much) = {theo_upper_too_much:.8f}, "
            f"Theoretical Cluster = {theo_cluster:.8f}"
        )
    print(
        "Empirical: ",
        kl_div(
            torch.log(torch.Tensor(empirical_probs)),
            torch.Tensor(empirical_probs),
        ),
        " ",
        empirical_probs.mean(),
        " ",
        empirical_probs.std(),
    )
    print(
        "Theoretical Upper: ",
        kl_div(
            torch.log(torch.Tensor(empirical_probs)),
            torch.Tensor(theoretical_inclusion_upper),
        ).item(),
        " ",
        theoretical_inclusion_upper.mean(),
        " ",
        theoretical_inclusion_upper.std(),
    )
    print(
        "Theoretical Lower: ",
        kl_div(
            torch.log(torch.Tensor(empirical_probs)),
            torch.Tensor(theoretical_inclusion_lower),
        ).item(),
        " ",
        theoretical_inclusion_lower.mean(),
        " ",
        theoretical_inclusion_lower.std(),
    )
    print(
        "Theoretical Upper (too much): ",
        kl_div(
            torch.log(torch.Tensor(empirical_probs)),
            torch.Tensor(theoretical_inclusion_upper_too_much),
        ).item(),
        " ",
        theoretical_inclusion_upper_too_much.mean(),
        " ",
        theoretical_inclusion_upper_too_much.std(),
    )
    print(
        "Theoretical Cluster: ",
        kl_div(
            torch.log(torch.Tensor(empirical_probs)),
            torch.Tensor(theoretical_cluster),
        ).item(),
        " ",
        theoretical_cluster.mean(),
        " ",
        theoretical_cluster.std(),
    )


def get_flash_head_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    cache_dir: str = "Llama-3.2-1B-Instruct-cluster-8192/",
    n_probes: int = 256,
    n_clusters: int = 8192,
    device_map: str = "cuda",
) -> GenerationPipeline:
    """Get the FlashHead pipeline.

    :param model_id:
        The HuggingFace model, defaults to "meta-llama/Llama-3.2-1B-Instruct".
    :param cache_dir:
        The cache directory for the clustering, defaults to
        'Llama-3.2-1B-Instruct-cluster-8192/'
    :param n_probes:
        The number of probes to use, defaults to 256.
    :param n_clusters:
        The number of clusters to use, defaults to 8192.
    :param device_map:
        The device to load the model at.
    :return:
        A model generation pipeline for FlashHead.
    """
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map=device_map,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    flash_head = FlashHead(
        model.lm_head,
        n_probes=n_probes,
        forward_type="partial_logits",
        **get_flash_head_parameters(
            lm_head=model.lm_head,
            tokenizer=tokenizer,
            n_clusters=n_clusters,
            cache_dir=cache_dir,
        ),
    )

    quantize = False
    if quantize:
        from snippets.llm.huggingface.quantize import (
            quantize_head,
            quantize_model,
        )

        flash_head.cluster_linear = quantize_head(flash_head.cluster_linear)
        model.model = quantize_model(model.model)

    model.lm_head = flash_head

    generation_pipeline = GenerationPipeline(
        model.model,
        model.lm_head,
        tokenizer=tokenizer,
        mode="flash_head",
    )
    return generation_pipeline


def get_flash_head_model_and_tokenizer(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    cache_dir: str = None,
    n_probes: int = 256,
    n_clusters: int = 8192,
    forward_type: str = "partial_logits",
    device=None,
):
    """Get a new model and tokenzier with FlashHead as the lm_head."""
    if cache_dir is None:
        cache_dir = "./flash_head_cache"
        print(f"Saving cache dir to {cache_dir}")

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device
    )

    flash_head = FlashHead(
        model.lm_head,
        n_probes=n_probes,
        forward_type=forward_type,
        **get_flash_head_parameters(
            lm_head=model.lm_head,
            tokenizer=tokenizer,
            n_clusters=n_clusters,
            cache_dir=cache_dir,
        ),
    )

    quantize = False
    if quantize:
        from snippets.llm.huggingface.quantize import (
            quantize_head,
            quantize_model,
        )

        # model.lm_head = quantize_head(model.lm_head)
        # flash_head.cluster_linear = quantize_head(flash_head.cluster_linear)
        # model.model = quantize_model(model.model)
    model.lm_head = flash_head

    return model, tokenizer


def get_spherical_k_means_model_and_tokenizer(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    cache_dir: str = None,
    n_clusters: int = 8192,
    forward_type: str = "partial_logits",
    device=None,
):
    """Get Spherical-K means model and tokenizer."""
    return get_flash_head_model_and_tokenizer(
        model_id,
        cache_dir=cache_dir,
        n_probes=1,  # Fixed for Spherical K Means
        n_clusters=n_clusters,
        forward_type=forward_type,
        device=device,
    )
