import os
import json
import torch
import numpy as np
import argparse
from tqdm import tqdm
from collections import Counter


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def get_statistics(data, idx_list, goal):
    data = [data[idx]['tag'][goal] if data[idx]['tag'] is not None else "None" for idx in idx_list]
    counts = Counter(data)
    counts = dict(counts.most_common())
    return counts


def use_equal(data_1, data_2):
    domain_idx_list = []
    type_idx_list = []
    both_idx_list = []
    for idx, (item_1, item_2) in enumerate(zip(data_1, data_2)):
        if item_1['tag'] and item_2['tag']:
            domain_equal = item_1['tag']['Domain'] == item_2['tag']['Domain']
            type_equal = item_1['tag']['Type'] == item_2['tag']['Type']
            if domain_equal:
                domain_idx_list.append(idx)
            if type_equal:
                type_idx_list.append(idx)
            if domain_equal and type_equal:
                both_idx_list.append(idx)
    return domain_idx_list, type_idx_list, both_idx_list


def run(args):
    output_dir = os.path.join(os.getcwd(), f"../../output/tagging/analysis")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    data_1 = get_data(os.path.join(os.getcwd(), "../../output/tagging", args.method_1))
    data_2 = get_data(os.path.join(os.getcwd(), "../../output/tagging", args.method_2))

    domain_idx_list, type_idx_list, both_idx_list = use_equal(data_1, data_2)
    domain_domain_counts = get_statistics(data_1, domain_idx_list, 'Domain')
    type_type_counts = get_statistics(data_1, type_idx_list, 'Type')
    both_domain_counts = get_statistics(data_1, both_idx_list, 'Domain')
    both_type_counts = get_statistics(data_1, both_idx_list, 'Type')

    statistics_data = {
        "Domain Equal": {
            "Number of Diagrams": len(domain_idx_list),
            "Domain Distribution": domain_domain_counts,
        },
        "Type Equal": {
            "Number of Diagrams": len(type_idx_list),
            "Type Distribution": type_type_counts,
        },
        "Both Equal": {
            "Number of Diagrams": len(both_idx_list),
            "Domain Distribution": both_domain_counts,
            "Type Distribution": both_type_counts,
        },
    }
    statistics_file = os.path.join(output_dir, f"tag_equal_({args.method_1})({args.method_2}).json")
    with open(statistics_file, "w") as f:
        json.dump(statistics_data, f, indent=4)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--method_1', type=str)
    parser.add_argument('--method_2', type=str)
    args = parser.parse_args()
    run(args)

