import json
import re
import matplotlib.pyplot as plt
import collections
import matplotlib
import seaborn as sns
from collections import defaultdict
import os


matplotlib.use("Agg")
matplotlib.rcParams.update({'font.family': 'Times New Roman'})
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True})

for dataset in ["mimic_iii", "eicu"]:
    with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \
            open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out:
        list_num_q_tag_var = []
        list_num_tables = []
        list_num_columns = []
        num_q_tag_var_dict = defaultdict(list)
        num_tables_dict = defaultdict(list)
        num_columns_dict = defaultdict(list)
        for lines in f_in:
            x = json.loads(lines)

            # number of variables in q_tag
            num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[")

            # number of different tables used in the sql
            # (we only count the tables in the original datasets, but not the new ones created in the sql)
            tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"])
            num_tables = len(set(tables))

            # number of different columns used in the sql
            # (we only count the columns in the original datasets, but not the new ones created in the sql)
            columns = re.findall(r'\b\w*\.\w*\b', x["query"])
            columns = [item for item in columns if not re.match(r'^t\d', item)]
            num_columns = len(set(columns))

            x["num_q_tag_var"] = num_q_tag_var
            x["num_tables"] = num_tables
            x["num_columns"] = num_columns

            f_out.write(json.dumps(x) + "\n")

            list_num_q_tag_var.append(num_q_tag_var)
            list_num_tables.append(num_tables)
            list_num_columns.append(num_columns)

            num_q_tag_var_dict[num_q_tag_var].append(x["id"])
            num_tables_dict[num_tables].append(x["id"])
            num_columns_dict[num_columns].append(x["id"])

    os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True)
    for k, id_list in num_q_tag_var_dict.items():
        with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f:
            json.dump(id_list, f)
    os.makedirs(f"{dataset}/num_tables/", exist_ok=True)
    for k, id_list in num_tables_dict.items():
        with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f:
            json.dump(id_list, f)
    os.makedirs(f"{dataset}/num_columns/", exist_ok=True)
    for k, id_list in num_columns_dict.items():
        with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f:
            json.dump(id_list, f)


    # plot the distribution
    all_list = [list_num_q_tag_var, list_num_tables, list_num_columns]
    xlabel = ["# q_tag variables", "# tables", "# columns"]
    file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"]

    os.makedirs(f"{dataset}/figures/", exist_ok=True)
    for i in range(len(all_list)):
        c = collections.Counter(all_list[i])
        c = sorted(c.items())
        plt.figure(figsize=(6, 5.5), dpi=120)
        ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0])
        plt.ylabel("Frequency", size=30)
        plt.xlabel(xlabel[i], size=30)
        plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05])
        plt.savefig(f"{dataset}/figures/{file_name[i]}")
