# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import sys
import time
import os
from pathlib import Path
from typing import Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config

import nvtx

from transformers import AutoTokenizer
from APLinear import APLinear
from LUTGEMMLinear import LUTGEMMLinear

import warnings

from model_predef import Transformer

warnings.filterwarnings(
    "ignore", 
    category=FutureWarning
)

def device_sync(device):
    if "cuda" in device:
        torch.cuda.synchronize(device)
    elif ("cpu" in device) or ("mps" in device):
        pass
    else:
        print(f"device={device} is not yet suppported")


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.fx_graph_cache = True 
#torch._functorch.config.enable_autograd_cache = True

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits=logits.float()
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]


def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
    # input_pos: [B, 1]
    assert input_pos.shape[-1] == 1
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)

def decode_one_token_inplace(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, next_token: torch.Tensor, next_prob: torch.Tensor, **sampling_kwargs):
    next_token[...], next_prob[...] = decode_one_token(model, x, input_pos, **sampling_kwargs)


def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, use_graph=False, callback=lambda _: _, **sampling_kwargs):
    # WARNING: DO NOT pass use_graph=True for the first invocation after torch.compile, as torch.compile will only successfully compile with use_graph=False
    # if you're using torch.compile without CUDA graphs, you may set use_graph to True after the first invocation to speed up subsequent calls
    if use_graph:
        # Allocate static input tensors
        static_cur_token = cur_token.clone()
        static_input_pos = input_pos.clone()

        # Set requires_grad=False for inference
        static_cur_token.requires_grad = False
        static_input_pos.requires_grad = False

        # Pre-allocate static output tensors
        static_next_token = torch.empty_like(static_cur_token)
        static_next_prob = torch.empty((1, model.config.vocab_size), dtype=torch.float32, device='cuda')

        graph = torch.cuda.CUDAGraph()
        with nvtx.annotate("Capturing CUDA graph", color='cyan'):
            with torch.cuda.graph(graph):
                decode_one_token_inplace(
                    model, static_cur_token, static_input_pos, static_next_token, static_next_prob, **sampling_kwargs
                )
            torch.cuda.synchronize()

    new_tokens, new_probs = [], []
    with nvtx.annotate("Decoding tokens", color='orange'):
        for i in range(num_new_tokens):
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=True):
                if use_graph:
                    # Update inputs in-place
                    static_cur_token.copy_(cur_token)
                    static_input_pos.copy_(input_pos)
                    graph.replay()
                    # Retrieve outputs from static output tensors
                    next_token = static_next_token.clone()
                    next_prob = static_next_prob.clone()
                else:
                    next_token, next_prob = decode_one_token(
                        model, cur_token, input_pos, **sampling_kwargs
                    )
                
                # Update position and tokens
                input_pos += 1
                new_tokens.append(next_token.clone())
                callback(new_tokens[-1])
                new_probs.append(next_prob.clone())
                cur_token = next_token.clone()
        torch.cuda.synchronize()

    return new_tokens, new_probs


def model_forward(model, x, input_pos):
    return model(x, input_pos)

def speculative_decode(
    model: Transformer,
    draft_model: Transformer,
    cur_token: torch.Tensor,
    input_pos: int,
    speculate_k: int,
    **sampling_kwargs
) -> torch.Tensor:
    # draft model inference sequentially
    device = cur_token.device
    orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
    draft_tokens, draft_probs, _ = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)

    draft_tokens = torch.cat(draft_tokens)
    # parallel inference on target model using draft tokens
    target_logits = model_forward(
        model,
        torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
        torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
    )
    target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
    draft_probs = torch.stack(draft_probs)
    # q: target prob, p: draft prob
    # q >= p: always accept draft token
    # q < p: q/p prob to accept draft token
    p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
    q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
    accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
    rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()

    if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
        accept_length = speculate_k + 1
        last_token = multinomial_sample_one_no_sync(target_probs[-1])
        # fill last token into draft model
        model_forward(
            draft_model,
            draft_tokens[-1].view(1, -1),
            orig_input_pos + speculate_k,
        )
        return torch.cat([draft_tokens, last_token])
    else:
        accept_length = rejected_locations[0].item()
        p = draft_probs[accept_length]
        q = target_probs[accept_length]
        new = q - p
        new = torch.where(new > 0, new, 0.0)
        new = new / new.sum()
        next_token = multinomial_sample_one_no_sync(new)
        return torch.cat([draft_tokens[:accept_length], next_token])

