import os
import re
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages
from fixed_variables import *
from utils import *


def length(response):
    return len(response.split())

def lexical_diversity(response):
    return len(set(response.lower().split()))

def sentence_complexity(response, long_length=20):
    doc = nlp(response)
    depths = [len([token for token in sent]) for sent in doc.sents]
    max_depth = max(depths) if depths else 0
    long_sentences = sum(1 for sent in doc.sents if len(sent) > long_length)
    return max_depth, long_sentences

def ambiguity(response):
    vague_phrases = ["maybe", "possibly", "could be", "might be", "it seems", "probably"]
    count = sum(response.lower().count(phrase) for phrase in vague_phrases)
    return count

def redundancy(response):
    tokens = response.split()
    unique_tokens = set(tokens)
    return 1 - (len(unique_tokens) / len(tokens)) if tokens else 0

def repetitions(response):
    repetitions = response.split(". ")
    count = sum(repetitions.count(rep) > 1 for rep in repetitions)
    return count

def smaller_repetitions(response):
    sentence_endings = r"[,.!?\n]\s+(?=[A-Z])"
    sentences = re.split(sentence_endings, response)
    sentences = [s.strip() for s in sentences if s.strip()]
    count = sum(sentences.count(rep) > 1 for rep in sentences)
    return count


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r", encoding="utf-8") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def get_modality_stat(func_name, model_self, modality):
    question_stat = defaultdict(lambda: defaultdict(lambda: {'1': 0, '0': 0}))
    response_data = get_data(os.path.join("benchmark", model_self, modality))
    comparison_df = pd.read_csv(os.path.join("comparison", model_self, f"{modality}.csv"))
    for question in questions:
        for idx, item in enumerate(response_data):
            analysis_str = item['solution'][question]['analysis']
            val = eval(func_name)(analysis_str)
            correctness = int(comparison_df.iloc[idx][question])
            question_stat[question][val][str(correctness)] += 1
    return {
        question: {str(val): stat_dict for val, stat_dict in sorted(question_stat[question].items(), key=lambda x: x[0])}
        for question in question_stat
    }


def load_stat(dir, func_name, model, modality):
    func_name_dir = os.path.join(dir, func_name)
    stat_file = os.path.join(func_name_dir, f"{model_map[model]}_{modality}.json")
    if os.path.exists(stat_file):
        with open(stat_file, "r") as f:
            return json.load(f)
    else:
        if not os.path.exists(func_name_dir):
            os.makedirs(func_name_dir)
        stat = get_modality_stat(func_name, model.split("/")[-1], modality)
        with open(stat_file, "w") as f:
            json.dump(stat, f, indent=4)
        return stat


def plot_length_scatter(func_name, model, colors):
    fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)
    axes = axes.flatten()
    print(f"\n### {model_map[model]}: ")
    for idx, question in enumerate(questions):
        print(f"# {question}: ")
        stat = load_stat("cache/overfit/response", func_name, model, "real")
        ax = axes[idx]
        x, y = [], []
        x_min, x_max = 0, 550
        for val_str, count_dict in stat[question].items():
            val = int(val_str)
            num_1 = count_dict["1"]
            num_0 = count_dict["0"]
            total = num_1 + num_0
            if total <= 1:
                continue
            mean_val = num_1 / total
            x.append(val)
            y.append(mean_val)
        ax.scatter(x, y, color=colors[idx])
        ax.set_title(f"{question}")
        ax.grid(True)
        ax.set_xlim(x_min, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, max(1, x_max // 10))])
        ax.set_ylim(0, 1.05)
        if len(x) >= 2 and len(y) >= 2:
            pearson, pearson_p = pearsonr(x, y)
            spearman, spearman_p = spearmanr(x, y)
            print(f"Pearson:{pearson:.4f} ({pearson_p:.4f})\tSpearman:{spearman:.4f} ({spearman_p:.4f})")

    fig.suptitle(f"{model_map[model]}", fontsize=16, y=0.94)
    plt.tight_layout(rect=[0, 0, 1, 0.92])
    return fig


def plot_length_confidence(func_name, model, colors):
    fig, axes = plt.subplots(2, 2, figsize=(16, 10), sharex=True, sharey=True)
    axes = axes.flatten()
    print(f"\n### {model_map[model]}: ")
    for idx, question in enumerate(questions):
        print(f"# {question}: ")
        stat = load_stat("cache/overfit/response", func_name, model, "real")
        ax = axes[idx]
        (pearson, pearson_p), (spearman, spearman_p), mi = compute_metrics(stat[question])
        print(f"Pearson:{pearson:.4f} ({pearson_p:.4f})\tSpearman:{spearman:.4f} ({spearman_p:.4f})")
        legend_labels = [
            f"Pearson = {pearson:.4f}, p-value = {pearson_p:.4f}",
            f"Spearman = {spearman:.4f}, p-value = {spearman_p:.4f}",
            f"MI = {mi:.4f}"
        ]
        x_min, x_max = 0, 550
        bins, _ = get_bins(stat[question])
        bin_x, bin_mean, bin_std = [], [], []
        for b in bins:
            cnt_1 = sum(int(x[1] == 1) for x in b)
            cnt_0 = sum(int(x[1] == 0) for x in b)
            mean, std = mean_std(cnt_1, cnt_0)
            center_val = np.mean([x[0] for x in b])
            bin_x.append(center_val)
            bin_mean.append(mean)
            bin_std.append(std)
        bin_x = np.array(bin_x)
        bin_mean = np.array(bin_mean)
        bin_std = np.array(bin_std)
        ax.plot(bin_x, bin_mean, marker='o', color=colors[idx], linewidth=2)
        ax.fill_between(bin_x, bin_mean - bin_std, bin_mean + bin_std, color=colors[idx], alpha=0.3)
        ax.set_title(f"{question}")
        ax.set_xlim(x_min, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, max(1, x_max // 10))])
        ax.set_ylim(0, 1.5)
        ax.grid(True)
        ax.set_yticks([round(0.2 * i, 2) for i in range(9)])
        for label in legend_labels:
            ax.plot([], [], ' ', label=label)
        ax.legend(loc="upper left", fontsize=9, frameon=False)
    fig.suptitle(f"{model_map[model]}", fontsize=16, y=0.94)
    return fig


def e_response(func_name, plot_type):
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    with PdfPages(f"evaluation/overfit/e_response_{func_name}_sep_({plot_type}).pdf") as pdf:
        for model in full_model_list:
            plot_func = eval(f"plot_{func_name}_{plot_type}")
            fig = plot_func(func_name, model, colors)
            pdf.savefig(fig)
            plt.close(fig)


if __name__ == '__main__':
    metrics = [
        "length",
        # "lexical_diversity",
        # "sentence_complexity",
        # "ambiguity",
        # "redundancy",
        # "repetitions",
        # "smaller_repetitions",
    ]
    for metric in metrics:
        # e_response(metric, "scatter")
        e_response(metric, "confidence")

