from ml2.shapley.gpt2_attn import get_gpt2_attn
from ml2.shapley.attention_flow import calculate_shapley_sum, sum_up_heads
import time
from datetime import datetime
import random
import wandb



def test_gpt_empirically(num_samples):
    queries = []
    with open("/ml2/ml2/shapley/newstest2017-ende.en", "r") as f:
        # with open("/ml2/ml2/shapley/newstest2017-ende.en", "r") as f:
        queries = f.readlines()
    return random.sample(queries, num_samples + 1)


def token_bias(num_samples):
    lst = test_gpt_empirically(num_samples)
    results = []
    f = open(
        "/home/c01nime/CISPA-home/firstTokenBias_"
        + str(num_samples)
        + "_"
        + datetime.now().strftime("%m-%d-%Y_%H-%M")
        + ".txt",
        "w",
    )
    new_lst = []
    for text in lst:
        if len(text.split(" ")) < 20:
            tmp = text.split(" ")
            new_lst.append(" ".join(tmp[:10]))
    count = 1
    for text in new_lst:
        attn, input_tokens, output_tokens = get_gpt2_attn(text)
        # print(input_tokens)
        # print(output_tokens)
        print(str(count) + "/" + str(num_samples))
        attn = sum_up_heads(attn)
        for x in range(len(output_tokens) - 1, len(output_tokens)):
            # start = time.time()
            result_list = calculate_shapley_sum(
                attn, input_tokens, output_tokens, layers_dec=12, output_token=x, decoder_only=True, plot=False
            )
            results.append(result_list[:6])
            # end = time.time()
            # print("The time for token " + str(x) + " was " + str(end-start))
            f.write(str(result_list) + "\n")
        # print(results)
        count = count + 1
    f.close()


if __name__ == "__main__":
    wandb.init(project="shaplyay")
    token_bias(1000)
