import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import numpy as np
import json
import matplotlib.pyplot as plt
from utils import *
from find_threshold import *


# 1. count words
def count_words(data):

    prompt_words = 0
    response_words = 0
    for entry in data:
        if "strong_llm_prompt" in entry and isinstance(entry["strong_llm_prompt"], str):
            prompt_words += len(entry["strong_llm_prompt"].split())
        if "strong_llm_response" in entry and isinstance(entry["strong_llm_response"], str):
            response_words += len(entry["strong_llm_response"].split())
    total_words = prompt_words + response_words

    return prompt_words, response_words, total_words

data_our_pipeline_dir = "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy.json"
with open(data_our_pipeline_dir, "r", encoding="utf-8") as f:
    data_our = json.load(f)
prompt_words_our, response_words_our, total_words_our = count_words(data_our)

data_jung_pipeline_dir = "data/gsm_plus/test_gsm_plus_jung_pipeline.json"
with open(data_jung_pipeline_dir, "r", encoding="utf-8") as f:
    data_jung = json.load(f)
prompt_words_jung, response_words_jung, total_words_jung = count_words(data_jung)

print("=== Our Pipeline ===")
print("Number of data points:", len(data_our))
print("Prompt words:", prompt_words_our)
print("Response words:", response_words_our)
print("Total words:", total_words_our)
print("Total tokens:", total_words_our/0.75)
total_price_our = prompt_words_our/0.75/1000000*0.1+response_words_our/0.75/1000000*0.4
print("Total price:", total_price_our)

print("\n=== Jung Pipeline ===")
print("Number of data points:", len(data_jung))
print("Prompt words:", prompt_words_jung)
print("Response words:", response_words_jung)
print("Total words:", total_words_jung)
print("Total tokens:", total_words_jung/0.75)
total_price_jung = prompt_words_jung/0.75/1000000*0.1+response_words_jung/0.75/1000000*0.4
print("Total price:", total_price_jung)

print("\n=== Overall ===")
print("Token reduction percentage:", (total_words_jung - total_words_our) / total_words_jung)  # 29.98%
print("Price reduction percentage:", (total_price_jung - total_price_our) / total_price_jung)  # 30.41%
print("")


# 2. plot figures
data_our_pipeline_dir = "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy.json"
data_our_pipeline_random_dir = "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy_random_strategy.json" 
data_jung_pipeline_dir = "data/gsm_plus/test_gsm_plus_jung_pipeline.json"

stats_our = stats_pipeline_category(data_our_pipeline_dir, category_name="perturbation_type")
stats_our["all"] = stats_pipeline(data_our_pipeline_dir)

stats_our_random = stats_pipeline_category(data_our_pipeline_random_dir, category_name="perturbation_type")
stats_our_random["all"] = stats_pipeline(data_our_pipeline_random_dir)

stats_jung = stats_pipeline_category(data_jung_pipeline_dir, category_name="perturbation_type")
stats_jung["all"] = stats_pipeline(data_jung_pipeline_dir)

categories = list(stats_jung.keys())

# 2.1. overall accuracy
jun_acc = []
our_acc = []
our_acc_random = []
for cat in categories:
    jun_total = stats_jung[cat]["total_problems"] - stats_jung[cat]["unaccept_problems"]
    our_total = stats_our[cat]["total_problems"] - stats_our[cat]["unaccept_problems"]
    our_total_random = stats_our_random[cat]["total_problems"] - stats_our_random[cat]["unaccept_problems"]

    jun_correct = stats_jung[cat]["weak_llm_used_correct"] + stats_jung[cat]["strong_llm_used_correct"]
    our_correct = stats_our[cat]["weak_llm_used_correct"] + stats_our[cat]["strong_llm_used_correct"]
    our_correct_random = stats_our_random[cat]["weak_llm_used_correct"] + stats_our_random[cat]["strong_llm_used_correct"]

    jun_acc.append(jun_correct / jun_total)
    our_acc.append(our_correct / our_total)
    our_acc_random.append(our_correct_random / our_total_random)

