import json
import os
import time
import uuid
from typing import List, Optional, Union

import tiktoken  # For token counting
from src.turtlegfx_datagen.utils.extract_model_name import extract_model_name


def convert_inference_to_openai_format(
    prompt_data: Optional[List[dict]] = None,
    prompt_file: Optional[str] = None,
    model_name: Optional[str] = None,
    peft_model: Optional[str] = None,
    model_outputs: List[Union[str, List[str]]] = None,
    save: bool = True,
    output_path: Optional[str] = None,
    generation_params: dict = None
) -> List[dict]:
    """
    Converts model outputs to the OpenAI Chat Completion API response format, supporting multiple choices per prompt.

    Parameters:
    - prompt_data: List of prompt data dictionaries. If provided, `prompt_file` is ignored.
    - prompt_file: Path to the prompt file used for inference.
    - model_name: Name of the model used for inference (e.g., "meta-llama/Meta-Llama-3-8").
    - peft_model: Path or name of the PEFT model.
    - model_outputs: List of model outputs. Each element can be a single output (string) or a list of outputs.
    - save: Whether to save the outputs to a file.
    - output_path: Path to save the output files.
    - generation_params: Dictionary of generation parameters used (e.g., top_p, temperature, top_k, etc.)
    
    Format of the content of prompt_file:
    [
        {...},
        {...},
        ...
    ]

    The function reads the prompts from the prompt_file and combines them with the model outputs to create
    responses in the OpenAI API format.
    """
    assert prompt_file is not None, "prompt_file cannot be None"

    if prompt_data is None:
        with open(prompt_file, 'r') as file:
            prompt_data = json.load(file)

    responses = []

    # Initialize the tokenizer
    encoding = tiktoken.get_encoding("gpt2")

    num_prompts = len(prompt_data)
    if len(model_outputs) != num_prompts:
        raise ValueError("Number of model outputs does not match number of prompts.")

    for i in range(num_prompts):
        # NOTE: The following code is commented because it might cause the resource allocation problem with vllm. Basically, you cannot use AutoProcessor.from_pretrained(model_name)
        # prompt_messages = prompt_data[i]['message']  # List of messages

        # processor = AutoProcessor.from_pretrained(model_name)
        # prompt_text = processor.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)

        # Token counting for the prompt
        # prompt_tokens = len(encoding.encode(prompt_text))

        # Get outputs for the current prompt
        outputs = model_outputs[i]

        # Ensure outputs is a list (to handle single and multiple choices)
        if isinstance(outputs, str):
            outputs = [outputs]

        choices = []
        total_completion_tokens = 0

        for idx, output in enumerate(outputs):
            completion_tokens = len(encoding.encode(output))
            total_completion_tokens += completion_tokens

            choice = {
                "message"      : {
                    "role"   : "assistant",
                    "content": output
                },
                "logprobs"     : None,
                "finish_reason": "stop",
                "index"        : idx
            }
            choices.append(choice)

        # total_tokens = prompt_tokens + total_completion_tokens

        # Generate unique ID
        response_id = f"{prompt_data[i]['id']}--chatcmpl-{uuid.uuid4()}"

        # Build the model name, including peft_model if provided
        if peft_model:
            # Extract base names and make them concise
            base_model_name = extract_model_name(model_name).split("/")[-1]
            peft_model_name = os.path.basename(peft_model)
            combined_model_name = f"{base_model_name}+{peft_model_name}"
        else:
            combined_model_name = extract_model_name(model_name).split("/")[-1] if model_name else "unknown-model"

        # Build the response dictionary
        response = {
            "id"     : response_id,
            "prompt_id": prompt_data[i]['id'],
            "src_file": prompt_file,
            "object" : "chat.completion",
            "created": int(time.time()),
            "model"  : combined_model_name,
            "usage"  : {
                # "prompt_tokens"            : prompt_tokens,
                "completion_tokens"        : total_completion_tokens,
                # "total_tokens"             : total_tokens,
                "completion_tokens_details": {
                    "reasoning_tokens": 0
                }
            },
            "choices": choices
        }

        # Add generation parameters to the response if provided
        if generation_params:
            response['generation_parameters'] = generation_params

        responses.append(response)

    if save and output_path is not None:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w') as f:
            json.dump(responses, f, indent=2)
        print(f"Results saved to {output_path}")

    return responses
