import os
import re
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages
from fixed_variables import *
from utils import *


def number(triple_str):
    return triple_str.count("\n")


def get_modality_stat(func_name, modality):
    with open("pkl/data_test.pkl", 'rb') as file:
        data_test = pickle.load(file)

    question_stat = defaultdict(lambda: defaultdict(lambda: {'1': 0, '0': 0}))
    for model in tqdm(full_model_list):
        model_self = model.split("/")[-1]
        comparison_df = pd.read_csv(os.path.join("comparison", model_self, f"{modality}.csv"))
        for question in questions:
            for idx, pkl_item in enumerate(data_test):
                val = eval(func_name)(pkl_item['triples'])
                # if val < 5 or val > 20:
                #     continue
                correctness = int(comparison_df.iloc[idx][question])
                question_stat[question][val][str(correctness)] += 1
    return {
        question: {str(val): stat_dict for val, stat_dict in sorted(question_stat[question].items(), key=lambda x: x[0])}
        for question in question_stat
    }


def load_stat(dir, func_name, modality):
    func_name_dir = os.path.join(dir, func_name)
    stat_file = os.path.join(func_name_dir, f"all_{modality}.json")
    if os.path.exists(stat_file):
        with open(stat_file, "r") as f:
            return json.load(f)
    else:
        if not os.path.exists(dir):
            os.makedirs(dir)
        stat = get_modality_stat(func_name, modality)
        with open(stat_file, "w") as f:
            json.dump(stat, f, indent=4)
        return stat


def plot_number_scatter(func_name, modality, colors):
    stat = load_stat("cache/overfit/triple", func_name, modality)
    plot_data = {question: [] for question in questions}
    for question in questions:
        for val, count_dict in stat[question].items():
            num_1 = count_dict["1"]
            num_0 = count_dict["0"]
            mean_val = num_1 / (num_1 + num_0)
            plot_data[question].append((int(val), mean_val))

    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
    axes = axes.flatten()
    for idx, question in enumerate(questions):
        x = [item[0] for item in plot_data[question]]
        y = [item[1] for item in plot_data[question]]
        axes[idx].scatter(x, y, color=colors[idx])
        axes[idx].grid(True)
        axes[idx].set_title(question)
        axes[idx].set_ylim(0.45, 1)
        axes[idx].set_yticks([round(tick, 2) for tick in [0.45 + 0.05 * i for i in range(12)]])
    fig.suptitle(f"Triple Number - {modality}", fontsize=16, y=0.92)
    plt.tight_layout(rect=[0, 0, 1, 0.9])
    return fig


def plot_number_confidence(func_name, modality, colors):
    stat = load_stat("cache/overfit/triple", func_name, modality)
    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
    axes = axes.flatten()
    for idx, question in enumerate(questions):
        (pearson, pearson_p), (spearman, spearman_p), mi = compute_metrics(stat[question])
        legend_labels = [
            f"Pearson = {pearson:.4f}, p-value = {pearson_p:.4f}",
            f"Spearman = {spearman:.4f}, p-value = {spearman_p:.4f}",
            f"MI = {mi:.4f}"
        ]
        x_min, x_max = 5, 20
        bins, _ = get_bins(stat[question])
        bin_x, bin_mean, bin_std = [], [], []
        for b in bins:
            cnt_1 = sum(int(x[1] == 1) for x in b)
            cnt_0 = sum(int(x[1] == 0) for x in b)
            mean, std = mean_std(cnt_1, cnt_0)
            center_val = np.mean([x[0] for x in b])
            bin_x.append(center_val)
            bin_mean.append(mean)
            bin_std.append(std)
        ax = axes[idx]
        bin_x = np.array(bin_x)
        bin_mean = np.array(bin_mean)
        bin_std = np.array(bin_std)
        ax.plot(bin_x, bin_mean, marker='o', color=colors[idx], linewidth=2)
        ax.fill_between(bin_x, bin_mean - bin_std, bin_mean + bin_std, color=colors[idx], alpha=0.3)
        ax.set_title(question)
        ax.set_xlim(x_min, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, x_max // 10)])
        ax.set_ylim(0, 1.5)
        ax.grid(True)
        ax.set_yticks([round(0.1 * i, 2) for i in range(16)])
        for label in legend_labels:
            ax.plot([], [], ' ', label=label)
        ax.legend(loc="upper left", fontsize=9, frameon=False)
    fig.suptitle(f"Triple Number - {modality}", fontsize=16, y=0.92)
    return fig


def e_triple(func_name, plot_type):
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    with PdfPages(f"evaluation/overfit/e_triple_{func_name}_({plot_type}).pdf") as pdf:
        plot_func = eval(f"plot_{func_name}_{plot_type}")
        fig_real = plot_func(func_name, "real", colors)
        pdf.savefig(fig_real)
        plt.close(fig_real)
        fig_syn = plot_func(func_name, "synthetic", colors)
        pdf.savefig(fig_syn)
        plt.close(fig_syn)


if __name__ == '__main__':
    metrics = [
        "number",
    ]
    for metric in metrics:
        e_triple(metric, "scatter")
        e_triple(metric, "confidence")

