import transformers
import dataclasses
import torch.nn.functional
import numpy as np
import json
import os

from .query_format import BasicQueryFormat


class LIModelWrapper:
    def __init__(
        self,
        model_path: str,
        device: str = "auto",
        caching: bool = True,
        cache_dir: str | None = None,
        seed: int = 42,
    ) -> None:
        self.device = device
        self.model = transformers.LlamaForCausalLM.from_pretrained(
            model_path, device_map=device
        )
        self.tokenizer = transformers.LlamaTokenizerFast.from_pretrained(model_path, legacy=True)
        self.model_path = model_path
        self.model_type = str(self.model.config.model_type)

        self.caching = caching
        self.cache_dir = cache_dir
        #self._cached_responses: list[dict[str, str | list | None]] = []
        self.rng = np.random.default_rng(seed=seed)

        if caching:
            self.cache_file_path = self._setup_caching()

    def _setup_caching(self) -> str:
        assert self.cache_dir is not None
        assert self.caching is True
        os.makedirs(self.cache_dir, exist_ok=True)

        if not os.path.exists(os.path.join(self.cache_dir, "candidates_cache.jsonl")):
            f = open(os.path.join(self.cache_dir, "candidates_cache.jsonl"), "w")
            f.close()
        return os.path.join(self.cache_dir, "candidates_cache.jsonl")

    def _get_cached_response(
        self,
        prompt_id: int,
        prompt: str,
        max_new_tokens: int,
        model: str,
    ) -> "LIModelResponse | None":
        if not self.caching:
            return None
        assert self.cache_file_path is not None


        with open(self.cache_file_path, "rt", encoding="utf-8") as f:
            for line in f:
                entry = json.loads(line)
                if (
                    entry["prompt"] == prompt
                    and entry["max_new_tokens"] == max_new_tokens
                    and entry["model"] == model
                ):
                    return LIModelResponse(
                        text=entry["response"],
                        tokens=[],
                        probs=[],
                        model=entry["model"],
                    )
        return None

    def _cache_response(
        self,
        prompt_id: int,
        prompt: str,
        response: "LIModelResponse",
        max_new_tokens: int,
        model: str,
    ) -> None:
        if not self.caching:
            return

        assert self.cache_file_path is not None

        with open(self.cache_file_path, "a", encoding="utf-8") as f:
            entry = {
                "prompt_id": prompt_id,
                "prompt": prompt,
                "response": response.text,
                "tokens": [],
                "probs": [],
                "model": model,
                "max_new_tokens": max_new_tokens,
            }
            f.write(json.dumps(entry) + "\n")


    def run_basic_summary(
        self,
        query: BasicQueryFormat,
        summary_prefix: str = "",
        max_new_tokens: int = 200,
        **kwargs,
    ) -> "LIModelResponse":
        s = query.get_summary_prompt(summary_prefix)
        return self.run_prompt(s, max_new_tokens=max_new_tokens, **kwargs)

    def run_prompt(
        self,
        prompt: str,
        max_new_tokens: int = 200,
        eos_token: str | None = None,
        do_sample: bool = False,
        truncation_strings: list[str] = ["\n\nSkills:"],
        **kwargs,
    ) -> "LIModelResponse":
        if eos_token is None:
            eos_token_id = self.model.config.eos_token_id
        else:
            eos_token_id = self.tokenizer.encode(eos_token)[-1]

        prompt_id = int(self.rng.integers(0, 2**31 - 1))
        #print(f"Prompt ID: {prompt_id}: {prompt}")
        cached_response = self._get_cached_response(
            prompt_id=prompt_id,
            prompt=prompt,
            max_new_tokens=max_new_tokens,
            model=self.model_type,
        )
        if cached_response is not None:
            return cached_response

        inputs = self.tokenizer(prompt, return_tensors="pt")
        out = self.model.generate(
            inputs.input_ids.to("cuda"),
            temperature=1.0,
            attention_mask=inputs["attention_mask"].to("cuda"),
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
            eos_token_id=eos_token_id,
            do_sample=do_sample,
            return_dict_in_generate=True,
            output_scores=True,
            **kwargs,
        )
        assert isinstance(
            out, transformers.generation.utils.GreedySearchDecoderOnlyOutput
        )
        assert out.scores is not None
        seq = out.sequences[0].cpu()[len(inputs.input_ids[0]) :]
        text = self.tokenizer.decode(seq)
        for truncation_string in truncation_strings:
            if truncation_string in text:
                text = text[: text.index(truncation_string)]
                break
        tokens = seq.numpy()
        probs = []

        for score, token in zip(out.scores, tokens):
            probs.append(torch.nn.functional.softmax(score, dim=1)[0, token].item())

        ret = LIModelResponse(
            text=text,
            tokens=tokens.tolist(),  # type: ignore
            probs=probs,
            model=self.model_type,
        )
        self._cache_response(
            prompt_id=prompt_id,
            prompt=prompt,
            response=ret,
            max_new_tokens=max_new_tokens,
            model=self.model_type,
        )
        return ret


@dataclasses.dataclass
class LIModelResponse:
    text: str
    tokens: list[int]
    probs: list[float]
    model: str | None

    def to_dict(self) -> dict[str, str | list | None]:
        return {
            "text": self.text,
            #'model': self.model,
            #'tokens': self.tokens,
            #'probs': self.probs,
        }
