from ml2.shapley.gpt2_attn import get_gpt2_attn
from ml2.shapley.attention_rollout import calculate_rollout_sum, sum_up_heads
import ml2.shapley.attention_flow
import time
from datetime import datetime
import random
from ml2.shapley.helsinki_attn import get_attn_helsinki
import wandb



def test_gpt_empirically(num_samples):
    queries = []
    with open("/ml2/ml2/shapley/newstest2017-ende.en", "r") as f:
    #with open("newstest2017-ende.en", "r") as f:
        queries = f.readlines()
    return random.sample(queries, num_samples + 1)


def helsinki_rollout(num_samples):
    lst = test_gpt_empirically(num_samples)
    results_rollout = []
    results_flow = []
    f = open(
        "/home/c01nime/CISPA-home/attention_rollout_"
    #    "attention_rollout_"
        + str(num_samples)
        + "_"
        + datetime.now().strftime("%m-%d-%Y_%H-%M")
        + ".txt",
        "w",
    )
    new_lst = []
    for text in lst:
        if len(text.split(" ")) < 25:
            new_lst.append(text)
    count = 1
    for text in new_lst:
        attn, input_tokens, output_tokens = get_attn_helsinki(text)
        # print(input_tokens)
        # print(output_tokens)
        print(str(count) + "/" + str(num_samples))
        attn_rollout = sum_up_heads(attn)
        attn_flow = ml2.shapley.attention_flow.sum_up_heads(attn)
        for x in range(2, 3):
            # start = time.time()
            result_list = calculate_rollout_sum(attn_rollout, input_tokens, output_tokens,layers_enc = 8, layers_dec = 8,  output_token=x)
            results_rollout.append(result_list)
            result_list = ml2.shapley.attention_flow.calculate_shapley_sum(attn_flow, input_tokens, output_tokens,layers_enc = 8, layers_dec = 8,  output_token=x)
            results_flow.append(result_list)
            # end = time.time()
            # print("The time for token " + str(x) + " was " + str(end-start))
            f.write(str(input_tokens) + " ; " + str(output_tokens) + " ; " + str(results_rollout) + " ; " + str(results_flow) + "\n")
            print(str(input_tokens) + " ; " + str(output_tokens) + " ; " + str(results_rollout) + " ; " + str(results_flow) + "\n")
        count = count + 1
    f.close()


if __name__ == "__main__":
    wandb.init(project="shaplyay")
    helsinki_rollout(20)
