# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.


import pdb
from src.turtlegfx_datagen.utils.convert_inference_output import convert_inference_to_openai_format
import fire
import json
import os
import sys

import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from accelerate.utils import is_xpu_available


def main(
        model_name,
        peft_model: str = None,
        quantization: str = None,  # Options: 4bit, 8bit
        max_new_tokens=256,  # The maximum numbers of tokens to generate
        min_new_tokens: int = 0,  # The minimum numbers of tokens to generate
        prompt_file: str = None,
        seed: int = 42,  # seed value for reproducibility
        do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
        use_cache: bool = True,
        # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
        top_p: float = 1.0,
        # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
        temperature: float = 1.0,  # [optional] The value used to modulate the next token probabilities.
        top_k: int = 50,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
        use_fast_kernels: bool = False,
        # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
        vllm_batch_size: int = 2,
        tensor_parallel_size: int = 2,
        output_path: str = None,  # save the output to a file
        **kwargs
):
    # print the model being used
    print(f"Using model: {model_name}")

    # Load prompts
    if prompt_file is not None:
        if not os.path.exists(prompt_file):
            print(f"Provided prompt file does not exist: {prompt_file}")
            sys.exit(1)
        with open(prompt_file, 'r') as file:
            prompt_data = json.load(file)
        dialogs = [x['message'] for x in prompt_data]

    elif not sys.stdin.isatty():
        # Read from stdin
        dialogs = "\n".join(sys.stdin.readlines())
        try:
            dialogs = json.loads(dialogs)
        except json.JSONDecodeError:
            print("Could not parse JSON from stdin. Please provide a JSON file with the user prompts.")
            sys.exit(1)
    else:
        print("No user prompt provided. Exiting.")
        sys.exit(1)

    print(f"# User dialogs: {len(dialogs)}")
    print("\n==================================\n")

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(seed)
    else:
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    # Step 1: Tokenize the input
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    try:
        chats = tokenizer.apply_chat_template(dialogs, tokenize=False)
    except Exception as e:
        print(f"Error tokenizing the input: {e}")
        sys.exit(1)

    # Step 2: Load the model
    if peft_model:
        model = LLM(model=model_name, 
                    tensor_parallel_size=tensor_parallel_size, 
                    max_model_len=20000,
                    max_num_seqs=vllm_batch_size)
    else:
        model = LLM(model=model_name, 
                    tensor_parallel_size=tensor_parallel_size, 
                    max_model_len=20000,
                    max_num_seqs=vllm_batch_size)

    if use_fast_kernels:
        """
        Setting 'use_fast_kernels' will enable
        using of Flash Attention or Xformer memory-efficient kernels 
        based on the hardware being used. This would speed up inference when used for batched inputs.
        """
        try:
            from optimum.bettertransformer import BetterTransformer
            model = BetterTransformer.transform(model)
        except ImportError:
            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

    # Generate the output
    sampling_params = SamplingParams(
        top_p=top_p,
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
    )
    if peft_model:
        outputs = model.generate(chats, sampling_params, lora_request=LoRARequest("lora_request", 1, peft_model))
    else:
        outputs = model.generate(chats, sampling_params)

    output_texts = [o.outputs[0].text for o in outputs]

    print(f"Generated {len(output_texts)}")

    # Prepare generation parameters to include in the response data
    generation_params = {
        'top_p'             : top_p,
        'temperature'       : temperature,
        'top_k'             : top_k,
        'repetition_penalty': repetition_penalty,
        # 'length_penalty'    : length_penalty,
        'max_new_tokens'    : max_new_tokens,
        'min_new_tokens'    : min_new_tokens,
        'do_sample'         : do_sample,
        'use_cache'         : use_cache,
        'seed'              : seed,
        'tensor_parallel_size': tensor_parallel_size,
        'max_num_seqs'       : vllm_batch_size,
    }

    print("Finished! Starting saving chat completion...")

    # Convert outputs to OpenAI format and save
    convert_inference_to_openai_format(
        prompt_file=prompt_file,
        model_name=model_name,
        peft_model=peft_model,
        model_outputs=output_texts,
        save=True,
        output_path=output_path,
        generation_params=generation_params
    )


if __name__ == "__main__":
    fire.Fire(main)
