# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Literal, Optional
import re
import random
import lightning as L
import torch
import torch._dynamo.config
import torch._inductor.config
from lightning.fabric.plugins import BitsandbytesPrecision
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import GPT, Config, Tokenizer
from lit_gpt.utils import check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint, num_parameters


def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
    if torch._dynamo.is_compiling():
        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
        distribution = torch.empty_like(probs).exponential_(1)
        return torch.argmax(probs / distribution, dim=-1, keepdim=True)
    return torch.multinomial(probs, num_samples=1)


def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
    logits = logits[0, -1]
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    if temperature > 0.0:
        probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)


def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    logits = model(x, input_pos=input_pos)
    next = sample(logits, **kwargs)
    return next.to(dtype=x.dtype)

def prepare_words(word_file):
    text = Path(word_file).read_text()
    text = re.sub(r"[0-9\(\)\.\–\-/\"?:;,\]\[]", " ", text)
    words = [w for w in re.split(r"\s+", text) if w]
    return words

def text_generator(words, fabric, tokenizer, n):
    text = " ".join(random.choices(words, k=n-int(n/2)))
    print(text)
    encoded = tokenizer.encode(text, device=fabric.device)
    prompt_length = encoded.size(0)
    while prompt_length < n:
        text = " ".join(random.choices(words, k=1))
        encoded_add = tokenizer.encode(text, device=fabric.device)
        encoded = torch.cat((encoded, encoded_add))
        prompt_length = encoded.size(0)
    return encoded

_PAR_SPLIT = re.compile(r"\n\s*\n+")

def redpajama_paragraph_prompt(
    fabric,
    tokenizer,
    n = 512,
    source_path = None,
    filenames = None,
    glob_pattern = "*.jsonl",
    max_tries = 200,
):
    "Takes random paragraph from RedPajama dataset"
    
    if filenames is not None:
        files = sorted((source_path / name for name in filenames))
    else:
        files = sorted(source_path.glob(glob_pattern))
    files = [p for p in files if p.is_file()]
    
    def split_paragraphs(text):
        text = text.replace("\r\n", "\n")
        return [p.strip() for p in _PAR_SPLIT.split(text) if p.strip()]

    def random_sample_jsonl(path):
        try:
            size = os.path.getsize(path)
            with path.open("rb") as f:
                f.seek(random.randrange(size))
                f.readline()
                line = f.readline()
                if not line:
                    f.seek(0)
                    line = f.readline()
            if not line:
                return None
            row = line.decode().strip()
            if not row:
                return None
            obj = json.loads(row)
            return obj.get("text") or None
        except Exception:
            return None

    for _ in range(max_tries):
        path = random.choice(files)
        text = random_sample_jsonl(path)
        if not text:
            continue

        para = split_paragraphs(text)
        if not para:
            continue

        start = random.randrange(len(para))
        pieces = [para[start]]
        encoded = tokenizer.encode(pieces[0], device=fabric.device)

        if encoded.size(0) >= n:
            return encoded[:n]

        i = start + 1
        while encoded.size(0) < n and i < len(para):
            pieces.append(para[i])
            encoded = tokenizer.encode("\n\n".join(pieces), device=fabric.device)
            i += 1

        if encoded.size(0) >= n:
            return encoded[:n]
        
    
@torch.inference_mode()
def generate(
    model: GPT,
    prompt: torch.Tensor,
    max_returned_tokens: int,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    eos_id: Optional[int] = None,
) -> torch.Tensor:
    """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.

    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        prompt: Tensor of shape (T) with indices of the prompt sequence.
        max_returned_tokens: The maximum number of tokens to return (given plus generated).
        temperature: Scales the predicted logits by 1 / temperature.
        top_k: If specified, only sample among the tokens with the k highest probabilities.
        eos_id: If specified, stop generating any more token once the <eos> token is triggered.
    """
    T = prompt.size(0)
    assert max_returned_tokens > T
    if model.max_seq_length < max_returned_tokens - 1:
        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
        # not support it to avoid negatively impacting the overall speed
        raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

    device = prompt.device
    tokens = [prompt]
    input_pos = torch.tensor([T], device=device)
    token = next_token(
        model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k
    ).clone()
    tokens.append(token)
    for _ in range(2, max_returned_tokens - T + 1):
        token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k).clone()
        tokens.append(token)
        if token == eos_id:
            break
        input_pos = input_pos.add_(1)
    return torch.cat(tokens)


