"""Calculate dataset-level stats for the dataset.
"""

import glob
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import ujson as json
from tqdm import tqdm
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from ..data_readers import AnswerClusterAttachedDataReader


@BaseTask.register("dataset-stats-calculation")
class DatasetStatsCalculation(BaseTask):
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
    ):
        """ """
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir

    @overrides
    def _run(self):
        
        iterator = AnswerClusterAttachedDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
        )

        results = []
        num_clusters = [len(item.clusters) for item in tqdm(list(iterator))]

        # calculate the stats (bin-plot)
        fig, ax = plt.subplots()
        ax.hist(num_clusters, bins=range(0, 100, 5))
        ax.set_xticks(range(0, 100, 5))
        ax.set_xlabel("Number of Clusters")
        ax.set_ylabel("Number of Instances")
        ax.set_title("Number of Clusters Distribution")

        return fig
    
    @overrides
    def _write(self, outputs):
        """ Write the outputs """
        outputs.savefig(os.path.join(self._output_dir, "num_clusters_distribution.png"))
        # plt.savefig(os.path.join(self._output_dir, "num_clusters_distribution.png"))