"""Generate some stats for the data.
"""

import click
import os
import networkx as nx
import numpy as np
import ujson as json
import matplotlib.pyplot as plt
from jinja2 import Template


def _is_entail(
    cluster_a,
    cluster_b,
    entailment_mat
):
    cluster_a = cluster_a['set_ids']
    cluster_b = cluster_b['set_ids']

    for i in cluster_a:
        for j in cluster_b:
            if entailment_mat[i, j] > 0:
                return True
    return False


@click.command()
@click.option("--input-dir", type=click.Path(exists=True), help="Path to the input data.", required=True)
# @click.option("--cluster-dir", type=click.Path(exists=True), help="Path to the cluster data.", required=True)
@click.option("--output-dir", type=click.Path(), help="Path to the output directory.", required=True)
def main(
    input_dir,
    # cluster_dir,
    output_dir
):
    """
    """

    data = []
    # idmap = {}
    # entailment_mats = []

    for filename in os.listdir(input_dir):

        ipath = os.path.join(input_dir, filename)
        idmap_filename = filename.replace("result", "idmap").replace("jsonl", "json")
        numpy_filename = filename.replace("result", "entailment").replace("jsonl", "npy")

        with open(ipath, 'r', encoding='utf-8') as file_:
            data.extend([json.loads(line) for line in file_])

        # with open(os.path.join(cluster_dir, idmap_filename), 'r') as file_:
        #     idmap.update({int(k): v + len(idmap) for k, v in json.load(file_).items()})

        # this_mat = np.load(os.path.join(cluster_dir, numpy_filename))
        # entailment_mats.append(this_mat.reshape(-1, 100, 100))

    # entailment_mats = np.concatenate(entailment_mats, axis=0)

    num_meta_clusters = [len(d['meta_clusters']) for d in data]

    fig, ax = plt.subplots()

    # draw distribution of the number of clusters in a datapoint
    ax.hist(num_meta_clusters, bins=range(1, max(num_meta_clusters)+2), align='left', rwidth=0.8)
    ax.set_ylim(0, 500)
    ax.set_title("Dist of Number of Super Clusters")
    ax.set_xlabel("# of Super Clusters")
    ax.set_ylabel("# Points")

    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir, exist_ok=True)
    fig.savefig(os.path.join(output_dir, "num_meta_clusters.png"))
    
    num_clusters = [len(d['clusters']) for d in data]

    fig, ax = plt.subplots()

    # draw distribution of the number of clusters in a datapoint
    ax.hist(num_clusters, bins=range(1, max(num_clusters)+2), align='left', rwidth=0.8)
    ax.set_ylim(0, 500)
    ax.set_title("Dist of Number of Clusters")
    ax.set_xlabel("# of Clusters")
    ax.set_ylabel("# Points")

    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir, exist_ok=True)
    fig.savefig(os.path.join(output_dir, "num_clusters.png"))

    # # sample stratified by number of clusters
    # # data = sorted(data, key=lambda x: len(x['clusters']))
    # sorted_ids = np.argsort(num_clusters)
    # sorted_data = [data[i] for i in sorted_ids]
    # # sample 10 datapoints in total
    # num_samples = 10

    # interval = len(data) // num_samples
    # sampled = [(sorted_data[(i + 1) * interval], (i + 1) * interval) for i in range(num_samples)]
    # sampled_entailment_mat = entailment_mats[sorted_ids[interval::interval]]
    
    # assert len(sampled) == len(sampled_entailment_mat)
    
    # for sitem, idx in zip(sampled, sorted_ids[interval::interval]):
    #     assert idmap[sitem[0]['example_id']] == idx, f"{idmap[sitem[0]['_id']]} != {idx}"

    # with open("data/templates/sample_presentation.html", 'r', encoding='utf-8') as file_:
    #     template = Template(file_.read())

    # for (datapoint, index), entailment_mat in zip(sampled, sampled_entailment_mat):
        
    #     # also create a entailment graph
    #     plt.clf()
    #     G = nx.DiGraph()
    #     for cluster in datapoint['clusters']:
    #         G.add_node(cluster['cluster_id'] + 1)

    #     for i, cluster_a in enumerate(datapoint['clusters']):
    #         for j, cluster_b in enumerate(datapoint['clusters']):
    #             if i != j and _is_entail(cluster_a, cluster_b, entailment_mat):
    #                 G.add_edge(cluster_a['cluster_id'] + 1, cluster_b['cluster_id'] + 1)
                    
    #     pos = nx.nx_agraph.graphviz_layout(G)
    #     nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, font_color="black")
        
    #     # save svg
    #     plt.savefig(os.path.join(output_dir, f"sample_{index}.svg"))

    #     # open svg as a string and save it in a html file
    #     with open(os.path.join(output_dir, f"sample_{index}.svg"), 'r', encoding='utf-8') as file_:
    #         svg_content = file_.read()

    #     with open(os.path.join(output_dir, f"sample_{index}.html"), 'w', encoding='utf-8') as file_:
    #         texts = [cluster['claim'] for cluster in datapoint['clusters']]
    #         file_.write(template.render(texts=texts, index=datapoint['example_id'], svg_content=svg_content))


if __name__ == "__main__":
    main()
