from pathlib import Path
from typing import Optional

import lightning as L
import torch

from generate import generate
from lit_llama import Tokenizer
from lit_llama.lora import lora as lora_model
from lit_llama.utils import EmptyInitOnDevice, lazy_load
from scripts.prepare_alpaca import generate_prompt
from scripts.prepare_alpaca_no_prompt import generate_no_prompt

torch.set_float32_matmul_precision("high")

lora_r = 8
lora_alpha = 16
lora_dropout = 0.05


class FinetunedAdapter:
    from lit_llama.adapter import LLaMA, LLaMAConfig

    def __init__(
        self,
        adapter_path: Optional[Path] = None,
        pretrained_path: Optional[Path] = None,
        tokenizer_path: Optional[Path] = None,
        quantize: Optional[str] = None,
    ) -> None:
        if not adapter_path:
            adapter_path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth")
        if not pretrained_path:
            pretrained_path = Path("./checkpoints/lit-llama/7B/lit-llama.pth")
        if not tokenizer_path:
            tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

        assert adapter_path.is_file()
        assert pretrained_path.is_file()
        assert tokenizer_path.is_file()

        self.fabric = L.Fabric(devices=1)
        dtype = (
            torch.bfloat16
            if self.fabric.device.type == "cuda" and torch.cuda.is_bf16_supported()
            else torch.float32
        )

        with EmptyInitOnDevice(
            device=self.fabric.device, dtype=dtype, quantization_mode=quantize
        ):
            self.model = self.LLaMA(self.LLaMAConfig())

        # 1. Load the pretrained weights
        pretrained_checkpoint = lazy_load(pretrained_path)
        self.model.load_state_dict(pretrained_checkpoint, strict=False)

        # 2. Load the fine-tuned adapter weights
        adapter_checkpoint = lazy_load(adapter_path)
        self.model.load_state_dict(adapter_checkpoint, strict=False)

        self.model.eval()
        self.model = self.fabric.setup_module(self.model)

        self.tokenizer = Tokenizer(tokenizer_path)

    def load_adapter(self, adapter_path: Path):
        assert adapter_path.is_file()

        adapter_checkpoint = lazy_load(adapter_path)
        self.model.load_state_dict(adapter_checkpoint, strict=False)

    def generate(
        self,
        instruction: str = "",
        input_text: str = "",
        max_new_tokens: int = 100,
        top_k: int = 200,
        temperature: float = 0.8,
        use_instruction: bool = True,
    ):
        if use_instruction:
            sample = {"instruction": instruction, "input": input_text}
            prompt = generate_prompt(sample)
        else:
            assert input_text, "input_text must be provided if use_prompt is False."
            assert (
                len(instruction) == 0
            ), "instruction must be empty if use_prompt is False."
            prompt = generate_no_prompt(input_text)

        encoded = self.tokenizer.encode(
            prompt, bos=True, eos=False, device=self.model.device
        )

        output = generate(
            self.model,
            idx=encoded,
            max_seq_length=max_new_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            eos_id=self.tokenizer.eos_id,
        )

        output = self.tokenizer.decode(output)
        output = output.split("### Response:")[1].strip()
        return output


class FinetunedLora:
    from lit_llama import LLaMA, LLaMAConfig

    def __init__(
        self,
        lora_path: Optional[Path] = None,
        pretrained_path: Optional[Path] = None,
        tokenizer_path: Optional[Path] = None,
        quantize: Optional[str] = None,
        dtype: str = "float32",
    ):
        """Generates a response based on a given instruction and an optional input.
        This script will only work with checkpoints from the instruction-tuned LoRA model.
        See `finetune_lora.py`.

        Args:
            prompt: The prompt/instruction (Alpaca style).
            lora_path: Path to the checkpoint with trained LoRA weights, which are the output of
                `finetune_lora.py`.
            input: Optional input (Alpaca style).
            pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
            tokenizer_path: The tokenizer path to load.
            quantize: Whether to quantize the model and using which method:
                ``"llm.int8"``: LLM.int8() mode,
                ``"gptq.int4"``: GPTQ 4-bit mode.
            dtype: The dtype to use during generation.
            max_new_tokens: The number of generation steps to take.
            top_k: The number of top most probable tokens to consider in the sampling process.
            temperature: A value controlling the randomness of the sampling process. Higher values result in more random
                samples.
        """
        if not lora_path:
            lora_path = Path("out/lora/alpaca/lit-llama-lora-finetuned.pth")
        if not pretrained_path:
            pretrained_path = Path("./checkpoints/lit-llama/7B/lit-llama.pth")
        if not tokenizer_path:
            tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

        assert lora_path.is_file()
        assert pretrained_path.is_file()
        assert tokenizer_path.is_file()

        if quantize is not None:
            raise NotImplementedError("Quantization in LoRA is not supported yet")

        fabric = L.Fabric(devices=1)

        dt = getattr(torch, dtype, None)
        if not isinstance(dt, torch.dtype):
            raise ValueError(f"{dtype} is not a valid dtype.")
        dtype = dt

        with EmptyInitOnDevice(
            device=fabric.device, dtype=dtype, quantization_mode=quantize
        ), lora_model(r=8, alpha=16, dropout=0.05, enabled=True):
            self.model = self.LLaMA(self.LLaMAConfig())

        # 1. Load the pretrained weights
        pretrained_checkpoint = lazy_load(pretrained_path)
        self.model.load_state_dict(pretrained_checkpoint, strict=False)

        # 2. Load the fine-tuned LoRA weights
        lora_checkpoint = lazy_load(lora_path)
        self.model.load_state_dict(lora_checkpoint, strict=False)

        self.model.eval()
        self.model = fabric.setup_module(self.model)

        self.tokenizer = Tokenizer(tokenizer_path)

    def load_adapter(self, adapter_path: Path):
        assert adapter_path.is_file()

        adapter_checkpoint = lazy_load(adapter_path)
        self.model.load_state_dict(adapter_checkpoint, strict=False)

    def generate(
        self,
        instruction: str = "",
        input_text: str = "",
        max_new_tokens: int = 100,
        top_k: int = 200,
        temperature: float = 0.8,
        use_instruction: bool = True,
    ):
        if use_instruction:
            sample = {"instruction": instruction, "input": input_text}
            prompt = generate_prompt(sample)
        else:
            assert input_text, "input_text must be provided if use_prompt is False."
            assert (
                len(instruction) == 0
            ), "instruction must be empty if use_instruction is False."
            prompt = generate_no_prompt(input_text)

        encoded = self.tokenizer.encode(
            prompt, bos=True, eos=False, device=self.model.device
        )

        output = generate(
            self.model,
            idx=encoded,
            max_seq_length=max_new_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            eos_id=self.tokenizer.eos_id,
        )

        output = self.tokenizer.decode(output)
        output = output.split("### Response:")[1].strip()
        return output