x = np.arange(len(categories))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width, jun_acc, width, label="jung")
rects2 = ax.bar(x, our_acc_random, width, label="our random")
rects3 = ax.bar(x + width, our_acc, width, label="our")
ax.set_ylabel("Pipeline Accuracy")
#ax.set_title("Accuracy comparison per category")
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.legend()
ax.bar_label(rects1, labels=[f"{v:.2%}" for v in jun_acc], padding=3)
ax.bar_label(rects2, labels=[f"{v:.2%}" for v in our_acc_random], padding=3)
ax.bar_label(rects3, labels=[f"{v:.2%}" for v in our_acc], padding=3)
plt.tight_layout()
plt.show()

# 2.2. number of strong llm calls
jun_call = []
our_call = []
our_call_random = []
for cat in categories:
    jun_call.append(stats_jung[cat]["strong_llm_used"] / stats_jung[cat]["total_problems"])
    our_call.append(stats_our[cat]["strong_llm_used"] / stats_our[cat]["total_problems"])
    our_call_random.append(stats_our_random[cat]["strong_llm_used"] / stats_our_random[cat]["total_problems"])

x = np.arange(len(categories))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width, jun_call, width, label="jung")
rects2 = ax.bar(x, our_call_random, width, label="our random")
rects3 = ax.bar(x + width, our_call, width, label="our")
ax.set_ylabel("Strong LLM Calls")
#ax.set_title("Accuracy comparison per category")
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.legend()
ax.bar_label(rects1, labels=[f"{v:.2%}" for v in jun_call], padding=3)
ax.bar_label(rects2, labels=[f"{v:.2%}" for v in our_call_random], padding=3)
ax.bar_label(rects3, labels=[f"{v:.2%}" for v in our_call], padding=3)
plt.tight_layout()
plt.show()

# 2.3. weak accuracy
jun_acc = []
our_acc = []
our_acc_random = []
for cat in categories:
    jun_total = stats_jung[cat]["total_problems"] - stats_jung[cat]["unaccept_problems"]
    our_total = stats_our[cat]["total_problems"] - stats_our[cat]["unaccept_problems"]
    our_total_random = stats_our_random[cat]["total_problems"] - stats_our_random[cat]["unaccept_problems"]

    jun_correct = stats_jung[cat]["weak_total_correct"]
    our_correct = stats_our[cat]["weak_total_correct"]
    our_correct_random = stats_our_random[cat]["weak_total_correct"]

    jun_acc.append(jun_correct / jun_total)
    our_acc.append(our_correct / our_total)
    our_acc_random.append(our_correct_random / our_total_random)

x = np.arange(len(categories))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width, jun_acc, width, label="jung")
rects2 = ax.bar(x, our_acc_random, width, label="our random")
rects3 = ax.bar(x + width, our_acc, width, label="our")
ax.set_ylabel("Weak LLM Accuracy")
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.legend()
ax.bar_label(rects1, labels=[f"{v:.2%}" for v in jun_acc], padding=3)
ax.bar_label(rects2, labels=[f"{v:.2%}" for v in our_acc_random], padding=3)
ax.bar_label(rects3, labels=[f"{v:.2%}" for v in our_acc], padding=3)
plt.tight_layout()
plt.show()

# 2.4. number of correct answers accepted by weak llm
jun_acc = []
our_acc = []
our_acc_random = []
for cat in categories:
    jun_acc.append(stats_jung[cat]["weak_llm_used_correct"] / (stats_jung[cat]["total_problems"] - stats_jung[cat]["unaccept_problems"]))
    our_acc.append(stats_our[cat]["weak_llm_used_correct"] / (stats_our[cat]["total_problems"] - stats_our[cat]["unaccept_problems"]))
    our_acc_random.append(stats_our_random[cat]["weak_llm_used_correct"] / (stats_our_random[cat]["total_problems"] - stats_our_random[cat]["unaccept_problems"]))

x = np.arange(len(categories))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width, jun_acc, width, label="jung")
rects2 = ax.bar(x, our_acc_random, width, label="our random")
rects3 = ax.bar(x + width, our_acc, width, label="our")
ax.set_ylabel("Weak LLM Accepted Correct Problems")
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.legend()
ax.bar_label(rects1, labels=[f"{v:.2%}" for v in jun_acc], padding=3)
ax.bar_label(rects2, labels=[f"{v:.2%}" for v in our_acc_random], padding=3)
ax.bar_label(rects3, labels=[f"{v:.2%}" for v in our_acc], padding=3)
plt.tight_layout()
plt.show()

