import json

from utils.data import a_date_or_a_pear_questions

from paper_experiments.shared_utils import *

THRESHOLD = 0.5
BATCH_SIZE = 100
BATCH_SIZE_PER_PROMPT = 20
NUM_PROMPTS_PER_ROLLOUT = BATCH_SIZE // BATCH_SIZE_PER_PROMPT


def get_suggestivenesses(model, tokenizer, prompts, words):
    suggestivenesses = {word: {} for word in words}
    for i in range(0, len(prompts) + NUM_PROMPTS_PER_ROLLOUT, NUM_PROMPTS_PER_ROLLOUT):
        if i == len(prompts):
            break
        lower_bound = i
        upper_bound = min(i + NUM_PROMPTS_PER_ROLLOUT, len(prompts))
        rollout_prompts = prompts[lower_bound:upper_bound]
        rollout = generate_steered_output(
            None,
            model,
            tokenizer,
            rollout_prompts,
            BATCH_SIZE_PER_PROMPT,
            NUM_PROMPTS_PER_ROLLOUT,
        )
        rollout_clean = [get_cleaned_up_text(text) for text in rollout]
        for i, prompt in enumerate(rollout_prompts):
            texts_for_word = rollout_clean[
                i * BATCH_SIZE_PER_PROMPT : (i + 1) * BATCH_SIZE_PER_PROMPT
            ]
            correct_list = get_last_word_correct(texts_for_word, words)
            for j in range(len(words)):
                suggestivenesses[words[j]][prompt] = correct_list[j].mean()
                print(suggestivenesses[words[j]][prompt], words[j], prompt)
    return suggestivenesses


def main(model_name, model, tokenizer, word):
    print(f"Processing {word}...")
    print(f"Model: {model_name}")

    script_dir = os.path.dirname(os.path.abspath(__file__))
    parent_dir = os.path.dirname(script_dir)

    if os.path.exists(os.path.join(parent_dir, "data", "train", model_name, f"suggestive_{word}_lines.json")):
        print(f"Skipping {word} because it already exists")
        return

    specific_word_lines = json.load(
        open(os.path.join(parent_dir, "data", "train", "specific_word_lines.json"))
    )[word]

    specific_word_prompts = [
        f"A rhyming couplet:\n{line}" for line in specific_word_lines
    ]
    suggestivenesses = get_suggestivenesses(
        model, tokenizer, specific_word_prompts, [word]
    )[word]

    suggestive_prompts = [
        prompt
        for prompt, suggestiveness in suggestivenesses.items()
        if suggestiveness >= THRESHOLD
    ]

    suggestive_lines = [prompt.split("\n")[1] + "\n" for prompt in suggestive_prompts]

    os.makedirs(os.path.join(parent_dir, "data", "train", model_name), exist_ok=True)

    json.dump(
        suggestive_lines,
        open(os.path.join(parent_dir, "data", "train", model_name, f"suggestive_{word}_lines.json"), "w"),
    )


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    # Load model once
    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    try:
        for wordpair in WORD_PAIRS_SAME_RHYME_FAMILY:
            for word in wordpair:
                main(
                    args.model_name,
                    model,
                    tokenizer,
                    word,
                )
    finally:
        # Clean up model
        del model, tokenizer
        cleanup_gpu_memory()
