from nesim.utils.json_stuff import load_json_as_dict
import matplotlib.pyplot as plt
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

def plot_smoothness_vs_perplexity(data, binhuraib_smoothness_values, filename):
    # Line styles for the two datasets
    linestyles = {
        'openwebtext': 'dotted',   # dotted line for OpenWebText
        'bookcorpus': 'dashed',    # dashed line for BookCorpus
    }
    
    # Prepare the plot
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Plot the data for each dataset with different line styles
    for dataset_name, dataset_values in data.items():
        smoothness_values = []
        perplexity_values = []
        config_names = []
        for config_name, values in dataset_values.items():
            smoothness_values.append(values['smoothness'])
            perplexity_values.append(values['perplexity'])
            config_names.append(config_name)  # Keep track of model names
        
        # Scatter plot with dotted/dashed lines connecting the points
        ax.plot(smoothness_values, perplexity_values, linestyle=linestyles[dataset_name], marker='o', label=f'{dataset_name}')
        
        # Add text labels above each point
        for i, config_name in enumerate(config_names):
            ax.text(smoothness_values[i], perplexity_values[i] + 0.05, config_name, fontsize=9, ha='center')

    # Add vertical gray dotted lines for binhuraib_smoothness_values
    for name, smoothness_value in binhuraib_smoothness_values.items():
        ax.axvline(x=smoothness_value, color='gray', linestyle='dotted')  # Vertical line
        ax.text(smoothness_value, ax.get_ylim()[1] * 0.95, name, color='gray', fontsize=10, ha='center', rotation=90)

    # Set labels and title
    ax.set_xlabel('Smoothness', fontsize=14)
    ax.set_ylabel('Perplexity', fontsize=14)
    # ax.set_title('Smoothness vs Perplexity for OpenWebText and BookCorpus', fontsize=16)
    
    # Add legend
    ax.legend(fontsize=10)
    
    # Show grid and plot
    ax.grid(True)
    plt.tight_layout()
    plt.show()
    
    # Save the figure to the specified file
    fig.savefig(filename)

smoothness_results = load_json_as_dict("smoothness_data.json")
training_losses = {
    "openwebtext": {
        "untrained": 10.8564,
        "baseline": 4.7252,
        "topo_1": 4.9957,
        "topo_5": 4.6094,
        "topo_10": 4.6843,
        "topo_50": 4.7816
    },
    "bookcorpus": {
        "untrained": 10.8499,
        "baseline": 7.5694,
        "topo_1": 7.4134,
        "topo_5": 7.6125,
        "topo_10": 7.5371,
        "topo_50": 7.8460
    }
}

binhuraib_smoothness_values = {
    "queries": 0.62,
    "keys": 0.60,
    "fc_out": 0.11,
    "values": 0.16
}

plot_data = {}

for dataset_name in training_losses:
    plot_data[dataset_name] = {}
    for model_name in smoothness_results:
        smoothness_values = list(smoothness_results[model_name].values())
        mean_smoothness = sum(smoothness_values)/len(smoothness_values)

        plot_data[dataset_name][model_name] = {
            "perplexity": training_losses[dataset_name][model_name],
            "smoothness": mean_smoothness
        }
print(plot_data)
apply_ratan_matplotlib_thing()
plot_smoothness_vs_perplexity(
    data=plot_data,
    binhuraib_smoothness_values=binhuraib_smoothness_values,
    filename="smoothness_vs_loss.pdf"
)