# 2.5. coverage rate
jun_cov = []
our_cov = []
our_cov_random = []
for cat in categories:
    jun_total = stats_jung[cat]["total_problems"]
    jun_unaccept = stats_jung[cat]["unaccept_problems"]
    jun_cov.append((jun_total - jun_unaccept) / jun_total)

    our_total = stats_our[cat]["total_problems"]
    our_unaccept = stats_our[cat]["unaccept_problems"]
    our_cov.append((our_total - our_unaccept) / our_total)

    our_total_random = stats_our_random[cat]["total_problems"]
    our_unaccept_random = stats_our_random[cat]["unaccept_problems"]
    our_cov_random.append((our_total_random - our_unaccept_random) / our_total_random)

x = np.arange(len(categories))
width = 0.25
fig, ax = plt.subplots(figsize=(12, 6))
rects1 = ax.bar(x - width, jun_cov, width, label="jung")
rects2 = ax.bar(x, our_cov_random, width, label="our random")
rects3 = ax.bar(x + width, our_cov, width, label="our")
ax.set_ylabel("Coverage Rate")
ax.set_xticks(x)
ax.set_xticklabels(categories, rotation=45, ha="right")
ax.legend()
ax.bar_label(rects1, labels=[f"{v:.2%}" for v in jun_cov], padding=3)
ax.bar_label(rects2, labels=[f"{v:.2%}" for v in our_cov_random], padding=3)
ax.bar_label(rects3, labels=[f"{v:.2%}" for v in our_cov], padding=3)
plt.tight_layout()
plt.show()


# 3. condifence vs. accuracy
def accuracy_vs_condifence(confidences, labels, num_thresholds=101):
    """
    Plot the proportion of correct predictions for varying confidence thresholds.
    """
    confidences = np.asarray(confidences)
    labels = np.asarray(labels)

    thresholds = np.linspace(0, 1, num_thresholds)
    accuracies = []
    for t in thresholds:
        mask = confidences >= t
        if mask.any():
            accuracies.append(labels[mask].mean())
        else:
            accuracies.append(np.nan)
    return thresholds, accuracies

def load_accuracy_vs_confidence(json_path, num_thresholds=101):
    """
    Load prediction results from a JSON file and compute accuracy vs. confidence curve.
    
    Args:
        json_path (str): Path to the JSON file.
        num_thresholds (int): Number of thresholds for evaluation.
        
    Returns:
        thresholds (np.ndarray): Thresholds from 0 to 1.
        accuracies (list): Accuracy values corresponding to thresholds.
    """
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    weak_log_confidence = [
        item["weak_llm_answer_logprob_sum"] 
        for item in data if item["weak_llm_strategy"] != "No strategy"
    ]
    weak_confidence, x_min, rng = minmax01(exp_from_logprobs(weak_log_confidence))

    weak_prediction = [
        int(item["weak_llm_accuracy"]) 
        for item in data if item["weak_llm_strategy"] != "No strategy"
    ]
    
    thresholds, accuracies = accuracy_vs_condifence(
        weak_confidence, weak_prediction, num_thresholds=num_thresholds
    )
    return weak_confidence, thresholds, accuracies

