from tqdm import tqdm
from transformers import pipeline

def get_formatted(text):
    answers = unmasker(text, top_k=15)
    answers = {x['token_str']:x['score'] for x in answers}
    empty = 1 - sum(answers.values())
    answers[''] = empty
    answers_str =''
    for k,v in answers.items():
        answers_str += k.strip()+':'+str(v) + ' '
    return answers_str.rstrip(' ').lstrip(' ')
    
def write(f_path_in, f_path_out):
    with open(f_path_in) as f_in, open(f_path_out,'w') as f_out:
        i = 0 
        for line in tqdm(f_in,total=10_600):
            char_context = 400
            i+=1
            #print(i)
            is_ok = False
            while not is_ok:
                try:
                    left_text = line.rstrip().split('\t')[-2]
                    right_text = line.rstrip().split('\t')[-1]
                    l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
                    a = get_formatted(l_in)
                    is_ok = True
                except:
                    print('lowering context')
                    char_context -= 50 
                    if char_context < 60:
                        a = ':1'
                        print('lower threshold context exceeded')
                        is_ok = True

            f_out.write(a + '\n')
            #left_text = line.rstrip().split('\t')[-2]
            #right_text = line.rstrip().split('\t')[-1]
            #l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
            #a = get_formatted(l_in)

            #f_out.write(a + '\n')

model = 'roberta-large'
unmasker = pipeline('fill-mask', model=model, device=0)
write('../dev-0/in.tsv', '../dev-0/out.tsv')
write('../test-A/in.tsv', '../test-A/out.tsv')
