import json
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import linregress

from inference_rlhf.code.plot_teaser import TASK_TO_SHAPE, MODEL_TO_COLOR

FIGURES_DIR = "./figures"
QUANTILE = 0.1

sns.set_theme(style="whitegrid")

def plot_scatter(elliptical_path, vanilla_path, task, policy):
    with open(elliptical_path, "r") as f:
        elliptical_samples_to_get_correct = json.load(f)

    with open(vanilla_path, "r") as f:
        vanilla_samples_to_get_correct = json.load(f)

    # plot the number of samples to get correct for the elliptical coreset vs the vanilla coreset
    elliptical_samples_to_correct = []
    vanilla_samples_to_correct = []
    for k in elliptical_samples_to_get_correct.keys():
        elliptical_samples_to_correct.append(np.mean(elliptical_samples_to_get_correct[k]))
        vanilla_samples_to_correct.append(vanilla_samples_to_get_correct[k])

    elliptical_samples_to_correct = np.array(elliptical_samples_to_correct)
    vanilla_samples_to_correct = np.array(vanilla_samples_to_correct)

    XLIM = 5000

    # Calculate the line of best fit
    # slope, intercept = np.polyfit(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct), 1) # '1' indicates a linear fit (degree 1)

    slope, intercept, r_value, p_value, std_err = linregress(np.log(vanilla_samples_to_correct), np.log(elliptical_samples_to_correct))
    r_squared = r_value**2
    line_of_best_fit = np.exp(slope * np.log(np.linspace(1, XLIM, XLIM)) + intercept)

    # Plotting
    plt.plot(
        np.linspace(1, XLIM, XLIM),
        line_of_best_fit,
        color=MODEL_TO_COLOR[policy],
        label=f'$\log y = {slope:.2f} \log x + {intercept:.2f}$',
        linewidth=2.5  # Added line thickness
    )
    # plot y = x line
    plt.plot(np.linspace(1, XLIM, XLIM), np.linspace(1, XLIM, XLIM), color='black', label='y = x', linestyle='--')
    plt.scatter(
        vanilla_samples_to_correct,
        elliptical_samples_to_correct,
        s=40,
        marker=TASK_TO_SHAPE[task],
        facecolor=MODEL_TO_COLOR[policy],
        edgecolor='black',
        linewidths=0.5,  # Make the edgecolor thinner
        alpha=0.7  # Add opacity
    )
    plt.xscale("log")
    plt.yscale("log")
    # plt.title(f"{task}, {policy}: elliptical vs. vanilla samples-to-correct ($R^2 = {r_squared:.2f}$)")
    plt.xlabel("Random samples-to-correct", fontsize=16)
    plt.ylabel("Elliptical samples-to-correct", fontsize=16)
    plt.legend(fontsize=14)
    plt.savefig(os.path.join(os.path.dirname(elliptical_path), "scatter_plot.pdf"))
    plt.close()

def plot_scatter_sorted(elliptical_path, vanilla_path, vanilla_ref_path, task, policy):
    with open(elliptical_path, "r") as f:
        elliptical_samples_to_get_correct = json.load(f)

    with open(vanilla_path, "r") as f:
        vanilla_samples_to_get_correct = json.load(f)

    with open(vanilla_ref_path, "r") as f:
        vanilla_samples_to_get_correct_ref = json.load(f)

    sorted_prompt_idxs = sorted(vanilla_samples_to_get_correct_ref.keys(), key=vanilla_samples_to_get_correct_ref.get)
    xs = np.arange(1, len(sorted_prompt_idxs) + 1, 1)
    elliptical_ys = np.array([np.mean(elliptical_samples_to_get_correct[prompt_idx]) for prompt_idx in sorted_prompt_idxs])
    vanilla_ys = np.array([vanilla_samples_to_get_correct[prompt_idx] for prompt_idx in sorted_prompt_idxs])

    # plt.title(f"{task}, {policy}: elliptical vs. vanilla samples-to-correct (sorted)")
    plt.scatter(xs, vanilla_ys, s=10, label='Vanilla')
    plt.scatter(xs, elliptical_ys, s=10, label='Elliptical')
    plt.legend()
    plt.xlim(int((1 - QUANTILE) * len(xs)), len(xs))
    plt.xticks(np.arange(int((1 - QUANTILE) * len(xs)), len(xs) + 1, 5))
    plt.xlabel("Sorted prompt index")
    plt.ylabel("Samples-to-correct")
    plt.savefig(os.path.join(os.path.dirname(elliptical_path), f"scatter_sorted_{QUANTILE}.pdf"))
    plt.close()

def main():
    # get folders but no files in FIGURES_DIR
    tasks = [f for f in os.listdir(FIGURES_DIR) if os.path.isdir(os.path.join(FIGURES_DIR, f))]
    for task in tasks:
        if task == 'mbpp':
            policies = [f for f in os.listdir(os.path.join(FIGURES_DIR, task)) if os.path.isdir(os.path.join(FIGURES_DIR, task, f))]
            for policy in policies:
                elliptical_path = [f for f in os.listdir(os.path.join(FIGURES_DIR, task, policy)) if "elliptical" in f and f.endswith(".json")][0]
                vanilla_path = [f for f in os.listdir(os.path.join(FIGURES_DIR, task, policy)) if "vanilla" in f and not "ref" in f and f.endswith(".json")][0]
                vanilla_ref_path = [f for f in os.listdir(os.path.join(FIGURES_DIR, task, policy)) if "vanilla" in f and "ref" in f and f.endswith(".json")][0]
                plot_scatter(os.path.join(FIGURES_DIR, task, policy, elliptical_path), os.path.join(FIGURES_DIR, task, policy, vanilla_path), task, policy)
                # plot_scatter_sorted(os.path.join(FIGURES_DIR, task, policy, elliptical_path), os.path.join(FIGURES_DIR, task, policy, vanilla_path), os.path.join(FIGURES_DIR, task, policy, vanilla_ref_path), task, policy)

if __name__ == "__main__":
    main()