import pdb
import fire
import os
import sys
import json
import torch
from vllm import LLM, SamplingParams
from accelerate.utils import is_xpu_available
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info
from src.turtlegfx_datagen.utils.convert_inference_output import convert_inference_to_openai_format


def main(
        model_name,
        peft_model: str = None,
        quantization: str = None,  # Options: 4bit, 8bit
        max_new_tokens=256,  # The maximum numbers of tokens to generate
        min_new_tokens: int = 0,  # The minimum numbers of tokens to generate
        prompt_file: str = None,
        seed: int = 42,  # seed value for reproducibility
        do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
        use_cache: bool = True,
        # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
        top_p: float = 1.0,
        # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
        temperature: float = 1.0,  # [optional] The value used to modulate the next token probabilities.
        top_k: int = 50,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
        use_fast_kernels: bool = False,
        # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
        vllm_batch_size: int = 2,
        tensor_parallel_size: int = 4,
        output_path: str = None,  # save the output to a file
        **kwargs
):
    # Print the model being used
    print(f"Using model: {model_name}")

    # Load prompts
    if prompt_file is not None:
        if not os.path.exists(prompt_file):
            print(f"Provided prompt file does not exist: {prompt_file}")
            sys.exit(1)
        with open(prompt_file, 'r') as file:
            prompt_data = json.load(file)
        dialogs = [x['message'] for x in prompt_data]

    print(f"# User dialogs: {len(dialogs)}")
    print("\n==================================\n")

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(seed)
    else:
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    # Load the model
    llm = LLM(
        model=model_name,
        limit_mm_per_prompt={"image": 10, "video": 10},
        # max_model_len=16384,
        max_model_len=20000,
        tensor_parallel_size=tensor_parallel_size,
        max_num_seqs=vllm_batch_size
    )

    # Prepare sampling parameters
    sampling_params = SamplingParams(
        top_p=top_p,
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
    )

    # Initialize the processor
    processor = AutoProcessor.from_pretrained(model_name)

    # Prepare inputs
    inputs = []
    filtered_prompt_data = []
    for i, dialog in enumerate(dialogs):
        # Generate the prompt using the processor
        prompt = processor.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True)

        try:
            # Process vision info
            image_inputs, video_inputs = process_vision_info(dialog)
        except Exception as e:
            # you might have aspect ratio issues, skip the dialog
            print(f"Error processing vision info: {e}")
            continue

        mm_data = {}
        if image_inputs is not None:
            mm_data["image"] = image_inputs
        if video_inputs is not None:
            mm_data["video"] = video_inputs

        llm_inputs = {
            "prompt"          : prompt,
            "multi_modal_data": mm_data,
        }

        inputs.append(llm_inputs)
        # make sure the filtered_prompt_data is aligned with the inputs
        filtered_prompt_data.append(prompt_data[i])

    # For debug purposes, print the first input
    print(f"\n\nFOR DEBUG: First input: {inputs[0]}\n\n")

    # Generate outputs
    outputs = llm.generate(inputs, sampling_params=sampling_params)
    output_texts = [o.outputs[0].text for o in outputs]

    print(f"Generated {len(output_texts)} outputs")

    # Prepare generation parameters to include in the response data
    generation_params = {
        'top_p'             : top_p,
        'temperature'       : temperature,
        'top_k'             : top_k,
        'repetition_penalty': repetition_penalty,
        # 'length_penalty'    : length_penalty,
        'max_new_tokens'    : max_new_tokens,
        'min_new_tokens'    : min_new_tokens,
        'do_sample'         : do_sample,
        'use_cache'         : use_cache,
        'seed'              : seed,
        'tensor_parallel_size': tensor_parallel_size,
        'max_num_seqs'       : vllm_batch_size,
    }

    # Convert outputs to OpenAI format and save
    convert_inference_to_openai_format(
        prompt_file=prompt_file,
        prompt_data=filtered_prompt_data,
        model_name=model_name,
        peft_model=peft_model,
        model_outputs=output_texts,
        save=True,
        output_path=output_path,
        generation_params=generation_params
    )


if __name__ == "__main__":
    fire.Fire(main)