@torch.no_grad()
def generate(
    model: Transformer,
    prompt: torch.Tensor,
    max_new_tokens: int,
    batch_size: int,
    callback = lambda x: x,
    **sampling_kwargs
) -> torch.Tensor:
    """
    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    """

    # create an empty tensor of the expected final shape and fill in the current tokens
    T = prompt.size(-1)
    T_new = T + max_new_tokens
    max_seq_length = min(T_new, model.config.block_size)

    device, dtype = prompt.device, prompt.dtype
    max_seq_length = max_seq_length
    with torch.device(device):
        model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)

    # create an empty tensor of the expected final shape and fill in the current tokens
    empty = torch.empty(batch_size, T_new, dtype=dtype, device=device)
    # We are just making the same prompt for every batch
    prompt = prompt.view(1, -1).repeat(batch_size, 1)
    empty[:, :T] = prompt
    seq = empty
    input_pos = torch.arange(0, T, device=device, dtype=torch.int32)

    """
    next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
    import pdb; pdb.set_trace()
    seq[:, T] = next_token.squeeze()

    input_pos = torch.tensor([T], device=device, dtype=torch.int).view(1)

    ###
    torch.cuda.reset_peak_memory_stats()
    generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1),
                          input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
    peak_memory = torch.cuda.max_memory_allocated()
    print(f"Peak memory allocated during decoding: {peak_memory / (1024**2):.2f} MB")
    ###
    seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1)
    """
    
    ###
    torch.cuda.reset_peak_memory_stats()
    generated_tokens, new_probs = decode_n_tokens(model, prompt.view(batch_size, -1),
                          input_pos, max_new_tokens, callback=callback, **sampling_kwargs)
    peak_memory = torch.cuda.max_memory_allocated()
    #print(f"Peak memory allocated during decoding: {peak_memory / (1024**2):.2f} MB")
    ###
    seq[:, 1:] = torch.cat(generated_tokens, dim=-1)

    return seq

def encode_tokens(tokenizer, string, device=default_device):
    tokens = tokenizer.encode(string)
    return torch.tensor(tokens, dtype=torch.int, device=device)

def encode_bos(tokenizer, device=default_device):
    return torch.tensor([tokenizer.bos_token_id], dtype=torch.int, device=device)

def load_model(model_name, device, quant_algo,  
                bitwidth, n_tb, k_chunk_list, dtype=None, halve_layers=False):
    use_cuda = 'cuda' in device

    checkpoint_path = f"../eval/{quant_algo}_cache/{model_name.split('/')[-1]}-w{bitwidth}"
    model_path = os.path.join(checkpoint_path, "converted_pytorch_model.bin")
    dec_path = os.path.join(checkpoint_path, "cheatsheet.bin")

    linear_kwargs = {"bitwidth": bitwidth}

    match quant_algo:
        case "sqllm":
            linear_class = APLinear
        case "awq":
            linear_class = LUTGEMMLinear
            linear_kwargs['group_size'] = 128

    print("Building model ...", flush=True)
    model = Transformer.from_name(
        name=model_name, dtype=dtype,
        linear_class=linear_class,
        linear_kwargs=linear_kwargs,
        halve_layers=halve_layers,
    )

    print("Loading weights ...", flush=True)
    #model = model.to_empty(device='cpu')
    # checkpoint = torch.load(model_path, mmap=True, weights_only=True)
    # model.load_state_dict(checkpoint, assign=True, strict=False if halve_layers else True)

    print("Dispatching model to device ...", flush=True)
    model=model.to(device=device, dtype=dtype)

    # # load & set up DEC
    # print("Loading residuals ...", flush=True)
    # buffer_size = 8 * 1024
    # dec_data = torch.load(dec_path, weights_only=True)
    # model.load_dec_data(dec_data)
    # model.create_dec_context(n_tb, buffer_size)
    # model.set_dec_config(k_chunk_list)

    # tokenizer = AutoTokenizer.from_pretrained(os.path.dirname(model_path))
    tokenizer = AutoTokenizer.from_pretrained("/work/models/llama-3-8b-hf")
   
    print("Model loaded.", flush=True)
    return model.eval(), tokenizer

def _get_model_size(model):
    model_size = 0
    params = 0
    for name, child in model.named_children():
        if not isinstance(child, torch.nn.Embedding):
            model_size += sum(
                [
                    p.numel() * p.dtype.itemsize
                    for p in itertools.chain(child.parameters(), child.buffers())
                ]
            )
            params += sum(
                [
                    p.numel()
                    for p in itertools.chain(child.parameters(), child.buffers())
                ]
            )
            print(name)
            size = sum(p.numel() * p.dtype.itemsize for p in itertools.chain(child.parameters(), child.buffers()))
            print(size/1024/1024/1024)

    return model_size, params

B_INST, E_INST = "[INST]", "[/INST]"

