import os
import json
import torch
import numpy as np
import argparse
from tqdm import tqdm
import pandas as pd
from collections import Counter


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def run(args):
    output_dir = os.path.join(os.getcwd(), f"../../output/tagging/analysis/tag_cross")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    data = get_data(os.path.join(os.getcwd(), "../../output/tagging", args.method))
    selected_types = ["Bar Chart", "Line Graph", "Map", "Network Chart"]
    selected_domains = ["Astronomy", "Biology", "Chemistry", "Health Science", "History", "Mathematics", "Music", "Physics"]

    filtered_data = []
    for item in data:
        if item['tag'] is not None:
            domain = item['tag']['Domain'] if item['tag']['Domain'] in selected_domains else "Other"
            type = item['tag']['Type'] if item['tag']['Type'] in selected_types else "Other"
            filtered_data.append({'Domain': domain, 'Type': type})

    df = pd.DataFrame(filtered_data)
    cross_tab = pd.crosstab(df['Domain'], df['Type'])
    cross_tab.loc['Total'] = cross_tab.sum()
    cross_tab['Total'] = cross_tab.sum(axis=1)
    ordered_index = selected_domains + ["Other", "Total"]
    ordered_columns = selected_types + ["Other", "Total"]

    cross_tab = cross_tab.reindex(index=ordered_index, columns=ordered_columns, fill_value=0)
    cross_file = os.path.join(output_dir, f"cross_({args.method}).csv")
    cross_tab.to_csv(cross_file, encoding="utf-8-sig", index_label="")

    return cross_file


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--method', type=str)
    args = parser.parse_args()
    run(args)
