import os
import tqdm
import fire
import glob
import json
from transformers import AutoTokenizer
from rouge_score import rouge_scorer


def decode_example(example, tokenizer, rouger):
    response_token_ids = example.pop('response_ids')
    example['response'] = tokenizer.decode(
        response_token_ids, skip_special_tokens=True)

    example['label'] = rouger.score_multi(
        targets=example['references'],
        prediction=example['response']
    )['rougeL'].fmeasure

    return example


def main(t0_data_dir='./processed_t0_data',
         gen_dir='./t0_gen_data_results',
         decoded_gen_dir='./t0_gen_data_results_decoded',
         do_decode=False,
         do_update=False):
    if do_decode:
        rouger = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
        for model_name in ['yuchenlin/BART0', 'bigscience/T0_3B']:
            tokenizer = AutoTokenizer.from_pretrained(model_name)

            model_name = model_name.split('/')[-1]
            for filename in glob.glob(f'{gen_dir}/{model_name}_*/*.jsonl'):
                output_filename = filename.replace(gen_dir, decoded_gen_dir)
                if os.path.exists(output_filename):
                    print(f'{filename} already decoded.')
                    continue

                os.makedirs(
                    '/'.join(output_filename.split('/')[:-1]), exist_ok=True)
                with open(output_filename, 'w') as output_file:
                    for line in tqdm.tqdm(
                            open(filename), desc=f'decoding {filename}'):
                        example = decode_example(
                            example=json.loads(line),
                            tokenizer=tokenizer,
                            rouger=rouger)
                        output_file.write(json.dumps(example) + '\n')

    if do_update:
        for filename in glob.glob(f'{decoded_gen_dir}/*/*.jsonl'):
            edits = {}
            for line in tqdm.tqdm(
                    open(filename), desc=f'loading from {filename}'):
                example = json.loads(line)
                dataset_name = example['dataset_name']
                template_name = example['template_name']
                example_idx = example['example_idx']

                file_to_edit = \
                    f'{t0_data_dir}/{dataset_name}/{template_name}.jsonl'
                if file_to_edit not in edits:
                    edits[file_to_edit] = {}

                edits[file_to_edit][example_idx] = \
                    [example['response'], example['label']]

            for file_to_edit in edits:
                examples = []
                n_updates = 0
                for line in open(file_to_edit):
                    example = json.loads(line)
                    if example['example_idx'] in edits[file_to_edit]:
                        response, label = \
                            edits[file_to_edit][example['example_idx']]
                        if response not in [t[0] for t in example['responses']]:
                            example['responses'].append([response, label])
                            n_updates += 1
                    examples.append(example)

                with open(file_to_edit, 'w') as output_file:
                    for example in examples:
                        output_file.write(json.dumps(example) + '\n')

                print(f'\t{n_updates} / {len(examples)} examples '
                      f'updated in {file_to_edit}.')


if __name__ == '__main__':
    fire.Fire(main)
