from transformers import GPT2LMHeadModel, GPT2Tokenizer


TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2")
MODEL = GPT2LMHeadModel.from_pretrained("gpt2", output_attentions=True)


def get_gpt2_attn(text: str):
    input_ids = TOKENIZER(text, return_tensors="pt").input_ids

    # input tokens
    input_tokens = []
    for i in input_ids[0]:
        input_tokens.append(TOKENIZER.decode(i, skip_special_tokens=False))

    outputs = MODEL.generate(input_ids)

    # output tokens
    output_tokens = []
    for o in outputs[0]:
        output_tokens.append(TOKENIZER.decode(o, skip_special_tokens=False))

    # call model with generated output
    outputs = MODEL(input_ids=outputs)

    # attention
    num_heads = len(outputs.attentions)
    heads = []
    for i in range(num_heads):
        heads.append(outputs.attentions[i][0])
    attn = {"dec_attn": heads}

    return attn, input_tokens, output_tokens[len(input_tokens) :]


def test_gpt_empirically():
    queries = []
    with open("newstest2017-ende.en", "r") as f:
        queries = f.readlines()
    return queries


