import argparse

# go over each word
# go over each line
# calculate suggestibility for this word
# save suggestibility scores
# filter for suggestibility above threshold
from shared_utils import *


def get_suggestiveness(
    lines, word1, word2, batch_size=500, num_examples=0, extra_line=True
):
    print(batch_size)
    suggestiveness_scores = []
    for prompt in lines:
        couplet = generate_steered_output(None, model, tokenizer, prompt, batch_size)
        suggestiveness_word1 = get_last_word_fraction(
            couplet, word1, num_examples=num_examples, extra_line=extra_line
        )
        suggestiveness_word2 = get_last_word_fraction(
            couplet, word2, num_examples=num_examples, extra_line=extra_line
        )
        if suggestiveness_word1 == 0 and suggestiveness_word2 == 0:
            continue
        suggestiveness = suggestiveness_word1 - suggestiveness_word2
        suggestiveness_scores.append(suggestiveness)
        print(
            f"Prompt: {prompt}, Suggestiveness {word1}: {suggestiveness_word1}, Suggestiveness {word2}: {suggestiveness_word2}, Suggestiveness: {suggestiveness}"
        )
    suggestiveness_scores = np.array(suggestiveness_scores)
    return suggestiveness_scores, suggestiveness_word1, suggestiveness_word2


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
