from src.turtlegfx_datagen.prompts.build_prompts import build_prompts_from_json
from src.turtlegfx_datagen.utils.sample_from_list import sample_seeds
import os
import json
import argparse
import uuid


def apply_prompt_template(item, prompt_template, references):
    """
    Apply the prompt template to the item with the given reference codes.

    Args:
        item (dict): A dictionary from the input JSON file.
        prompt_template (str): The template string with placeholders.
        ref (list): List of reference codes to include in the prompt for the model to edit by looking at these codes.

    Returns:
        prompt (str): The generated prompt.
    """
    prompt = prompt_template.format(code_to_adapt=item.get('code', ''),
                                    reference_code_1=references[0],
                                    reference_code_2=references[1])
    return prompt


def custom_function(item, prompt_template, input_file, n_combine=2, n_combs_sample=10, seed_list=None):
    """
    Custom function to build prompts for a specific scenario.

    Args:
        item (dict): A dictionary from the input JSON file.
        prompt_template (str): The template string with placeholders.
        input_file (str): Path to the input file (for src_file in output).
        n_combine (int): Number of seeds to combine.
        n_combs_sample (int): Number of samples to generate per item.
        seed_list (list): List of reference codes to include in the prompt for the model to edit by looking at these codes.

    Returns:
        results (list): List of dictionaries with keys: 'id', 'messages', 'src_file', 'params'
    """
    # Extract necessary info from item
    id_from_input = item.get('id', '')

    # code_list = [x['code'] for x in seed_list]

    # Generate all specs combinations
    sampled_combs = sample_seeds(seed_list, n_combine, n_combs_sample)

    results = []
    for comb in sampled_combs:
        # Build the prompt using apply_prompt_template
        ref_codes = [x['code'] for x in comb]
        prompt = apply_prompt_template(item, prompt_template, ref_codes)

        # Build the messages format
        message = [
            {
                "role"   : "system",
                "content": "You are a helpful assistant that helps adapt code based on user instructions."
            },
            {
                "role"   : "user",
                "content": prompt
            }
        ]

        # Build result dictionary
        result = {
            "id"      : f"{id_from_input}--bldprom-{uuid.uuid4()}",
            "message" : message,
            "src_file": input_file,
            "params"  : {
                "n_combine"      : n_combine,
                "n_combs_sample" : n_combs_sample,
                # use ref_code id only, otherwise the fail gets too large
                "ref_codes"      : [x['id'] for x in comb],
                "base_code"      : item['id']  # Include the original item data
            }
        }

        results.append(result)

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate prompts from input JSON and template")

    parser.add_argument('--input_file', type=str, required=True, help='Path to input JSON file')
    parser.add_argument('--output_file', type=str, required=True, help='Path to output JSON file')
    parser.add_argument('--prompt_template', type=str, required=True, help='Prompt template string or file path')
    # params
    parser.add_argument('--n_combine', type=int, default=2, help='Number of seeds to combine')
    parser.add_argument('--n_combs_sample', type=int, default=7, help='Number of samples per input item')
    parser.add_argument('--seed_file', type=str, default=None, help='Path to seed dataset file that includes the reference codes')
    parser.add_argument('--max_num_items', type=int, default=100000, help='Maximum number of items per part')

    args = parser.parse_args()

    # Load the template (from a file or string)
    if os.path.isfile(args.prompt_template):
        with open(args.prompt_template, 'r') as f:
            prompt_template = f.read()
    else:
        prompt_template = args.prompt_template

    # Load the seed file
    if args.seed_file is not None:
        with open(args.seed_file, 'r') as f:
            seed_list = json.load(f)
    else:
        raise ValueError(f"Seed file {args.seed_file} not found")

    build_prompts_from_json(
        input_file=args.input_file,
        output_file=args.output_file,
        prompt_template=prompt_template,
        custom_function=custom_function,
        n_combine=args.n_combine,
        n_combs_sample=args.n_combs_sample,
        max_num_items=args.max_num_items,
        seed_list=seed_list
    )
