import fire
import os
import sys
import json
import torch
from vllm import LLM, SamplingParams
from src.turtlegfx.utils.base64img import convert_base64_to_img
from accelerate.utils import is_xpu_available
from src.turtlegfx_datagen.utils.convert_inference_output import convert_inference_to_openai_format
from transformers import AutoTokenizer


def get_model_config(model_name: str, tokenizer):
    """Get model-specific configurations."""
    if "aria" in model_name.lower():
        config = {
            "tokenizer_mode": "slow",
            "dtype": "bfloat16",
            "trust_remote_code": True
        }
        stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
    
    elif "internvl2" in model_name.lower():
        config = {
            "trust_remote_code": True,
            "max_model_len": 4096,
        }
        stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
        stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
    elif "molmo" in model_name.lower():
        config = {
            "trust_remote_code": True,
            "max_model_len": 4096,
        }
        stop_token_ids = None
    elif "llava" in model_name.lower():
        config = {
            "trust_remote_code": True,
            "max_model_len": 12282,
        }
        stop_token_ids = None
    elif "deepseek_vl2" in model_name.lower():
        config = {
            "max_model_len": 4096,
            "hf_overrides": {"architectures": ["DeepseekVLV2ForCausalLM"]},
            "trust_remote_code": True
        }
    else:
        config = {
            "trust_remote_code": True,
            "max_model_len": 8192,
        }
        stop_token_ids = None

    return config, stop_token_ids


def main(
        model_name: str = None,
        peft_model: str = None,
        quantization: str = None,  # Options: 4bit, 8bit
        max_new_tokens=256,  # The maximum numbers of tokens to generate
        min_new_tokens: int = 0,  # The minimum numbers of tokens to generate
        prompt_file: str = None,
        seed: int = 42,  # seed value for reproducibility
        do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
        use_cache: bool = True,
        top_p: float = 1.0,
        temperature: float = 1.0,  # The value used to modulate the next token probabilities.
        top_k: int = 50,  # The number of highest probability vocabulary tokens to keep for top-k-filtering.
        repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
        use_fast_kernels: bool = False,
        vllm_batch_size: int = 2,
        tensor_parallel_size: int = 4,
        output_path: str = None,  # save the output to a file
        **kwargs
):
    # Print the model being used
    print(f"Using model: {model_name}")

    # Load prompts
    if prompt_file is not None:
        if not os.path.exists(prompt_file):
            print(f"Provided prompt file does not exist: {prompt_file}")
            sys.exit(1)
        with open(prompt_file, 'r') as file:
            prompt_data = json.load(file)
        

    print(f"# Prompts loaded: {len(prompt_data)}")
    print("\n==================================\n")

    # Set the seeds for reproducibility
    if is_xpu_available():
        torch.xpu.manual_seed(seed)
    else:
        torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    # Prepare inputs
    inputs = []
    filtered_prompt_data = []
    for i, item in enumerate(prompt_data):
        try:
            prompt = item['prompt']

            # Get image data and code from the prompt
            base64image = item.get('task_image')
            image_data = convert_base64_to_img(base64image)

            if not image_data:
                continue

            llm_inputs = {
                "prompt": prompt,
                "multi_modal_data": {
                    "image": image_data
                },
            }

            inputs.append(llm_inputs)
            filtered_prompt_data.append(item)

        except Exception as e:
            print(f"Error processing item {i}: {e}")
            continue

    # Load the model with model-specific configurations
    model_config, stop_token_ids = get_model_config(
        model_name=model_name,
        tokenizer=tokenizer,
    )
    
    llm = LLM(
        model=model_name,
        tensor_parallel_size=tensor_parallel_size,
        max_num_seqs=vllm_batch_size,
        **model_config
    )

    # Prepare sampling parameters
    sampling_params = SamplingParams(
        top_p=top_p,
        temperature=temperature,
        max_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        stop_token_ids=stop_token_ids
    )

    # For debug purposes, print the first input
    print(f"\n\nFOR DEBUG: First input: {inputs[0]}\n\n")

    # Generate outputs
    outputs = llm.generate(inputs, sampling_params=sampling_params)
    output_texts = [o.outputs[0].text for o in outputs]

    print(f"Generated {len(output_texts)} outputs")

    # Prepare generation parameters to include in the response data
    generation_params = {
        'top_p': top_p,
        'temperature': temperature,
        'top_k': top_k,
        'repetition_penalty': repetition_penalty,
        'max_new_tokens': max_new_tokens,
        'min_new_tokens': min_new_tokens,
        'do_sample': do_sample,
        'use_cache': use_cache,
        'seed': seed,
        'tensor_parallel_size': tensor_parallel_size,
        'max_num_seqs': vllm_batch_size,
    }

    # Convert outputs to OpenAI format and save
    convert_inference_to_openai_format(
        prompt_file=prompt_file,
        prompt_data=filtered_prompt_data,
        model_name=model_name,
        peft_model=peft_model,
        model_outputs=output_texts,
        save=True,
        output_path=output_path,
        generation_params=generation_params
    )


if __name__ == "__main__":
    fire.Fire(main)
