from src.xlogomini.utils.load_data import load_task_json, load_code_json
from src.xlogomini.utils.enums import XLOGO_TASK_IDS
import json
from src.xlogominiprog.translator import taskjs2ascii, taskjs2nl, codejs2python
from src.xlogominiprog.prompts.prompt_template import PROMPT_TEMPLATE_NL_BASE_MODEL, \
    PROMPT_TEMPLATE_NL_BASE_MODEL_VISION, \
    PROMPT_TEMPLATE_ASCII_BASE_MODEL
from src.xlogomini.utils.image_conversions import task2image


def build_task_prompt(dataset_path, template='nl'):
    if template == 'ascii':
        PROMPT_TEMPLATE = PROMPT_TEMPLATE_ASCII_BASE_MODEL
    elif template == 'nl':
        PROMPT_TEMPLATE = PROMPT_TEMPLATE_NL_BASE_MODEL
    elif template == 'nl-vision':
        PROMPT_TEMPLATE = PROMPT_TEMPLATE_NL_BASE_MODEL_VISION
    else:
        raise ValueError(f"Template {template} not supported")

    task_jsons, code_jsons, cons_jsons = [], [], []

    dataset = json.load(open(dataset_path, 'r'))
    for sample in dataset:
        task_jsons.append(sample['task_json'])
        code_jsons.append(sample['code_json'])
        cons_jsons.append(sample['constraints'])

    prompts = [PROMPT_TEMPLATE.format_map({
        "description": task_json['description'],
        "task"       : taskjs2ascii(task_json) if template == 'ascii' else taskjs2nl(task_json),
        "code"       : codejs2python(code_json)}) for task_json, code_json in zip(task_jsons, code_jsons)]

    if template == 'nl-vision':
        image_urls = [
            f"data:image/png;base64,{task2image(task, show=False, save=False, show_desc=False, return_base64=True)}" for
            task in task_jsons]
        return {
            "prompts"   : prompts,
            "task_jsons": task_jsons,
            "code_jsons": code_jsons,
            "cons_jsons": cons_jsons,
            "image_urls": image_urls
        }
    return {
        "prompts"   : prompts,
        "task_jsons": task_jsons,
        "code_jsons": code_jsons,
        "cons_jsons": cons_jsons
    }


def build_task_prompt_python():
    pass


def convert_to_chat_format(data, vision=False):
    """
    `data` is a dictionary with keys:
        - `prompts`,
        - `task_jsons`
        - `code_jsons`
        - `cons_jsons`
        - `image_urls` (if vision=True)
    """
    chat_prompts = []
    for i in range(len(data['prompts'])):
        if not vision:
            chat_prompts.append([{"role": "user", "content": data['prompts'][i]}])
        else:
            chat_prompts.append([{
                "role"   : "user",
                "content": [
                    {
                        "type": "text",
                        "text": data['prompts'][i]
                    },
                    {
                        "type"     : "image_url",
                        "image_url": {
                            "url": data['image_urls'][i]
                        }
                    }
                ]
            }])
    return chat_prompts


if __name__ == "__main__":
    # template = 'nl'
    template = 'nl-vision'

    dataset = 'v0-test-85'
    # dataset = 'v1-test-1k'

    if template == 'nl':
        data = build_task_prompt(dataset, template='nl')
        prompts = convert_to_chat_format(data)
    elif template == 'ascii':
        data = build_task_prompt(dataset, template='ascii')
        prompts = convert_to_chat_format(data)
    elif template == 'nl-vision':
        data = build_task_prompt(dataset, template='nl-vision')
        prompts = convert_to_chat_format(data, vision=True)
    else:
        raise ValueError(f"Template {template} not supported")

    data['prompts'] = prompts
    save_file = f"./chat_prompts_{template}_{dataset}.json"
    json.dump(data, open(save_file, "w"), indent=4)
    print(f"Chat prompts saved to {save_file}")
