from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


def get_attn_helsinki(input):
    tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
    model = AutoModelForSeq2SeqLM.from_pretrained(
        "Helsinki-NLP/opus-mt-en-de", output_hidden_states=True, output_attentions=True
    )
    input_ids = tokenizer(input, return_tensors="pt").input_ids
    outputs = model.generate(input_ids)
    output_tokens = []
    for o in outputs[0]:
        output_tokens.append(tokenizer.decode(o, skip_special_tokens=False))
    print(output_tokens)
    input_tokens = []
    for o in input_ids[0]:
        input_tokens.append(tokenizer.decode(o, skip_special_tokens=False))
    print(input_tokens)
    # print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # print(outputs)
    decoder_input_ids = tokenizer(output, return_tensors="pt").input_ids
    outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    num_heads = len(outputs.encoder_attentions)
    encoder_heads = []
    decoder_heads = []
    cross_heads = []
    for i in range(num_heads):
        encoder_heads.append(outputs.encoder_attentions[i][0])
        decoder_heads.append(outputs.decoder_attentions[i][0])
        cross_heads.append(outputs.cross_attentions[i][0])
    attn = {
        "enc_attn": encoder_heads,
        "dec_attn": decoder_heads,
        "enc_dec_attn": decoder_heads,
    }
    return attn, input_tokens, output_tokens