@torch.inference_mode()
def main(
    prompt: str = "What food do llamas eat?",
    *,
    num_samples: int = 1,
    max_new_tokens: int = 1,
    top_k: Optional[int] = 200,
    temperature: float = 1.0,
    checkpoint_dir: Path = Path("checkpoints/EleutherAI/pythia-410m"),
    quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
    precision: Optional[str] = None,
    compile: bool = False,
    logging_name: str = None,
    token_lengths: int = 512,
    seed: int = None,
    view_att = False,
    source_path: Path = Path("data/RedPajama-Data-1T-Sample")
) -> None:
    """Generates text samples based on a pre-trained model and tokenizer.

    Args:
        prompt: The prompt string to use for generating the samples.
        num_samples: The number of text samples to generate.
        max_new_tokens: The number of generation steps to take.
        top_k: The number of top most probable tokens to consider in the sampling process.
        temperature: A value controlling the randomness of the sampling process. Higher values result in more random
            samples.
        checkpoint_dir: The checkpoint directory to load.
        quantize: Whether to quantize the model and using which method:
            - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes
            - bnb.int8: 8-bit quantization from bitsandbytes
            for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
        precision: Indicates the Fabric precision setting to use.
        compile: Whether to compile the model.
        logging_name: Path to save attn_probs
        token_lengths: Length of random RedPajama paragraph
        seed: Random reproducibility
        view_att: Toggle to save the attn_probs during inference to logging_name
        source_path: Location of the RedPajama dataset for random paragraph extraction
    """
    L.seed_everything(seed)
    precision = precision or get_default_supported_precision(training=False)

    plugins = None
    if quantize is not None and quantize.startswith("bnb."):
        if "mixed" in precision:
            raise ValueError("Quantization and mixed precision is not supported.")
        dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
        plugins = BitsandbytesPrecision(quantize[4:], dtype)
        precision = None

    fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

    check_valid_checkpoint_dir(checkpoint_dir)

    config = Config.from_json(checkpoint_dir / "lit_config.json")

    checkpoint_path = checkpoint_dir / "lit_model.pth"

    tokenizer = Tokenizer(checkpoint_dir)
    encoded = redpajama_paragraph_prompt(
        fabric=fabric,
        tokenizer=tokenizer,
        n=token_lengths,
        source_path=source_path,
)
    
    print(f'Decoded prompt: {tokenizer.decode(encoded)}')
    prompt_length = encoded.size(0)
    max_returned_tokens = prompt_length + max_new_tokens

    fabric.print(f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", file=sys.stderr)
    t0 = time.perf_counter()
    with fabric.init_module(empty_init=True, ):
        model = GPT(config, p_mode = None, logging = True, view_att = view_att, logging_name = logging_name)
    fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
    with fabric.init_tensor():
        # set the max_seq_length to limit the memory usage to what we need
        model.max_seq_length = max_returned_tokens
        # enable the kv cache
        model.set_kv_cache(batch_size=1)
    model.eval()

    if compile:
        torch._dynamo.config.automatic_dynamic_shapes = True
        torch._inductor.config.triton.unique_kernel_names = True
        torch._inductor.config.coordinate_descent_tuning = True
        global next_token
        next_token = torch.compile(next_token, mode="reduce-overhead")

    model = fabric.setup_module(model)

    t0 = time.perf_counter()
    load_checkpoint(fabric, model, checkpoint_path, strict = True)
    fabric.print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
    fabric.print(f'Num Parameters: {num_parameters(model)}')
    for i in range(num_samples):
        t0 = time.perf_counter()
        y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
        t = time.perf_counter() - t0
        for block in model.transformer.h:
            block.attn.kv_cache.reset_parameters()
        fabric.print(tokenizer.decode(y))
        tokens_generated = y.size(0) - prompt_length
        fabric.print(
            f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr
        )
    if fabric.device.type == "cuda":
        fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)


if __name__ == "__main__":
    from jsonargparse import CLI

    torch.set_float32_matmul_precision("high")
    CLI(main)