import os
import json
import argparse
from collections import Counter


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def get_statistics(data, goal):
    data = [item['tag'][goal] if item['tag'] is not None else "None" for item in data]
    counts = Counter(data)
    counts = dict(counts.most_common())
    return counts


def run(args):
    task_dir = os.path.join(os.getcwd(), f"../../output/tagging")
    output_dir = os.path.join(task_dir, f"analysis/tag_statistics")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    data = get_data(os.path.join(task_dir, args.method))
    domain_counts = get_statistics(data, 'Domain')
    type_counts = get_statistics(data, 'Type')

    statistics_data = {
        "Number of Diagrams": len(data),
        "Domain Distribution": domain_counts,
        "Type Distribution": type_counts,
    }
    statistics_file = os.path.join(output_dir, f"{args.method}.json")
    with open(statistics_file, "w") as f:
        json.dump(statistics_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--method', type=str)
    args = parser.parse_args()
    run(args)
