import os
import sys
import fire
import json
import random
from prettytable import PrettyTable


def main(result_dir, n_present=20):
    rel_set = result_dir.split('/')[1]
    relation_info = json.load(open(f'relation_info/{rel_set}.json'))

    summary_file = open(f'{result_dir}/summary.txt', 'w')

    for rel, info in relation_info.items():
        columns = {'Seed samples': info['seed_ent_tuples']}

        if not os.path.exists(f'{result_dir}/{rel}/ent_tuples.json'):
            print(f'outputs of relation \"{rel}\" not found. skipped.')
            continue

        weighted_prompts = json.load(open(f'{result_dir}/{rel}/prompts.json'))
        weighted_ent_tuples = json.load(open(
            f'{result_dir}/{rel}/ent_tuples.json'))

        if len(weighted_ent_tuples) == 0:
            print(f'outputs of relation \"{rel}\" not found. skipped.')
            continue
        weighted_ent_tuples = weighted_ent_tuples[:200]

        columns[f'Ours (Top {n_present})'] = [
            str(ent_tuple) for ent_tuple, _ in weighted_ent_tuples[:n_present]]

        columns[f'Ours (Random samples over top 200 tuples)'] = [
            str(ent_tuple) for ent_tuple, _ in random.sample(
                weighted_ent_tuples, n_present)]

        table = PrettyTable()
        for key, col in columns.items():
            if len(col) < n_present:
                col.extend(['\\'] * (n_present - len(col)))
            table.add_column(key, col)

        def _print_results(output_file):
            print(f'Relation: {rel}', file=output_file)
            print('Prompts:', file=output_file)
            for prompt, weight in weighted_prompts:
                print(f'- {weight:.4f} {prompt}', file=output_file)
            print('Harvested Tuples:', file=output_file)
            print(table, file=output_file)
            print('=' * 50, file=output_file, flush=True)

        _print_results(output_file=summary_file)
        _print_results(output_file=sys.stdout)

    print(f'This summary has been saved into {summary_file.name}.')


if __name__ == '__main__':
    fire.Fire(main)
