import os
import re
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages
from fixed_variables import *
from utils import *


def length(response):
    return len(response.split())

def lexical_diversity(response):
    return len(set(response.lower().split()))

def sentence_complexity(response, long_length=20):
    doc = nlp(response)
    depths = [len([token for token in sent]) for sent in doc.sents]
    max_depth = max(depths) if depths else 0
    long_sentences = sum(1 for sent in doc.sents if len(sent) > long_length)
    return max_depth, long_sentences

def ambiguity(response):
    vague_phrases = ["maybe", "possibly", "could be", "might be", "it seems", "probably"]
    count = sum(response.lower().count(phrase) for phrase in vague_phrases)
    return count

def redundancy(response):
    tokens = response.split()
    unique_tokens = set(tokens)
    return 1 - (len(unique_tokens) / len(tokens)) if tokens else 0

def repetitions(response):
    repetitions = response.split(". ")
    count = sum(repetitions.count(rep) > 1 for rep in repetitions)
    return count

def smaller_repetitions(response):
    sentence_endings = r"[,.!?\n]\s+(?=[A-Z])"
    sentences = re.split(sentence_endings, response)
    sentences = [s.strip() for s in sentences if s.strip()]
    count = sum(sentences.count(rep) > 1 for rep in sentences)
    return count


def get_data(base_dir):
    all_chunk_data = []
    chunk_dir = os.path.join(base_dir, "chunks")
    num_chunks = len(os.listdir(chunk_dir))
    for chunk_idx in range(num_chunks):
        file_name = os.path.join(chunk_dir, f"chunk_{chunk_idx}.json")
        with open(file_name, "r", encoding="utf-8") as f:
            all_chunk_data += json.load(f)
    return all_chunk_data


def get_modality_stat(func_name, modality):
    question_stat = defaultdict(lambda: defaultdict(lambda: {'1': 0, '0': 0}))
    for model in tqdm(full_model_list):
        model_self = model.split("/")[-1]
        response_data = get_data(os.path.join("benchmark", model_self, modality))
        comparison_df = pd.read_csv(os.path.join("comparison", model_self, f"{modality}.csv"))
        for question in questions:
            for idx, item in enumerate(response_data):
                analysis_str = item['solution'][question]['analysis']
                val = eval(func_name)(analysis_str)
                correctness = int(comparison_df.iloc[idx][question])
                question_stat[question][val][str(correctness)] += 1
    return {
        question: {str(val): stat_dict for val, stat_dict in sorted(question_stat[question].items(), key=lambda x: x[0])}
        for question in question_stat
    }


def load_stat(dir, func_name, modality):
    func_name_dir = os.path.join(dir, func_name)
    stat_file = os.path.join(func_name_dir, f"all_{modality}.json")
    if os.path.exists(stat_file):
        with open(stat_file, "r") as f:
            return json.load(f)
    else:
        if not os.path.exists(dir):
            os.makedirs(dir)
        stat = get_modality_stat(func_name, modality)
        with open(stat_file, "w") as f:
            json.dump(stat, f, indent=4)
        return stat


def plot_length_scatter(func_name, modality, colors):
    stat = load_stat("cache/overfit/response", func_name, modality)
    plot_data = {question: [] for question in questions}
    for question in questions:
        for val, count_dict in stat[question].items():
            num_1 = count_dict["1"]
            num_0 = count_dict["0"]
            mean_val = num_1 / (num_1 + num_0)
            plot_data[question].append((int(val), mean_val))

    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
    axes = axes.flatten()
    for idx, question in enumerate(questions):
        x = [item[0] for item in plot_data[question]]
        y = [item[1] for item in plot_data[question]]
        axes[idx].scatter(x, y, color=colors[idx])
        axes[idx].grid(True)
        axes[idx].set_title(question)
    fig.suptitle(f"Response Length - {modality}", fontsize=16, y=0.92)
    plt.tight_layout(rect=[0, 0, 1, 0.9])
    return fig


def plot_length_confidence(func_name, modality, colors):
    stat = load_stat("cache/overfit/response", func_name, modality)
    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
    axes = axes.flatten()
    for idx, question in enumerate(questions):
        (pearson, pearson_p), (spearman, spearman_p), mi = compute_metrics(stat[question])
        legend_labels = [
            f"Pearson = {pearson:.4f}, p-value = {pearson_p:.4f}",
            f"Spearman = {spearman:.4f}, p-value = {spearman_p:.4f}",
            f"MI = {mi:.4f}"
        ]
        x_max = max(map(int, stat[question].keys()))
        bins, _ = get_bins(stat[question])
        bin_x, bin_mean, bin_std = [], [], []
        for b in bins:
            cnt_1 = sum(int(x[1] == 1) for x in b)
            cnt_0 = sum(int(x[1] == 0) for x in b)
            mean, std = mean_std(cnt_1, cnt_0)
            center_val = np.mean([x[0] for x in b])
            bin_x.append(center_val)
            bin_mean.append(mean)
            bin_std.append(std)
        ax = axes[idx]
        bin_x = np.array(bin_x)
        bin_mean = np.array(bin_mean)
        bin_std = np.array(bin_std)
        ax.plot(bin_x, bin_mean, marker='o', color=colors[idx], linewidth=2)
        ax.fill_between(bin_x, bin_mean - bin_std, bin_mean + bin_std, color=colors[idx], alpha=0.3)
        ax.set_title(question)
        ax.set_xlim(0, x_max)
        ax.set_xticks([i for i in range(0, x_max + 1, x_max // 10)])
        ax.set_ylim(0, 1.5)
        ax.grid(True)
        ax.set_yticks([round(0.1 * i, 2) for i in range(16)])
        for label in legend_labels:
            ax.plot([], [], ' ', label=label)
        ax.legend(loc="upper left", fontsize=9, frameon=False)
    fig.suptitle(f"Response Length - {modality}", fontsize=16, y=0.92)
    return fig


def e_response(func_name, plot_type):
    colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red']
    with PdfPages(f"evaluation/overfit/e_response_{func_name}_({plot_type}).pdf") as pdf:
        plot_func = eval(f"plot_{func_name}_{plot_type}")
        fig_real = plot_func(func_name, "real", colors)
        pdf.savefig(fig_real)
        plt.close(fig_real)
        fig_syn = plot_func(func_name, "synthetic", colors)
        pdf.savefig(fig_syn)
        plt.close(fig_syn)


if __name__ == '__main__':
    metrics = [
        "length",
        # "lexical_diversity",
        # "sentence_complexity",
        # "ambiguity",
        # "redundancy",
        # "repetitions",
        # "smaller_repetitions",
    ]
    for metric in metrics:
        e_response(metric, "scatter")
        e_response(metric, "confidence")

