from src.llmutils.inference.build_prompts import build_prompts_from_json
import os
import argparse


def apply_prompt_template(item, prompt_template):
    """
    Apply the prompt template to the item with the given reference codes.

    Args:
        item (dict): A dictionary from the input JSON file.
        prompt_template (str): The template string with placeholders.

    Returns:
        prompt (str): The generated prompt.
    """
    prompt = prompt_template.replace('{model_output}', item.get('model_output', ''))
    prompt = prompt.replace('{task_id}', item.get('id', ''))
    return prompt


def custom_function(item, prompt_template, input_file, **kwargs):
    """
    Custom function to build prompts for a specific scenario.

    Args:
        item (dict): A dictionary from the input JSON file.
        prompt_template (str): The template string with placeholders.
        input_file (str): Path to the input file (for src_file in output).

    Returns:
        results (list): List of dictionaries with keys: 'id', 'messages', 'src_file', 'params'
    """
    if 'image_output' not in item or len(item['image_output']) == 0:
        return []
    
    id_from_input = item.get('id', '')
    prompt = apply_prompt_template(item, prompt_template)

    messages = [
        {
            "role"   : "system",
            "content": "You are a helpful assistant."
        },

        {
            "role"   : "user",
            "content": [  # for pixtral-large
                {
                    "type" : "image_url",
                    "image_url": {
                        "url": item.get('image_true', ''),
                    }
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": max(item.get('image_output', ['']), key=len),
                    }
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]

    result = {
        "id"      : id_from_input,
        "messages" : messages,
        "src_file": input_file,
        "params"  : {
            "data": item
        }
    }
    return [result]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate prompts from input JSON and template")

    parser.add_argument('--input_file', type=str, required=True, help='Path to input JSON file')
    parser.add_argument('--output_file', type=str, required=True, help='Path to output JSON file')
    parser.add_argument('--prompt_template', type=str, required=True, help='Prompt template string or file path')
    parser.add_argument('--max_num_items', type=int, default=150000, help='Maximum number of items per part')
    args = parser.parse_args()

    # Load the template (from a file or string)
    if os.path.isfile(args.prompt_template):
        with open(args.prompt_template, 'r') as f:
            prompt_template = f.read()
    else:
        prompt_template = args.prompt_template

    build_prompts_from_json(
        input_file=args.input_file,
        output_file=args.output_file,
        prompt_template=prompt_template,
        custom_function=custom_function,
        max_num_items=args.max_num_items,
        split=False
    )