def main(
    prompt: Union[int, str] = "Hello, my name is",
    interactive: bool = False,
    num_samples: int = 5,
    max_new_tokens: int = 100,
    batch_size: int = 1,
    top_k: int = 200,
    temperature: float = 0.8,
    compile: int = 2,
    compile_prefill: bool = False,
    profile: Optional[Path] = None,
    device=default_device,
    model_name = None,
    quant_algo = None,
    bitwidth = None,
    dtype = None,
    n_tb = None,
    k = None,
    print_result = False
) -> None:
    """Generates text samples based on a pre-trained Transformer model and tokenizer.
    """
    print(f"Using device={device}")
    if (dtype == "float16"): 
        dtype = torch.float16
    elif (dtype == "bfloat16"): 
        dtype = torch.bfloat16
    elif (dtype == "float32"): 
        dtype = torch.float32

    t0 = time.time()
    model, tokenizer = load_model(model_name, device, quant_algo,
                                    bitwidth, n_tb, k, dtype)

    device_sync(device=device) # MKG
    print(f"Time to load model: {time.time() - t0:.02f} seconds", flush=True)

    # encode prompt
    #encoded = encode_tokens(tokenizer, prompt, device=device)
    encoded = encode_bos(tokenizer, device=device)
    prompt_length = encoded.size(-1)

    torch.manual_seed(1234)
    model_size, params = _get_model_size(model)

    if compile:
        global decode_one_token, prefill
        mode = 'max-autotune-no-cudagraphs' if compile == 1 else 'max-autotune'
        decode_one_token = torch.compile(decode_one_token, mode=mode, fullgraph=True, dynamic=False)
        # Uncomment to squeeze more perf out of prefill
        if compile_prefill:
            prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
        
    aggregate_metrics = {
        'tokens_per_sec': [],
        'accept_counts': [],
    }
    start = -1 if compile else 0

    for i in range(start, num_samples):
        callback = lambda x : x
        
        torch.cuda.synchronize()
        t0 = time.perf_counter()

        import contextlib
        if (i != num_samples - 1 or not profile):
            prof = contextlib.nullcontext()
        else:
            torch.profiler._utils._init_for_cuda_graphs()
            prof = torch.profiler.profile()
        with prof:
            y = generate(
                model,
                encoded,
                max_new_tokens,
                batch_size=batch_size,
                callback=callback,
                temperature=temperature,
                top_k=top_k,
            )
        if i == -1:
            print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds", flush=True)
            continue

        torch.cuda.synchronize()
        time_elapsed = time.perf_counter() - t0

        if print_result:
            print(tokenizer.decode(y[0].tolist()))
        
        tokens_generated = y.size(-1) - prompt_length
        generated_tokens_sec = tokens_generated / time_elapsed
        aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec)
        print(f"Time for inference {i + 1}: {time_elapsed:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec")
        print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s")
        total_tokens_sec = y.numel() / time_elapsed
        print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s")
        print(flush=True)
    print("==========")

    print(f"Batch Size: {batch_size}")
    print(f"Prompt Length: {prompt_length}")
    print(f"Generated tokens: {max_new_tokens}")
    print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")

    with open("config/bsel_result.csv", "a+") as wf:
        with open("config/bsel_config.txt") as rf:
            bsel_config = rf.readline().strip("\n")
            wf.write(f"{bsel_config},{torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}\n")


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Your CLI description.')

    def int_or_str(x):
        try:
            return int(x)
        except:
            return x

    parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.")
    parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
    parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
    parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
    parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
    parser.add_argument('--temperature', type=float, default=0.0, help='Temperature for sampling.')
    parser.add_argument('--compile', type=int, default=2, help='Whether to compile the model')
    parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
    parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
    parser.add_argument('--device', type=str, default=default_device, help='Device to use')
    parser.add_argument('--model_name', type=str, default=None, help='model_name')
    parser.add_argument('--bitwidth', type=int, default=None, help='bitwidth', choices=[3,4,5,6])
    parser.add_argument('--dtype', type=str, default="float16", help='dtype', choices=["float16", "float32", "bfloat16"])
    parser.add_argument('--quant_algo', type=str, default=None, help='quantization algorithm to use', choices=["sqllm", "awq"])
    parser.add_argument('--n_tb', type=int, default=None, help='n_tb')
    parser.add_argument(
                '--k',
                nargs='+',  # '+' means one or more arguments
                type=int,   # Type conversion to integer
                help='A list of num rows per fetch per chunk [qkv, o, gate, up, down]'
                )
    parser.add_argument('--print_result', action='store_true')

    args = parser.parse_args()
    main(
        args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
        args.temperature, args.compile, args.compile_prefill, args.profile, 
        args.device, args.model_name, args.quant_algo, args.bitwidth, args.dtype, args.n_tb, args.k, args.print_result
    )

