from knowledge_neurons import (
    KnowledgeNeurons,
    initialize_model_and_tokenizer,
    model_type,
)
import random


def test_gpt(MODEL_NAME: str):
    TEXT = "Q: What is the capital of England?\nA: The capital of England is London\nQ: What is the capital of France?\nA: The capital of France is"
    GROUND_TRUTH = " Paris"
    BATCH_SIZE = 20
    STEPS = 20
    PERCENTILE = 99.7
    GPT_TEXTS = [
        "The capital of france is",
        "There are some things to be that: The capital of france is",
        "As far as we know: The capital of france is",
        "There is an obvious fact that: The capital of france is",
    ]
    P = 0.6

    # setup model
    model, tokenizer = initialize_model_and_tokenizer(MODEL_NAME, "/home/pod/shared-nvme/ptms")

    kn = KnowledgeNeurons(model, tokenizer, model_type=model_type(MODEL_NAME))
    # coarse_neurons = kn.get_coarse_neurons(
    #     TEXT,
    #     GROUND_TRUTH,
    #     batch_size=BATCH_SIZE,
    #     steps=STEPS,
    #     percentile=PERCENTILE,
    #     attribution_method="max_activations"
    # )

    refined_neurons = kn.get_refined_neurons(
        GPT_TEXTS,
        GROUND_TRUTH,
        p=P,
        batch_size=BATCH_SIZE,
        steps=STEPS,
        coarse_percentile=PERCENTILE,
        # coarse_adaptive_threshold=0.1,
    )
    print(f"Refined knowledge neurons: {refined_neurons}")
    print("\nSuppressing refined neurons: \n")
    results_dict, unpatch_fn = kn.suppress_knowledge(
        TEXT, GROUND_TRUTH, refined_neurons
    )

    # print("\nSuppressing random neurons: \n")
    # random_neurons = [
    #     [
    #         random.randint(0, kn.n_layers() - 1),
    #         random.randint(0, kn.intermediate_size() - 1),
    #     ]
    #     for i in range(len(refined_neurons))
    # ]
    # results_dict, unpatch_fn = kn.suppress_knowledge(TEXT, GROUND_TRUTH, random_neurons)

    print("\nSuppressing refined neurons for an unrelated prompt: \n")
    results_dict, unpatch_fn = kn.suppress_knowledge(
        "Q: What is the official language of Spain?\nA: The official language of Spain is Spanish.\nQ: What is the official language of the Solomon Islands?\nA: The official language of the Solomon Islands is",
        " English",
        refined_neurons,
    )

    print("\nErasing refined neurons: \n")
    results_dict, unpatch_fn = kn.erase_knowledge(
        TEXT, refined_neurons, target=GROUND_TRUTH, erase_value="zero"
    )

    print("\nEnhancing refined neurons: \n")
    results_dict, unpatch_fn = kn.enhance_knowledge(TEXT, GROUND_TRUTH, refined_neurons)

    # print("\nEnhancing random neurons: \n")
    # results_dict, unpatch_fn = kn.enhance_knowledge(TEXT, GROUND_TRUTH, random_neurons)


def test_gpt2():
    MODEL_NAME = "gpt2"
    test_gpt(MODEL_NAME)


def test_gptneo():
    MODEL_NAME = "EleutherAI/gpt-neo-125M"
    test_gpt(MODEL_NAME)



if __name__ == "__main__":
    # test_gptneo()
    test_gpt2()