def plot_confidence_hist_three(datasets, colors, labels, save_path):
    """
    Plot three histograms side by side using subplots,
    each with 10 bins, with a shared y-axis.
    
    Args:
        datasets (list of np.ndarray): List of three arrays containing confidence values.
        colors (list of str): Colors for each histogram.
        labels (list of str): Labels for each histogram (methods).
        save_path (str): File path to save the figure.
    """
    plt.rcParams.update({"font.size": 15})
    fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True)

    num_bins = 10
    bin_edges = np.linspace(0, 1, num_bins + 1)

    for i, (data, color, label) in enumerate(zip(datasets, colors, labels)):
        counts, _ = np.histogram(data, bins=bin_edges)
        proportions = counts / counts.sum()

        axs[i].bar(bin_edges[:-1], proportions, width=0.1, align='edge',
                   color=color, edgecolor="black", alpha=0.8, label=label)

        axs[i].set_title(label, fontsize=15)
        axs[i].set_xlabel("Confidence")
        axs[i].set_xticks(np.arange(0, 1.1, 0.1))
        axs[i].tick_params(axis='x', rotation=45)

    axs[0].set_ylabel("Proportion")
    axs[0].set_ylim(0, 0.55)

    plt.tight_layout()
    plt.savefig(save_path, dpi=300, transparent=True, bbox_inches="tight")
    plt.close()

def plot_confidence_hist(data, color, name):
    plt.rcParams.update({"font.size": 15})
    plt.figure(figsize=(6,4))
    counts, bins, patches = plt.hist(
        data, bins=10, range=(0,1), 
        edgecolor="black", color=color, alpha=0.7
    )
    total = counts.sum()
    for rect, count in zip(patches, counts):
        rect.set_height(count / total)
    plt.ylim(0, 1)
    plt.xlabel("Confidence")
    plt.ylabel("Proportion")
    plt.tight_layout()
    plt.savefig(f"data/gsm_plus/gsm_plus_confidence_{name}.png", dpi=300, transparent=True, bbox_inches="tight")
    plt.close()

weak_confidence_weak_llm_gsm_plus, thresholds_weak_llm_gsm_plus, accuracies_weak_llm_gsm_plus = load_accuracy_vs_confidence(
    "data/gsm_plus/calibration_gsm_plus_weak_gpt3.5turbo_llm.json"
)
weak_confidence_our_pipeline_random_gsm_plus, thresholds_our_pipeline_random_gsm_plus, accuracies_our_pipeline_random_gsm_plus = load_accuracy_vs_confidence(
    "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy_random_strategy.json"
)
weak_confidence_our_pipeline_gsm_plus, thresholds_our_pipeline_gsm_plus, accuracies_our_pipeline_gsm_plus = load_accuracy_vs_confidence(
    "data/gsm_plus/test_gsm_plus_our_pipeline_strong_no_strategy.json"
)

colors = plt.get_cmap("Set2").colors 
plt.rcParams.update({"font.size": 15})
plt.figure(figsize=(6, 4))
plt.plot(thresholds_weak_llm_gsm_plus, accuracies_weak_llm_gsm_plus, label="Weak LLM", linestyle="-", color="#073E7F", linewidth=2)
plt.plot(thresholds_our_pipeline_random_gsm_plus, accuracies_our_pipeline_random_gsm_plus, label="Our Pipeline (Random)", linestyle="-", color="#BE0E23", linewidth=2)
plt.plot(thresholds_our_pipeline_gsm_plus, accuracies_our_pipeline_gsm_plus, label="Our Pipeline (Retrieval)", linestyle="-", color="#2E5C41", linewidth=2)
plt.xlabel("Confidence Threshold")
plt.ylabel("Accuracy")
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.grid(True)
plt.legend(fontsize=11)
plt.tight_layout()
plt.savefig("data/gsm_plus/gsm_plus_accuracy_vs_confidence.png", dpi=300, transparent=True, bbox_inches="tight")
plt.close()

plot_confidence_hist_three(
    [weak_confidence_weak_llm_gsm_plus, 
     weak_confidence_our_pipeline_random_gsm_plus, 
     weak_confidence_our_pipeline_gsm_plus],
    ["#073E7F", "#BE0E23", "#2E5C41"],
    ["Weak LLM", "Our Pipeline (Random)", "Our Pipeline (Retrieval)"],
    "data/gsm_plus/gsm_plus_confidence_comparison.png"
)

plot_confidence_hist(weak_confidence_weak_llm_gsm_plus, "#073E7F", "weak_llm")
plot_confidence_hist(weak_confidence_our_pipeline_random_gsm_plus, "#BE0E23", "our_pipeline_random")
plot_confidence_hist(weak_confidence_our_pipeline_gsm_plus, "#2E5C41", "our_pipeline")

