""" Combining Meta Clusters, e.g. the cluster claim are entailed by each other """

import click
import ujson as json
from typing import Text
import os


@click.command()
@click.option('--input-dir', type=click.Path(exists=True), required=True, help='Directory containing the input files')
@click.option('--cluster-dir', type=click.Path(exists=True), required=True, help='Directory containing the clusters')
@click.option('--output-dir', required=True, help='Directory to save the output files')
def main(
    input_dir,
    cluster_dir,
    output_dir
):
    """ """
    
    def _convert_to_cluster_filename(input_filename: Text) -> Text:
        return input_filename.replace("result-split-", "result-")
    
    for input_filename in os.listdir(input_dir):
        if not input_filename.endswith(".jsonl"):
            continue
        with open(os.path.join(input_dir, input_filename), "r", encoding='utf-8') as file_:
            data = [json.loads(line) for line in file_]

        with open(os.path.join(cluster_dir, _convert_to_cluster_filename(input_filename)), "r", encoding='utf-8') as file_:
            clusterings = [json.loads(line) for line in file_]
            
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, input_filename), "w", encoding='utf-8') as file_:
            for item, cluster_spec_list in zip(data, clusterings):
                existing_clusters = item['clusters']
                results = []
                
                # creating meta-clusters their
                # {
                #    "meta_cluster_id": Int,
                #    "sent_set_ids": List[Int],
                #    "cluster_set_ids": List[Int],
                #    "sentences": List[Text],
                #    "claims": List[Text],
                #    "meta-claim": Text
                # }
                for meta_cidx, cluster_spec in enumerate(cluster_spec_list):
                    cluster_set_ids = cluster_spec['set_ids']
                    sent_set_ids = sorted([sidx for cs in cluster_set_ids for sidx in existing_clusters[cs]['set_ids']])
                    sentences = [item['answers'][s] for s in sent_set_ids]
                    claims = [existing_clusters[c]['claim'] for c in cluster_set_ids]

                    meta_claim = sorted(claims, key=lambda x: len(x))[0]
                    
                    results.append({
                        "meta_cluster_id": meta_cidx,
                        "sent_set_ids": sent_set_ids,
                        "cluster_set_ids": cluster_set_ids,
                        "sentences": sentences,
                        "claims": claims,
                        "meta-claim": meta_claim
                    })

                file_.write(json.dumps({**item, "meta_clusters": results}) + "\n")
                
                
if __name__ == '__main__':
    main()