import os
import json
import logging
from src.turtlegfx_datagen.utils.data_split import split_data

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def build_prompts_from_json(input_file, output_file, prompt_template, custom_function, max_num_items=200000, **kwargs):
    """
    Builds prompts from the input JSON file using the provided template and custom function,
    then saves the results to the output JSON file(s).

    Args:
        input_file (str): Path to the input JSON file.
        output_file (str): Path to the output JSON file.
        prompt_template (str): The template string with placeholders.
        custom_function (function): The custom function to process each item.
        **kwargs: Additional keyword arguments to pass to the custom_function.
    """
    logger.info(f"Opening input file {input_file}")
    with open(input_file, 'r') as f:
        data = json.load(f)

    output_list = []
    
    logger.info(f"Processing {len(data)} prompts")
    for item in data:
        results = custom_function(item, prompt_template, input_file, **kwargs)
        output_list.extend(results)
    
    if not kwargs.get('split', True):
        # only in one file, no need to split
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        with open(output_file, 'w') as f:
            json.dump(output_list, f)
        return

    # Split output if it's too large
    num_parts = (len(output_list) - 1) // max_num_items + 1
    split_outputs = split_data(output_list, num_parts)

    # Save prompts
    for i, part in enumerate(split_outputs):
        part_output_file = output_file.replace('.json', f'_part_{i}.json')
        os.makedirs(os.path.dirname(part_output_file), exist_ok=True)
        with open(part_output_file, 'w') as f:
            json.dump(part, f)
        logger.info(f"Part {i} saved to {part_output_file}")

    # Save statistics
    stats_path = output_file.replace('.json', '_stats.json')
    with open(stats_path, 'w') as file:
        json.dump({
            "n_prompts_total": len(data),
            "n_output_items": len(output_list),
            "n_output_parts": num_parts
        }, file, indent=2)
    logger.info(f"Statistics saved to {stats_path}")
