"""
Minimal wrapper for a FAISS Hierarchical-NSW (HNSW) index
with the *inner-product preserving transformation* used by
Zhang et al. (2018) – the transformation converts
   score = h·x + b
to a Euclidean distance search.

The module exposes two things:
    build_hnsw(emb, bias, M, efC) → (index, Umax)
    search_hnsw(index, q, K, efS, Umax) → (ids, sim)

* emb:  (V, D)  torch / np float32
* bias: (V,)    torch / np float32
"""

from typing import Optional, Tuple

import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from efficient_heads.pipeline import GenerationPipeline


def _ippt(emb: np.ndarray, bias: np.ndarray, U: float) -> np.ndarray:
    """
    Inner-Product Preserving Transformation   φ(x,b) ∈ ℝ^{D+2}
    φ(x,b) = [x, b, sqrt(U^2 - ||x||^2 - b^2)]
    """
    norm2 = (emb**2).sum(-1) + bias**2  # (V,)
    pad = np.sqrt(np.clip(U**2 - norm2, 0, None))  # (V,)
    return np.concatenate([emb, bias[:, None], pad[:, None]], axis=1).astype(
        np.float32
    )


def build_hnsw(
    emb: torch.Tensor,
    bias: torch.Tensor,
    M: int = 32,
    ef_construction: int = 200,
) -> Tuple[faiss.IndexHNSW, float]:
    emb_np = emb.cpu().float().numpy()
    bias_np = bias.cpu().float().numpy()
    # ---- choose a radius U  -------------------------------------------
    U = float(np.sqrt((emb_np**2).sum(1).max() + (bias_np**2).max())) + 1e-4
    xb = _ippt(emb_np, bias_np, U)  # (V, D+2)
    d = xb.shape[1]

    index = faiss.IndexHNSWFlat(d, M, faiss.METRIC_L2)
    index.hnsw.efConstruction = ef_construction
    index.verbose = False
    index.add(xb)
    return index, U


def search_hnsw(
    index: faiss.IndexHNSW,
    q: torch.Tensor,
    K: int,
    ef_search: int,
    U: float,
):
    index.hnsw.efSearch = ef_search
    with torch.no_grad():
        q_np = q.float().cpu().numpy()  # (B,D)
        # embed -> [h, 1, 0]  (IPPT for context)
        pad = np.zeros((q_np.shape[0], 2), dtype=np.float32)
        pad[:, 0] = 1.0
        q_ippt = np.concatenate([q_np, pad], axis=1).astype(np.float32)
        D2, I = index.search(q_ippt, K)  # squared L2 distance
        # convert back to logits:  s = 0.5*(||h||^2 + U^2 - D2)
        h_norm2 = (q_np**2).sum(1, keepdims=True)
        sim = 0.5 * (h_norm2 + U**2 - D2)  # (B,K)
        return torch.from_numpy(I), torch.from_numpy(sim)


