import json
from pathlib import Path
from math import isnan
from cycler import cycler
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

figure_path = Path(__file__).resolve().parent.parent / "figures"
result_file = figure_path / "large_sample_results.json"
with open(result_file, "r") as fp:
    result_obj = json.load(fp)
x_values_dict = result_obj["x_values_dict"]
results = result_obj["results"]
for llm_name, llm_res in results.items():
    for metric, metric_res in llm_res.items():
        for pos, pos_lst in metric_res.items():
            for i in range(len(pos_lst)-1, -1, -1):
                if isnan(pos_lst[i]):
                    pos_lst[i] = pos_lst[i+1] # impute nan

def plot_confidence_lines(data_dict, x_values_dict, output_dir="plots",
                         xlabel="X Axis", figsize=(8, 5), dpi=300):
    Path(output_dir).mkdir(exist_ok=True)
    
    sns.set_style("whitegrid")
    plt.rcParams.update({
        "font.family": "serif",
        "font.serif": ["Times New Roman"],
        "font.size": 18,
        "axes.labelsize": 12,
        "xtick.labelsize": 18,
        "ytick.labelsize": 18,
        "legend.fontsize": 18,
        "figure.dpi": dpi,
        "savefig.dpi": dpi,
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.4,
        "axes.prop_cycle": cycler(color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b'])
    })
    
    metrics = next(iter(data_dict.values())).keys()
    
    for metric in metrics:
        plt.figure(figsize=figsize)
        x_values = x_values_dict[metric]
        for line_name, inner_dict in data_dict.items():
            mean_data = inner_dict[metric]['mean']
            lower_data = inner_dict[metric]['lower']
            upper_data = inner_dict[metric]['upper']
            llm_name = line_name.split("_")[0]
            line = plt.plot(x_values, mean_data,
                          label=llm_name,
                          marker='o',
                          markersize=5,
                          linewidth=2,
                          markeredgecolor='w',
                          markeredgewidth=0.5)
            
            fill_color = line[0].get_color()
            plt.fill_between(x_values, lower_data, upper_data, color=fill_color, alpha=0.2)

        ax = plt.gca()
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_linewidth(0.8)
        ax.spines['bottom'].set_linewidth(0.8)
        plt.grid(axis='y', linewidth=0.5, alpha=0.6, linestyle='--')
        # -------------------
        if metric.startswith("metric_1"):
            plt.xticks(np.arange(0, 101, 10))
            plt.ylim(0.495, 0.69)
        plt.xlabel(xlabel)
        plt.ylabel(metric)
        plt.legend(frameon=True, framealpha=0.9, loc='best')
        plt.tight_layout()

        filename = f"{metric.replace(' ', '_').lower()}_ci.png"
        plt.savefig(Path(output_dir) / filename, bbox_inches='tight')
        plt.close()

plot_confidence_lines(
    data_dict=results,
    x_values_dict=x_values_dict,
    output_dir="figures",
    xlabel="Number of samples"
)