import torch
from fairseq.models.roberta import  RobertaModel
from fairseq import hub_utils
from fairseq.models.roberta import RobertaModel, RobertaHubInterface

import os
from tqdm import tqdm


roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')

roberta.eval() 
roberta.cuda()


preds = roberta.fill_mask('I like <mask> and apples', topk=3)
#import pdb; pdb.set_trace()

# raise CUDA RuntimeError from which 
# the process does not recover
BLACKLIST = ['aeeadb08042bbd49dcbefcefa1f13806',
            '01ba303704bb62bcb59f8cb7cb5663d7',
            '98bdfa711364f45f1bcffb1359793614',
            'a9da7950abcbd531a5207c04c3bdc840',
            '4cd7f730ee72451406afa89c5c6431d6',
        ]

def predict(f_in_path,f_out_path):
    f_in = open(f_in_path,'r', newline='\n')
    f_out = open(f_out_path,'w', newline='\n')

    for line in tqdm(f_in,total = 88000):
        id,_, before, after = line.split('\t')
        before = before.replace('\\n', '\n')
        after = after.replace('\\n', '\n')
        before = ' '.join(before.split(' ')[-40:]) # tu można poprawić, żeby  śmigał na tokenal spm a nie zakładał że jest jak ze spacjami
        after = ' '.join(after.split(' ')[:40])
        input = before + ' <mask> ' + after
        try:
            if id in BLACKLIST:
                f_out.write(':1\n')
                continue
            preds = roberta.fill_mask(input, topk=10)
            hyps = []
            probs_sum = 0.0
            for pred in preds:
                if pred[2] == '<unk>':
                    continue
                hyps.append(pred[2].rstrip().lstrip() + ':' + str(pred[1]))
                probs_sum += pred[1]
            hyps.append(':' + str(1 - probs_sum))
            preds_line = ' '.join(hyps)
            f_out.write(preds_line + '\n')
        except RuntimeError:
            import pdb ; pdb.set_trace()
            print('RUNTIMEERROR')
            f_out.write(':1\n')

    f_out.close()

predict('../dev-0/in.tsv', '../dev-0/out.tsv')
predict('../test-A/in.tsv', '../test-A/out.tsv')
