"""MIDX head implementation"""

import os
import subprocess
import sys
from typing import Optional


def setup_midx_repo(
    repo_url: str = "https://github.com/XuHwang/MIDX_Journal",
    clone_path: str = "/tmp/MIDX_Journal",
) -> None:
    """Clone the MIDX repository to a temporary location and append to path.

    :param repo_url:
        The Git URL of the MIDX repository, defaults to
        "https://github.com/XuHwang/MIDX_Journal".
    :param clone_path:
        The local path to clone the repo into, defaults to "/tmp/MIDX_Journal".
    """
    if not os.path.exists(clone_path):
        print(f"[INFO] Cloning MIDX repository to {clone_path}")
        subprocess.run(["git", "clone", repo_url, clone_path], check=True)
    else:
        print(f"[INFO] MIDX repository already exists at {clone_path}")

    path_to_add = os.path.dirname(clone_path)

    if path_to_add not in sys.path:
        sys.path.insert(0, path_to_add)
        print(f"[INFO] Added {path_to_add} to sys.path")


# Setup repo and make sure MIDX_Journal can be imported
setup_midx_repo()

import torch

# Now we can import it
from MIDX_Journal.src.sampler.midx import MIDXSamplerPop, MIDXSamplerUniform
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

from efficient_heads.pipeline import GenerationPipeline


class MIDXHead(nn.Module):
    """
    Wraps the paper’s *single-cell / uniform-inside* sampler so that
    GenerationPipeline can call:

        next_tok = midx_head.get_next_token(last_hidden_state)

    Notes
    -----
    * Only product-quantisation (two halves) is implemented because that’s
      what the official code does.
    * Centroids and bucket statistics are built once from `lm_head.weight`.
    """

    def __init__(
        self,
        lm_head: nn.Linear,
        n_codewords: int = 32,
        pop_mode: Optional[int] = None,
    ):
        super().__init__()
        self.weight = lm_head.weight  # bfloat16 on GPU
        V, H = self.weight.shape

        # 1) build on CPU in float32
        SamplerCls = (
            MIDXSamplerPop if pop_mode is not None else MIDXSamplerUniform
        )
        self.sampler = SamplerCls(
            num_items=(V - 1 if pop_mode else V), num_clusters=n_codewords
        )
        self.sampler.update(self.weight.detach().float().cpu())

        # 2) now move to GPU and cast to the same dtype as weight
        self._move_and_cast_sampler_(self.weight.device, self.weight.dtype)

    def _move_and_cast_sampler_(self, device, dtype):
        """
        Move every tensor in the sampler to `device`,
        and cast floating tensors to `dtype` (bfloat16), but leave integer tensors as long.
        """

        def _to_(obj):
            if isinstance(obj, torch.Tensor):
                obj = obj.to(device)
                return obj.to(dtype) if obj.dtype.is_floating_point else obj
            if isinstance(obj, (list, tuple)):
                return type(obj)(_to_(x) for x in obj)
            if isinstance(obj, dict):
                return {k: _to_(v) for k, v in obj.items()}
            return obj  # leave non-tensor scalars unchanged

        for k, v in vars(self.sampler).items():
            setattr(self.sampler, k, _to_(v))

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        # full-softmax fallback
        return h.to(self.weight.dtype) @ self.weight.t()

    @torch.no_grad()
    def get_next_token(
        self,
        hidden_states: torch.Tensor,
        *,
        do_sample: bool = False,
        temperature: float = 1.0,
    ) -> torch.Tensor:
        """
        hidden_states: (B, T, H) in bfloat16, we run sampler in bfloat16 as well
        """
        q = hidden_states[:, -1, :].to(self.weight.dtype)  # (B, H) bfloat16
        tokens, _ = self.sampler(q, num_neg=1)  # uses c0,c1 in bfloat16
        return tokens  # (B,1)


def get_midx_pipeline(
    model_id: str = "meta-llama/Llama-3.2-1B-Instruct",
    n_codewords: int = 32,
    device_map: str = "cuda",
) -> GenerationPipeline:
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device_map
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model.lm_head = MIDXHead(
        model.lm_head,
        n_codewords=n_codewords,
    )

    return GenerationPipeline(
        model.model,
        model.lm_head,
        tokenizer=tokenizer,
        mode="midx",
    )


def get_midx_model_and_tokenizer(model_id, n_codewords: int, device=None):
    """Get midx model and tokenizer"""
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.bfloat16, device_map=device
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model.lm_head = MIDXHead(model.lm_head, n_codewords=n_codewords)

    return model, tokenizer


if __name__ == "__main__":
    import json

    from efficient_heads.eval_latency import compare_outputs, measure_latency
    from efficient_heads.pipeline import get_standard_pipeline

    standard_pipeline = get_standard_pipeline()
    print("\nRunning MIDX pipeline …")
    results = {}
    results["midx"] = {}
    for K in [16, 32, 64]:
        print(f"\nTesting K={K}")
        pipe = get_midx_pipeline(n_codewords=K)
        results["midx"][str(K)] = {
            "latency": measure_latency(pipe),
            "agreement": compare_outputs(
                standard_pipeline, pipe, prompts=None
            ),
        }
    with open("results_midx.json", "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2)
