import json
import random
from copy import deepcopy
from typing import List, Tuple

import jsonlines
from absl import flags

flags.DEFINE_integer(
    "landmark_n_garbage_chars", 0, "num of noise chars in landmark task"
)
flags.DEFINE_integer("lrt_n_lines", 2, "num of lines in the lrt benchmark")
flags.DEFINE_string("lrt_test_file", "200_lines.jsonl", "path to the lrt test file")
flags.DEFINE_string(
    "lim_test_file", "kv-retrieval-75_keys.jsonl.gz", "path to the lim test file"
)
flags.DEFINE_integer("lim_n_lines", None, "num of lines for lost in the middle")
flags.DEFINE_integer("lim_gold_index", None, "index of the golden passage")
flags.DEFINE_boolean("lim_query_aware", False, "ask about the key at the beginning")
flags.DEFINE_string("lim_prompt_path", "kv_retrieval.prompt", "path to the lim prompt")

FLAGS = flags.FLAGS


def generate_prompt_landmark(n_garbage):
    """Generates a text file and inserts an execute line at a random position."""
    n_garbage_prefix = random.randint(0, n_garbage)
    n_garbage_suffix = n_garbage - n_garbage_prefix

    task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
    garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
    garbage_inf = " ".join([garbage] * 2000)
    assert len(garbage_inf) >= n_garbage
    garbage_prefix = garbage_inf[:n_garbage_prefix]
    garbage_suffix = garbage_inf[:n_garbage_suffix]
    pass_key = random.randint(1, 50000)
    information_line = (
        f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
    )
    final_question = "What is the pass key? The pass key is"
    lines = [
        task_description,
        garbage_prefix,
        information_line,
        garbage_suffix,
        final_question,
    ]
    return "\n".join(lines), pass_key


def prompt_landmark_formatted():
    prompt, answer = generate_prompt_landmark(n_garbage=FLAGS.landmark_n_garbage_chars)
    answer = f" {answer}"
    return prompt, answer


def generate_prompt_longeval_lrt(idx=[0]):
    def load_jsonl(file_path):
        data = []

        with jsonlines.open(file_path) as reader:
            for line in reader:
                data.append(line)

        return data

    dct = load_jsonl(FLAGS.lrt_test_file)

    index = idx[-1]
    idx.pop()
    idx.append(index + 1)
    return get_prompt_and_answer(dct[index])


def get_prompt_and_answer(data_dict):
    correct_line = data_dict["correct_line"]

    prompt_lines = data_dict["prompt"].split("\n")
    # TODO: super important!! without that it has v. bad perf
    prompt_lines = prompt_lines[:-2]  # remove final line and boilerplate
    prompt_lines_merged = "\n".join(prompt_lines)

    split_parts = correct_line.split("<")
    prompt = prompt_lines_merged + "\n" + split_parts[0].strip() + " "
    answer = "<" + split_parts[1].strip()

    return prompt, answer


def get_kv_retrieval_prompt(
    data: List[Tuple[str, str]],
    key: str,
):
    if not data:
        raise ValueError(f"Provided `data` must be truthy, got: {data}")
    if not key:
        raise ValueError(f"Provided `key` must be truthy, got: {key}")
    if key not in [x[0] for x in data]:
        raise ValueError(f"Did not find provided `key` {key} in data {data}")
    if len(data) != len(set([x[0] for x in data])):
        raise ValueError(f"`data` has duplicate keys: {data}")
    if len(data) < 2:
        raise ValueError(f"Must have at least 2 items in data: {data}")

    with open(FLAGS.lim_prompt_path) as f:
        prompt_template = f.read().rstrip("\n")

    # Format the KV data into a string
    formatted_kv_records = ""
    for index, record in enumerate(data):
        start_character = "{" if index == 0 else " "
        data_string = f'"{record[0]}": "{record[1]}"'
        end_character = ",\n" if index != len(data) - 1 else "}"
        formatted_kv_records += start_character + data_string + end_character

    return prompt_template.format(formatted_kv_records=formatted_kv_records, key=key)


def generate_prompt_stanford(gold_index_kv_examples=[], example_idx=[0]):
    from xopen import xopen

    MAX_LINES = FLAGS.lim_n_lines
    input_path = FLAGS.lim_test_file
    gold_index = FLAGS.lim_gold_index

    if not gold_index_kv_examples:
        with xopen(input_path) as fin:
            for loop_it, line in enumerate(fin):
                input_example = json.loads(line)
                # TODO: remove
                if gold_index == -1:
                    gold_index = random.randrange(0, MAX_LINES)
                # gold_index = 5

                # Get the prediction for the input example
                ordered_kv_records = deepcopy(input_example["ordered_kv_records"])
                key = input_example["key"]
                value = input_example["value"]
                if gold_index is not None:
                    original_kv_index = ordered_kv_records.index([key, value])
                    # Remove the kv from its original index
                    original_kv = ordered_kv_records.pop(original_kv_index)
                    ordered_kv_records.insert(gold_index, original_kv)
                else:
                    gold_index = ordered_kv_records.index([key, value])

                if MAX_LINES is not None:
                    # assert gold_index < MAX_LINES
                    ordered_kv_records = ordered_kv_records[:MAX_LINES]

                if gold_index < len(ordered_kv_records):
                    gold_index_kv_examples.append(
                        {
                            "key": key,
                            "value": value,
                            "ordered_kv_records": ordered_kv_records,
                            "gold_index": gold_index,
                        }
                    )
    index = example_idx[-1]
    example = gold_index_kv_examples[index]
    # got examples, update index
    example_idx.pop()
    example_idx.append(index + 1)

    prompt = get_kv_retrieval_prompt(
        data=example["ordered_kv_records"], key=example["key"]
    )
    answer = example["value"]
    answer = f' "{answer}",'
    return prompt, answer
