import json
import re
from copy import deepcopy   


def reconstruct_messages(template, query):
    # Replace the placeholder with the query
    filled_prompt = deepcopy(template)
    for prompt_dict in filled_prompt:
        prompt_dict['content'] = re.sub(
            r'\[INSERT_TEXT\]',
            query,
            prompt_dict['content']
        )
    return filled_prompt
        

def read_prompt_template(template_path, dataset=None):
    
    with open(template_path, "r") as f:
        template = json.load(f)

    if isinstance(template, dict):
        # old version of the template   
        if dataset:
            template = add_dataset_specific_samples(template, dataset)
        print("Loading prompt template for version", template['version'])

        # Create the system message
        messages = [{"role": "system", "content": template["instruction"]}]

        # Add ICL samples (user-assistant pairs)
        for sample in template["icl_samples_list"]:
            messages.append({"role": "user", "content": sample["user"]})
            messages.append({"role": "assistant", "content": sample["assistant"]})

        messages.append({"role": "user", "content": "[INSERT_TEXT]"})
        return messages

    elif isinstance(template, list):
        # new version of the template, do not need to construct the messages
        return template
    else:
        raise ValueError("Invalid template format.")    


def add_dataset_specific_samples(template, dataset):
    # TODO: robust
    # if dataset.lower() not in ["c4", "wikipedia"]:
    #     raise NotImplementedError(f"Dataset {dataset} is not supported.")
    if dataset.lower() in template.keys():  
        template['icl_samples_list'].extend(template[dataset.lower()])
    else:
        print(f"Dataset {dataset} is not in the template.")
    return template