class FGDHead(nn.Module):
    """
    Fast-Graph Decoder head
    --------------------------------
    • During __init__ we *clone* the original LM-head weight
      (and bias if present), build the graph index offline.
    • During inference we:
        1. compute context `h`   … shape (B, 1, D)
        2. call graph search →   ids, sim   with K ≪ V
        3. scatter into a sparse-logit tensor, run full soft-max
           on that slice only, arg-max / sample.
    """

    def __init__(
        self,
        lm_head: nn.Linear,
        K: int = 256,
        ef_search: int = 100,
        M: int = 32,
        ef_construction: int = 200,
        device: Optional[torch.device] = None,
    ):
        super().__init__()
        self.original_lm_head = lm_head
        assert isinstance(lm_head, nn.Linear)
        self.weight = (
            lm_head.weight.detach().clone().to(torch.float32)
        )  # (V,D)
        self.bias = (
            lm_head.bias.detach().clone().to(torch.float32)
            if lm_head.bias is not None
            else torch.zeros(self.weight.size(0))
        )
        self.K = K
        self.ef_search = ef_search

        # ---- build graph  -------------------------------------------------
        self.index, self.U = build_hnsw(
            self.weight, self.bias, M=M, ef_construction=ef_construction
        )
        # move weight / bias to evaluation dtype+device afterwards
        self.register_buffer("weight_bf16", self.weight.to(torch.bfloat16))
        self.register_buffer("bias_bf16", self.bias.to(torch.bfloat16))
        self.device = device or self.weight_bf16.device

    # ------------------------------------------------------------------ #
    # Helpers used by GenerationPipeline
    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def get_next_token(
        self,
        hidden: torch.Tensor,  # (B, 1, D)
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        • hidden comes from last_hidden_state      bf16 / fp16
        • we operate in fp32 for the search, then cast back.
        """
        B, _, D = hidden.shape
        h = hidden.squeeze(1).to(torch.float32)  # (B,D)

        # ---- graph search (CPU) ----------------------------------------
        ids, sims = search_hnsw(
            self.index, h, self.K, self.ef_search, self.U
        )  # both tensors on CPU fp32
        ids = ids.to(self.device)  # (B,K)
        sims = sims.to(self.device)  # (B,K)

        ids = ids.clamp(min=0)

        # ---- build sparse logit tensor ---------------------------------
        logits_slice = sims.to(hidden.dtype)  # (B,K)
        # gather bias already included in sims, so no extra add.

        if do_sample:
            probs = (logits_slice / temperature).softmax(-1)
            next_tok = torch.multinomial(
                probs, num_samples=1
            )  # (B,1) – indices in slice
            next_tok = ids.gather(1, next_tok)  # vocab ids
        else:
            max_idx_in_slice = logits_slice.argmax(-1, keepdim=True)  # (B,1)
            next_tok = ids.gather(1, max_idx_in_slice)  # vocab ids

        return next_tok  # shape (B,1)

    # ------------------------------------------------------------------ #
    # NOTE: GenerationPipeline still calls `original_head(...)` to acquire
    # full logits for the *prompt* pass.  We therefore expose a forward
    # that replicates the dense mat-mul for that first step.
    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def forward(
        self,
        hidden: torch.Tensor,  # (B, L, D)   — L can be 1 or larger
    ) -> torch.Tensor:  # -> (B, L, V)
        """
        Returns *dense* logits for the entire vocabulary.
        Only the K positions returned by the HNSW search contain real values;
        every other position is –inf so its soft-max mass is zero.
        """
        B, L, D = hidden.shape  # note: L is now kept
        h = hidden.reshape(B * L, D).to(torch.float32)  # (B·L, D)

        # ---- ANN search ---------------------------------------------------
        ids, sims = search_hnsw(
            self.index, h, self.K, self.ef_search, self.U
        )  # (B·L, K)
        ids = ids.to(self.device)
        sims = sims.to(self.device)  # fp32 on device

        ids = ids.clamp(min=0)

        # ---- sparse logits ------------------------------------------------
        logits_slice = sims.to(hidden.dtype)  # (B·L, K)

        V = self.weight.shape[0]
        fill_val = torch.finfo(logits_slice.dtype).min  # == -inf for dtype
        logits_flat = torch.full(
            (B * L, V), fill_val, device=self.device, dtype=logits_slice.dtype
        )  # (B·L, V)

        logits_flat.scatter_(1, ids, logits_slice)  # fill K positions

        # reshape back to (B, L, V)
        logits = logits_flat.view(B, L, V)
        return logits


def get_fgd_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    K: int = 256,
    ef_search: int = 100,
    index_M: int = 32,
    ef_construction: int = 200,
) -> GenerationPipeline:
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map="cuda"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model.lm_head = FGDHead(
        lm_head=model.lm_head,
        K=K,
        ef_search=ef_search,
        M=index_M,
        ef_construction=ef_construction,
    )

    return GenerationPipeline(
        model_body=model.model,
        model_head=model.lm_head,
        tokenizer=tokenizer,
        mode="midx",  # we re-use the “midx/flash” branch
    )


def get_fgd_model_and_tokenizer(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    K: int = 256,
    ef_search: int = 100,
    index_M: int = 32,
    ef_construction: int = 200,
    device=None,
):

    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model.lm_head = FGDHead(
        lm_head=model.lm_head,
        K=K,
        ef_search=ef_search,
        M=index_M,
        ef_construction=ef_construction,
    )
    return model, tokenizer
