import os
import pandas as pd
import seaborn as sns
from matplotlib.colors import LogNorm
from matplotlib.ticker import MaxNLocator
import math
import matplotlib.pyplot as plt
import numpy as np


def plot_regions(df, save_name='tmp', title=None):
    """Plot heatmap plot of safety probabilities for chosen and rejected responses."""
    plt.figure(figsize=(3.5, 3.5))

    n_grids = 10

    data = np.zeros((n_grids, n_grids))
    for _, row in df.iterrows():
        x = int(row["rejected_safety_score"] * n_grids)
        y = int(row["chosen_safety_score"] * n_grids)
        data[y, x] += 1

    cbar_ticks = [math.pow(10, i) for i in range(math.floor(math.log10(data.min().min())), 1 + math.ceil(math.log10(data.max().max())))]
    ax = sns.heatmap(data=data, cmap='PuBu', square=True, norm=LogNorm(), cbar_kws={'ticks': cbar_ticks, 'format': '%.e'})
    ax.invert_yaxis()

    # set ticklabels
    ax.tick_params(axis='x', labelsize=10)
    ax.tick_params(axis='y', labelsize=10)
    ax.set_xticks(np.arange(0, data.shape[1], 2) + 1.5)
    ax.set_xticklabels(np.round(np.linspace(1. / n_grids * 2, 1.0, n_grids // 2), 2))
    ax.set_yticks(np.arange(0, data.shape[1], 2) + 1.5)
    ax.set_yticklabels(np.round(np.linspace(1. / n_grids * 2, 1.0, n_grids // 2), 2))

    # # remove colorbar digit
    cbar = ax.collections[0].colorbar
    # cbar.ax.set_yticklabels([])
    # fix colorbar position & size
    cbar.ax.set_position([0.75, 0.2, 0.05, 0.6])  # left, bottom, width, height

    plt.xlabel('Safety prob. of $y_l$ (by MD-Judge)', fontsize=12)
    plt.ylabel('Safety prob. of $y_w$ (by MD-Judge)', fontsize=12)
    if title:
        plt.title(title)

    plt.savefig(f'figures/{save_name}.pdf', bbox_inches='tight')
    plt.savefig(f'figures/{save_name}.png', bbox_inches='tight')
    plt.close()


def main():
    if not os.path.exists('figures'):
        os.makedirs('figures')

    if os.path.exists('rlhf_with_eval.csv'):
        rlhf_merge_df = pd.read_csv('rlhf_with_eval.csv', index_col=0)
        plot_regions(rlhf_merge_df, save_name='safety_prob_scatter_md_judge', title=None)

        # categories = sorted(rlhf_merge_df['salad_category_0'].dropna().unique().tolist())
        # for cat in categories:
        #     cat_df = rlhf_merge_df[(rlhf_merge_df['salad_category_0'] == cat) | (rlhf_merge_df['salad_category_1'] == cat)].copy()
        #     cat_number = int(cat.split(':')[0][1:])
        #     cat_name = cat.split(':')[1][1:]
        #     plot_regions(cat_df, save_name=f'scatter_{cat_number}', title=f'[{cat_number}] {cat_name}', )
    else:
        print('Run eval_rlhf.py first to prepare the plot data.')


if __name__ == "__main__":
    main()
