import os
import re
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages
from fixed_variables import *
from utils import *


def number(triple_str):
    return triple_str.count("\n")


def get_modality_stat(func_name, model_self, modality):
    with open("pkl/data_test.pkl", 'rb') as file:
        data_test = pickle.load(file)

    question_stat = defaultdict(lambda: defaultdict(lambda: {'1': 0, '0': 0}))
    comparison_df = pd.read_csv(os.path.join("comparison", model_self, f"{modality}.csv"))
    for question in questions:
        for idx, pkl_item in enumerate(data_test):
            val = eval(func_name)(pkl_item['triples'])
            # if val < 5 or val > 20:
            #     continue
            correctness = int(comparison_df.iloc[idx][question])
            question_stat[question][val][str(correctness)] += 1
    return {
        question: {str(val): stat_dict for val, stat_dict in sorted(question_stat[question].items(), key=lambda x: x[0])}
        for question in question_stat
    }


def load_stat(dir, func_name, model, modality):
    func_name_dir = os.path.join(dir, func_name)
    stat_file = os.path.join(func_name_dir, f"{model_map[model]}_{modality}.json")
    if os.path.exists(stat_file):
        with open(stat_file, "r") as f:
            return json.load(f)
    else:
        if not os.path.exists(func_name_dir):
            os.makedirs(func_name_dir)
        stat = get_modality_stat(func_name, model.split("/")[-1], modality)
        with open(stat_file, "w") as f:
            json.dump(stat, f, indent=4)
        return stat


def plot_number_scatter(func_name, model, colors):
    fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)
    axes = axes.flatten()
    print(f"\n### {model_map[model]}: ")
    for idx, question in enumerate(questions):
        print(f"# {question}: ")
        stat = load_stat("cache/overfit/triple", func_name, model, "real")
        ax = axes[idx]
        x, y = [], []
        x_min, x_max = 0, 30
        for val_str, count_dict in stat[question].items():
            val = int(val_str)
            num_1 = count_dict["1"]
            num_0 = count_dict["0"]
            total = num_1 + num_0
            if total <= 2:
                continue
            mean_val = num_1 / total
            x.append(val)
            y.append(mean_val)
        ax.scatter(x, y, color=colors[idx])
        ax.set_title(f"{question}")
        ax.grid(True)
        ax.set_xlim(x_min, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, max(1, x_max // 10))])
        ax.set_ylim(0, 1.05)
        if len(x) >= 2 and len(y) >= 2:
            pearson, pearson_p = pearsonr(x, y)
            spearman, spearman_p = spearmanr(x, y)
            print(f"Pearson:{pearson:.4f} ({pearson_p:.4f})\tSpearman:{spearman:.4f} ({spearman_p:.4f})")

    fig.suptitle(f"{model_map[model]}", fontsize=16, y=0.94)
    plt.tight_layout(rect=[0, 0, 1, 0.92])
    return fig


def plot_number_confidence(func_name, model, colors):
    fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)
    axes = axes.flatten()
    print(f"\n### {model_map[model]}: ")
    for idx, question in enumerate(questions):
        print(f"# {question}: ")
        stat = load_stat("cache/overfit/triple", func_name, model, "real")
        ax = axes[idx]
        (pearson, pearson_p), (spearman, spearman_p), mi = compute_metrics(stat[question])
        print(f"Pearson:{pearson:.4f} ({pearson_p:.4f})\tSpearman:{spearman:.4f} ({spearman_p:.4f})")
        legend_labels = [
            f"Pearson = {pearson:.4f}, p-value = {pearson_p:.4f}",
            f"Spearman = {spearman:.4f}, p-value = {spearman_p:.4f}",
            f"MI = {mi:.4f}"
        ]
        x_min, x_max = 5, 20
        bins, _ = get_bins(stat[question])
        bin_x, bin_mean, bin_std = [], [], []
        for b in bins:
            cnt_1 = sum(int(x[1] == 1) for x in b)
            cnt_0 = sum(int(x[1] == 0) for x in b)
            mean, std = mean_std(cnt_1, cnt_0)
            center_val = np.mean([x[0] for x in b])
            bin_x.append(center_val)
            bin_mean.append(mean)
            bin_std.append(std)
        bin_x = np.array(bin_x)
        bin_mean = np.array(bin_mean)
        bin_std = np.array(bin_std)
        ax.plot(bin_x, bin_mean, marker='o', color=colors[idx], linewidth=2)
        ax.fill_between(bin_x, bin_mean - bin_std, bin_mean + bin_std, color=colors[idx], alpha=0.3)
        ax.set_title(f"{question}")
        ax.set_xlim(x_min, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, max(1, x_max // 10))])
        ax.set_ylim(0, 1.5)
        ax.grid(True)
        ax.set_yticks([round(0.2 * i, 2) for i in range(9)])
        for label in legend_labels:
            ax.plot([], [], ' ', label=label)
        ax.legend(loc="upper left", fontsize=9, frameon=False)
    fig.suptitle(f"{model_map[model]}", fontsize=16, y=0.94)
    return fig


def e_triple(func_name, plot_type):
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    with PdfPages(f"evaluation/overfit/e_triple_{func_name}_sep_({plot_type}).pdf") as pdf:
        for model in full_model_list:
            plot_func = eval(f"plot_{func_name}_{plot_type}")
            fig = plot_func(func_name, model, colors)
            pdf.savefig(fig)
            plt.close(fig)


if __name__ == '__main__':
    metrics = [
        "number",
    ]
    for metric in metrics:
        # e_triple(metric, "scatter")
        e_triple(metric, "confidence")

