import fire
import os
import sys
import json
import torch
from vllm import LLM, SamplingParams
from accelerate.utils import is_xpu_available
from src.turtlegfx_datagen.utils.convert_inference_output import convert_inference_to_openai_format


def main(
        model_name: str = "mistralai/Pixtral-Large-Instruct-2411",
        max_new_tokens=256,
        min_new_tokens: int = 0,
        prompt_file: str = None,
        seed: int = 42,
        do_sample: bool = True,
        use_cache: bool = True,
        top_p: float = 1.0,
        temperature: float = 1.0,
        top_k: int = 50,
        repetition_penalty: float = 1.0,
        use_fast_kernels: bool = False,
        vllm_batch_size: int = 2,
        tensor_parallel_size: int = 4,
        output_path: str = None,
        **kwargs
):
    # Print the model being used
    print(f"Using model: {model_name}")

    # Load prompts
    if prompt_file is not None:
        if not os.path.exists(prompt_file):
            print(f"Provided prompt file does not exist: {prompt_file}")
            sys.exit(1)
        with open(prompt_file, 'r') as file:
            prompt_data = json.load(file)
            messages = [x['messages'] for x in prompt_data]

    print(f"# Prompts loaded: {len(messages)}")
    print("\n==================================\n")

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(seed)
    else:
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    # Prepare inputs
    inputs = []

    # Load the model with Pixtral-specific config
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tensor_parallel_size,
        max_num_seqs=vllm_batch_size,
        config_format="mistral",
        load_format="mistral",
        tokenizer_mode="mistral",
        limit_mm_per_prompt={"image": 4},
        trust_remote_code=True,
        max_model_len=16384,
        # max_model_len=8192,
    )

    # Prepare sampling parameters
    sampling_params = SamplingParams(
        top_p=top_p,
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )

    # Generate outputs using chat method for Pixtral
    outputs = llm.chat(messages, sampling_params=sampling_params)
    output_texts = [o.outputs[0].text for o in outputs]

    print(f"Generated {len(output_texts)} outputs")

    # Prepare generation parameters
    generation_params = {
        'top_p': top_p,
        'temperature': temperature,
        'top_k': top_k,
        'repetition_penalty': repetition_penalty,
        'max_new_tokens': max_new_tokens,
        'min_new_tokens': min_new_tokens,
        'do_sample': do_sample,
        'use_cache': use_cache,
        'seed': seed,
        'tensor_parallel_size': tensor_parallel_size,
        'max_num_seqs': vllm_batch_size,
    }

    # Convert outputs to OpenAI format and save
    convert_inference_to_openai_format(
        prompt_file=prompt_file,
        prompt_data=prompt_data,
        model_name=model_name,
        model_outputs=output_texts,
        save=True,
        output_path=output_path,
        generation_params=generation_params
    )


if __name__ == "__main__":
    fire.Fire(main) 