from transformers import AutoTokenizer, AutoModelForCausalLM


def get_attn_dialogpt_medium(input):
    tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
    model = AutoModelForCausalLM.from_pretrained(
        "microsoft/DialoGPT-medium", output_attentions=True
    )
    input_ids = tokenizer(input + tokenizer.eos_token, return_tensors="pt").input_ids
    outputs = model.generate(input_ids)
    output_tokens = []
    for o in outputs[0]:
        output_tokens.append(tokenizer.decode(o, skip_special_tokens=False))
    input_tokens = []
    for o in input_ids[0]:
        input_tokens.append(tokenizer.decode(o, skip_special_tokens=False))
    outputs = model(input_ids=outputs)
    heads = []
    for i in range(16):
        head = []
        for j in range(24):
            head.append(outputs.attentions[j][0][i])
        heads.append(head)
    attn = {"dec_attn": heads}

    return attn, input_tokens, output_tokens[len(input_tokens) :]
