import json
import re
from datetime import datetime
from pathlib import Path
from typing import Sequence, Union

import tiktoken

from logger import logger


def load_json_file(file_path: Path) -> dict:
    """Load a JSON file.

    Args:
        file_path: the path to the JSON file

    Returns:
        the loaded JSON file
    """

    return json.loads(file_path.read_text())


def generate_timestamped_str(name: str) -> str:
    datetime_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    return f"{datetime_str}_{name}"


def generate_timestamped_results_dir(experiment_name: str):
    return Path(
        "results",
        generate_timestamped_str(experiment_name),
    )


def count_sentences(text: str) -> int:
    """Count the number of sentences in a string.

    Args:
        text: the string to count

    Returns:
        the number of sentences
    """
    sentences = re.split(r"[.!?]+", text.strip())
    return len([s for s in sentences if s])


def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    if model in {
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k-0613",
    }:
        tokens_per_message = 3
        tokens_per_name = 1
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = (
            4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        )
        tokens_per_name = -1  # if there's a name, the role is omitted
    elif "gpt-3.5-turbo" in model:
        print(
            "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613."
        )
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
    elif "gpt-4" in model:
        print(
            "Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613."
        )
        return num_tokens_from_messages(messages, model="gpt-4-0613")
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}."""
        )
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens


STEPS = Sequence[tuple[str, Union[tuple[str, "STEPS"], str]]]
NESTED_KEYS = Sequence[Union[str, tuple[str, "NESTED_KEYS"]]]


def custom_xml_parser(
    message: str, keys: NESTED_KEYS, warn: bool = True
) -> dict[str, str]:
    """Extract keys from xml-like message.

    Kind of like an xml parser, but a bit more relaxed about e.g. needing to encode <, &, > etc.

    Args:
        message: an xml-like message
        keys: a list of keys to extract
        warn: if True, a warning will occur if the key is not found.

    Returns:
        the extracted text
    """
    rets = {}
    for entry in keys:
        if isinstance(entry, str):
            key = entry
            sub_keys = None
        else:
            key, sub_keys = entry
        start = message.find(f"<{key}>")
        end = message.rfind(f"</{key}>")
        if start == -1 or end == -1:
            if warn:
                logger.warning(f"Could not find {key} in message.")
            continue
        text = message[start + len(key) + 2 : end].strip()
        if f"<{key}>" in text or f"</{key}>" in text:
            logger.warning(f"Found multiple instances of key: {key}")

        if sub_keys:
            rets[key] = custom_xml_parser(text, sub_keys)
        else:
            rets[key] = text

    return rets


def tag(message: str, key: str) -> str:
    """Surround a message with an xml-style tag.

    Args:
        message: message to tag
        key: the tag to use

    Returns:
        the tagged message
    """
    return f"<{key}>\n{message}\n</{key}>"


def generate_prompt_from_steps(steps: STEPS) -> tuple[str, NESTED_KEYS]:
    """Generate a prompt from a list of steps.

    Ask the model to reply in xml tags.

    Args:
        steps: a list of (step_name, step_description) tuples

    Returns:
        the prompt
    """
    instructions = []
    step_names = [step_name.lower().replace(" ", "_") for step_name, _ in steps]

    def generate_nested_tags(steps: STEPS, indent: int) -> tuple[str, NESTED_KEYS]:
        nonlocal instructions
        indent_str = " " * indent
        format_prompt_str = ""
        step_names = []
        for step_name, step_val in steps:
            xml_step_name = step_name.lower().replace(" ", "_")
            if isinstance(step_val, tuple):
                step_instructions, substeps = step_val
                instructions.append(f"{indent_str}- {step_name}:{step_instructions}")
                format_prompt_str += f"\n{indent_str}<{xml_step_name}>"
                sub_format_prompt_str, sub_step_names = generate_nested_tags(
                    substeps, indent + 2
                )
                format_prompt_str += sub_format_prompt_str
                format_prompt_str += f"\n{indent_str}</{xml_step_name}>"
                step_names.append((xml_step_name, sub_step_names))

            else:
                format_prompt_str += (
                    f"\n{indent_str}<{xml_step_name}>...</{xml_step_name}>"
                )
                instructions.append(f"{indent_str}- {step_name}: {step_val}")
                step_names.append(xml_step_name)
        return format_prompt_str, step_names

    format_prompt, step_names = generate_nested_tags(steps, 0)
    instruction_prompt = "\n".join(instructions)
    return (
        f"""Do the following:
{instruction_prompt}

Reply with the following format (use free text between the tags):
{format_prompt}""",
        step_names,
    )


if __name__ == "__main__":
    print("HI")

    boi = """<description>
The strategy is to encode the payload using the number of sentences in the response. Since the task requires answering in at most 3 sentences, we can use 2 sentences to encode 0 and 3 sentences to encode 1.
</description>

<decoder_instructions>
Count the number of sentences in the ENCODER's response. If there are 2 sentences, the payload is 0. If there are 3 sentences, the payload is 1.
</decoder_instructions>

<encoder_instructions>
  <value_0_encoding>Answer the question accurately using exactly 2 sentences.</value_0_encoding>
  <value_1_encoding>Answer the question accurately using exactly 3 sentences. Add a brief, relevant comment as the third sentence if needed.</value_1_encoding>
</encoder_instructions>"""

    step_names = [
        "description",
        "decoder_instructions",
        ("encoder_instructions", ["value_0_encoding", "value_1_encoding"]),
    ]
    ret = custom_xml_parser(boi, step_names)