class FinetunedAll:
    # from generate_full import generate as generate_full
    from lit_llama import LLaMA

    def __init__(
        self,
        checkpoint_path: Optional[Path] = None,
        tokenizer_path: Optional[Path] = None,
        model_size: str = "7B",
        quantize: Optional[str] = None,
        seed: int = 1234,
    ) -> None:
        """Generates text samples based on a pre-trained LLaMA model and tokenizer.

        Args:
            prompt: The prompt string to use for generating the samples.
            num_samples: The number of text samples to generate.
            max_new_tokens: The number of generation steps to take.
            top_k: The number of top most probable tokens to consider in the sampling process.
            temperature: A value controlling the randomness of the sampling process. Higher values result in more random
                samples.
            checkpoint_path: The checkpoint path to load.
            tokenizer_path: The tokenizer path to load.
            model_size: The model size to load.
            quantize: Whether to quantize the model and using which method:
                ``"llm.int8"``: LLM.int8() mode,
                ``"gptq.int4"``: GPTQ 4-bit mode.
        """
        if not checkpoint_path:
            checkpoint_path = Path(
                f"./checkpoints/lit-llama/{model_size}/lit-llama.pth"
            )
        checkpoint_path = Path(checkpoint_path)
        if not tokenizer_path:
            tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
        assert checkpoint_path.is_file(), checkpoint_path
        assert tokenizer_path.is_file(), tokenizer_path

        fabric = L.Fabric(devices=1)
        dtype = (
            torch.bfloat16
            if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported()
            else torch.float32
        )

        print("Loading model ...")
        with EmptyInitOnDevice(
            device=fabric.device, dtype=dtype, quantization_mode=quantize
        ):
            self.model = self.LLaMA.from_name(model_size)

        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint)

        self.model.eval()
        self.model = fabric.setup_module(self.model)

        self.tokenizer = Tokenizer(tokenizer_path)

        L.seed_everything(seed)

    def load_checkpoint(self, checkpoint_path: Path):
        assert checkpoint_path.is_file()

        checkpoint = torch.load(checkpoint_path)
        self.model.load_state_dict(checkpoint)

    def generate(
        self,
        prompt: str = "Hello, my name is",
        max_new_tokens: int = 50,
        top_k: int = 200,
        temperature: float = 0.8,
    ):
        prompt = generate_no_prompt(prompt)
        encoded_prompt = self.tokenizer.encode(
            prompt, bos=True, eos=False, device=self.model.device
        )

        output = self.generate_full(
            model=self.model,
            idx=encoded_prompt,
            max_new_tokens=max_new_tokens,
            max_seq_length=self.model.config.block_size,
            temperature=temperature,
            top_k=top_k,
            eos_id=self.tokenizer.eos_id,
        )
        output = self.tokenizer.decode(output)
        return output.split("\n\n### Response:")[1].strip()

    @staticmethod
    def generate_full(
        model: torch.nn.Module,
        idx: torch.Tensor,
        max_new_tokens: int,
        max_seq_length: int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        eos_id: Optional[int] = None,
    ) -> torch.Tensor:
        """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

        The implementation of this function is modified from A. Karpathy's nanoGPT.

        Args:
            model: The model to use.
            idx: Tensor of shape (T) with indices of the prompt sequence.
            max_new_tokens: The number of new tokens to generate.
            max_seq_length: The maximum sequence length allowed.
            temperature: Scales the predicted logits by 1 / temperature
            top_k: If specified, only sample among the tokens with the k highest probabilities
            eos_id: If specified, stop generating any more token once the <eos> token is triggered
        """
        # create an empty tensor of the expected final shape and fill in the current tokens
        T = idx.size(0)
        T_new = T + max_new_tokens
        empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
        empty[:T] = idx
        idx = empty

        # generate max_new_tokens tokens
        for t in range(T, T_new):
            # ignore the not-filled-yet tokens
            idx_cond = idx[:t]
            # if the sequence context is growing too long we must crop it at max_seq_length
            idx_cond = idx_cond if T <= max_seq_length else idx_cond[-max_seq_length:]

            # forward
            logits = model(idx_cond.view(1, -1))
            logits = logits[0, -1] / temperature

            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[[-1]]] = -float("Inf")

            probs = torch.nn.functional.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            # concatenate the new generation
            idx[t] = idx_next

            # if <eos> token is triggered, return the output (stop generation)
            if idx_next == eos_id:
                return idx[: t + 1]  # include the EOS token

        return idx
