import os
import fire
import glob
import json
import numpy as np
import tqdm
from transformers import AutoTokenizer


def convert_gen_ids_to_target_inputs(gen_ids):
    gen_ids_upd = []
    for gen_id in gen_ids:
        if gen_id == 1:
            break
        elif gen_id != 0:
            gen_ids_upd.append(gen_id)

    gen_ids_upd.append(1)

    input_ids = np.zeros(len(gen_ids), dtype=int)
    attention_mask = np.zeros(len(gen_ids), dtype=int)

    input_ids[:len(gen_ids_upd)] = np.array(gen_ids_upd)
    attention_mask[:len(gen_ids_upd)] = 1

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


def main(model_name='google/flan-t5-small',
         gen_dir='./bigbench_gen_results',
         output_dir='./bigbench_gen_results_decoded'):
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    for filename in glob.glob(f'{gen_dir}/*/*/*.jsonl'):
        output_filename = filename.replace(gen_dir, output_dir)
        if os.path.exists(output_filename):
            continue
        os.makedirs('/'.join(output_filename.split('/')[:-1]), exist_ok=True)

        with open(output_filename, 'w') as output_file:
            for line in tqdm.tqdm(open(filename), desc=filename):
                gen_ids = json.loads(line)['gen_ids']
                response = tokenizer.decode(gen_ids, skip_special_tokens=True)

                if 'validation' in output_filename:
                    nlls = np.array(json.loads(line)['nlls'])
                    attention_mask = convert_gen_ids_to_target_inputs(
                        gen_ids=gen_ids)['attention_mask']

                    nll_sum = np.sum(nlls * attention_mask)
                    nll_avg = nll_sum / np.sum(attention_mask)
                else:
                    nll_sum = nll_avg = None

                output_file.write(json.dumps({
                    'response': response,
                    'nll_sum': nll_sum,
                    'nll_avg': nll_avg
                }) + '\n')


if __name__ == '__main__':
    fire.Fire(main)
