"""SVD softmax layer implementation"""

import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

from efficient_heads.pipeline import GenerationPipeline

torch.manual_seed(0)


class SVDSoftmaxHead(nn.Module):
    """
    SVD-softmax layer implementation.

    Implementation of the layer from the paper "SVD-Softmax: Fast Softmax
    Approximation on Large Vocabulary Neural Networks".
    """

    def __init__(
        self,
        weights: torch.Tensor,
        window: int = 256,
        top_n: int = 16000,
        dtype: torch.dtype = torch.bfloat16,
    ) -> None:
        """SVD softmax layer implementation.

        This layer can replace the lm_head layer of a model using the weights
        from the lm_head.linear as:

        ```
        model.lm_head = SVDSoftmaxHead(
            weights=model.lm_head.weight,
            window=window,
            top_n=top_n
        )
        ```

        :param weights:
            The weights of the lm_head layer.
        :param window:
            The window size to select from the hidden vector, defaults to 256.
        :param top_n:
            The number of top_n vectors to use from the total vocabulary to
            refine the logits by doing a full dot product, defaults to 16000.
        :param dtype:
            The datatype to use, defaults to `torch.bfloat16`.
        """
        super().__init__()
        self.window = window
        self.top_n = top_n
        self.dtype = dtype

        device = weights.device
        # Compute SVD decomposition once and store
        U, S, Vt = torch.linalg.svd(
            weights.cpu().to(torch.float), full_matrices=False
        )
        B = U @ torch.diag(S)

        self.register_buffer("B", B.to(dtype=dtype, device=device))
        self.register_buffer("Vt", Vt.to(dtype=dtype, device=device))

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Forward pass for the SVD-softmax logits computation.

        :param hidden_states:
            The hidden states from the last decoder block.
        :return:
            The output logits computed by SVD-Softmax.
        """
        batch_size, seq_len, hidden_size = hidden_states.shape

        h_flat = hidden_states.reshape(-1, hidden_size).to(
            self.dtype
        )  # (batch_size*seq_len, hidden_size)

        h_tilde = h_flat @ self.Vt.T

        z_tilde = h_tilde[:, : self.window] @ self.B[:, : self.window].T
        _, top_n_indices = torch.topk(z_tilde, self.top_n, dim=1)

        # Batched refine logits with full vectors
        B_top = self.B[top_n_indices]
        h_exp = h_tilde.unsqueeze(1)
        z_refined = torch.bmm(h_exp, B_top.transpose(1, 2)).squeeze(1)

        # Initialize output logits then scatter in the new top-n
        refined_logits = z_tilde
        refined_logits.scatter_(1, top_n_indices, z_refined)

        # Reshape to original (batch_size, seq_len, vocab_size)
        logits = refined_logits.view(batch_size, seq_len, -1)
        return logits.to(dtype=hidden_states.dtype)

    def get_next_token(
        self,
        hidden_states: torch.Tensor,
        do_sample: bool = False,
        temperature: float = 1.0,
    ):
        """
        Generate the next tokens from the hidden states of the model.

        :param hidden_states:
            The hidden states from the model body.
        :param do_sample:
            Whether to randomly sample for the next token.
        :param temperature:
            The softmax temperature, applicable only if `do_sample` is ``True``.

        :return:
            The next token index.
        """
        logits = self.forward(hidden_states=hidden_states)

        if do_sample:
            probs = (logits[:, -1, :] / temperature).softmax(dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
        else:
            next_token = logits[:, -1:].argmax(dim=-1)
        return next_token


def get_svd_softmax_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    window: int = 256,
    top_n: int = 16000,
    device_map: str = "cuda",
) -> GenerationPipeline:
    """Get the generation pipeline for SVD-Softmax.

    :param model_id:
        The baseline model id to start from.
    :param window:
        The window size to consider.
    :param top_n:
        The number of top n vectors to use from the vocabulary.
    :param device_map:
        The device to load the model at.
    :return:
        A generation pipeline for SVD-softmax.
    """
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device_map
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model.lm_head = SVDSoftmaxHead(
        weights=model.lm_head.weight,
        window=window,
        top_n=top_n,
    )
    generation_pipeline = GenerationPipeline(
        model.model,
        model.lm_head,
        tokenizer=tokenizer,
    )
    return generation_pipeline


def get_svd_softmax_model_and_tokenizer(
    model_id: str, window: int, top_n: int, device=None
):
    """Get SVD softmax model and tokenizer"""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model.lm_head = SVDSoftmaxHead(
        model.lm_head.weight, window=window, top_n=top_n
    )
    return model, tokenizer
