"""Use this to combine info from input and output files.
"""

import click
import ujson as json
from glob import glob
import logging
import os


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(formatter)
logger.addHandler(handler)


@click.command()
@click.option("--input-dir", type=click.Path(exists=True, dir_okay=True, file_okay=False), help="Path to the input file.", required=True)
@click.option("--output-dir", type=click.Path(exists=False), help="Path to the output file.", required=True)
@click.option("--reference", type=click.Path(exists=True, file_okay=True, dir_okay=False), help="Path to the file where the addtional information is acquired.", required=True)
def main(
    input_dir,
    output_dir,
    reference
):
    """
    """
    with open(reference, 'r', encoding='utf-8') as file_:
        reference_data = [json.loads(line) for line in file_]
        reference_dictionary = {str(ref_dict['example_id']): ref_dict for ref_dict in reference_data}
    
    
    os.makedirs(output_dir, exist_ok=True)
    for filepath in glob(os.path.join(input_dir, "*.json")):
        result_filepath = filepath.replace("idmap-", "result-").replace(".json", ".jsonl")
        write_path = os.path.join(output_dir, os.path.basename(result_filepath))
        if os.path.exists(write_path):
            logger.info(f"Skipping {write_path}")
            continue
        with open(filepath, 'r', encoding='utf-8') as file_:
            idmap = json.load(file_)
        with open(result_filepath, 'r', encoding='utf-8') as file_:
            results = [json.loads(line) for line in file_]
            
        with open(write_path, 'w', encoding='utf-8') as file_:
            for _id, info in reference_dictionary.items():
                if _id not in idmap:
                    continue
                ref_index = idmap[_id]
                file_.write(json.dumps({**info, "clusters": results[ref_index]}) + '\n')
                
                
if __name__ == "__main__":
    main()