# Import the SDK and the client module
import sys, os, shutil
import copy
import re
import json
from tqdm import tqdm
from enum import Enum, auto
from pathlib import Path

from upsetplot import UpSet, plot, generate_counts, from_indicators, from_memberships
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from collections import Counter


def plot_upset(data, save_path):
    all_topics = sorted(list(set(topic for d in data for topic in d["meat_data"]["topic"])))
    records = []
    for item in data:
        row = {topic: topic in item["meat_data"]["topic"] for topic in all_topics}
        records.append(row)
    df_boolean = pd.DataFrame(records)
    print(df_boolean)
    print(records)
    min_subset_size = 2

    upset_data = df_boolean.groupby(all_topics).size()
    plot(
        upset_data,
        sort_by="cardinality",
        min_subset_size=min_subset_size,
        facecolor="darkblue",
    )
    plt.suptitle("Topic Distribution", fontsize=20)
    plt.savefig(os.path.join(save_path, "topic_upset.png"), bbox_inches="tight")
    plt.close()


def count_topic(data):
    topic_per_charts = []
    total_topics = 0
    for item in data:
        topic_per_charts.append(len(item.get("meat_data", {}).get("topic", [])))
        total_topics += len(item.get("meat_data", {}).get("topic", []))
    all_topics_list = [topic for item in data for topic in item.get("meat_data", {}).get("topic", [])]
    topic_counts = Counter(all_topics_list)
    for topic, count in topic_counts.most_common():
        print(f"- {topic}: {count}")
    print(total_topics / len(topic_per_charts))
    return topic_counts


if __name__ == "__main__":
    data_path = "project/chartqa/data/label_studio/data_annotated/processed_annotations/all_annotations.json"
    save_path = "project/chartqa/src/analysis/pictures"

    with open(data_path, "r", encoding="utf-8") as f:
        data = json.load(f)
        # plot_upset(data, save_path)
        count_topic(data=data)
