# This script is supposed to be used to generate lines that could lead up to any of two specific words (for a nice testbed ...) or I go fuck it and steer from examples that  really suggest the one word ...
# One option would be to filter out lines where claude really is split 50 / 50
# Another option is just to tell claude to generate these lines. This is way easier ...

# %%
import json
import os
import sys

script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

# Now use absolute import instead of relative import
from shared_utils import (
    WORD_PAIRS_DIFFERENT_RHYME_FAMILY,
    WORD_PAIRS_SAME_RHYME_FAMILY,
)

from utils.llm_utils import generate_rollout


def get_list(rollout):
    l = rollout.split("\n\n")[1:]
    l = [line.split("\n")[0] + "\n" for line in l]
    for i in range(len(l)):
        line = l[i]
        # check if line[0] is nubmer
        if line[0].isdigit():
            l[i] = l[i].split(". ")[1]
    return l


# %%

word_pairs = WORD_PAIRS_SAME_RHYME_FAMILY + WORD_PAIRS_DIFFERENT_RHYME_FAMILY

model_id = "anthropic/claude-3-7-sonnet"
generation_prompt_template = """Please generate exactly 15 diverse two-line{rhyming} couplets, where for the last word, both {word1} and {word2} are possible.
"""

dataset_cfgs = {}
for i, (word1, word2) in enumerate(word_pairs):
    if i < 10:
        rhyming = " rhyming"
    else:
        rhyming = " unrhymed"
    dataset_cfgs[f"{word1}-{word2}"] = {
        "file_path": f"../data/temp/{word1}-{word2}.json",
        "generation_prompt": generation_prompt_template.format(
            word1=word1, word2=word2, rhyming=rhyming
        ),
    }

datasets = {}
# Define file paths for saving/loading the data using absolute paths
for dataset_name in dataset_cfgs:
    file_path = os.path.join(script_dir, dataset_cfgs[dataset_name]["file_path"])
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            dataset = json.load(f)
    else:
        print(f"Generating {dataset_name} data")
        rollout = generate_rollout(
            [dataset_cfgs[dataset_name]["generation_prompt"]], model_id
        )[0]
        print(rollout[:300])
        dataset = get_list(rollout)
        with open(file_path, "w") as f:
            json.dump(dataset, f)
    datasets[dataset_name] = dataset

# %%
for dataset_name in datasets:
    print(len(datasets[dataset_name]))

# %%
path = os.path.join(script_dir, "..", "data", "test", "specific_word_pairs.json")
# save datasets
with open(path, "w") as f:
    json.dump(datasets, f)
