import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import random
import os
import copy
import json

# visualization setting
title_fontsize = 14
legend_font=12

def draw_optimal_policy(save_dir):
    all_data = np.zeros([4,10])
    for i in range(4):
        all_data[i,i]=1
    fig, ax = plt.subplots()
    cax = ax.matshow(all_data, cmap='viridis')  # 使用'viridis'色彩映射，看起来比较舒服

    # Make colorbar same height as the plot
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax_cb = divider.append_axes("right", size="5%", pad=0.05)
    cb = fig.colorbar(cax, cax=cax_cb)

    ax.set_xlabel('Response y')
    ax.set_ylabel('Prompt x')
    ax.set_title('Optimal Policy', fontsize=title_fontsize, pad=20)

    file_name = f'optimal_policy.png'
    file_path = os.path.join(save_dir, file_name)
    plt.savefig(file_path)
    plt.close()  

def plot_probability(average_chosen_prob, min_chosen_prob, average_reject_prob, max_reject_prob, max_others_prob, average_others_prob, save_dir):
    
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Probability')
    ax1.plot(average_chosen_prob, label='Average Chosen Probability')
    ax1.plot(min_chosen_prob, label='Minimum Chosen Probability')
    ax1.plot(average_reject_prob, label='Average Reject Probability')
    ax1.plot(max_reject_prob, label='Maximum Reject Probability')
    ax1.plot(average_others_prob, label='Average Unseen Probability')
    ax1.tick_params(axis='y')
    ax1.legend(loc='best', fontsize=legend_font)
    ax1.set_title('Probability Change', fontsize=title_fontsize)

    # ax2 = ax1.twinx()
    # ax2.set_ylabel('log probability ratio')
    # eps = np.finfo(float).eps
    # frac_plus_minus = [
    #     np.log(a + eps) - np.log(b + eps) 
    #     for a, b in zip(average_chosen_prob, average_reject_prob)
    # ]
    # ax2.plot(frac_plus_minus, label='log probability ratio of chosen and rejected responses')
    # ax2.tick_params(axis='y')
    # ax2.legend(loc='best')

    file_name = 'probability_varying.png'
    file_path = os.path.join(save_dir, file_name)
    os.makedirs(save_dir, exist_ok=True)
    plt.grid(True)
    plt.savefig(file_path)
    plt.close()

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Probability')
    ax1.plot(average_chosen_prob, label='Average Chosen Probability (log)')
    ax1.plot(min_chosen_prob, label='Minimum Chosen Probability (log)')
    ax1.plot(average_reject_prob, label='Average Reject Probability (log)')
    ax1.plot(max_reject_prob, label='Maximum Reject Probability (log)')
    ax1.plot(average_others_prob, label='Average Unseen Probability (log)')
    ax1.tick_params(axis='y')
    ax1.set_title('Log Probability Change')
    ax1.set_yscale('log')

    ax1.legend(loc='best', fontsize=legend_font)


    file_name = 'log_probability_varying.png'
    file_path = os.path.join(save_dir, file_name)
    os.makedirs(save_dir, exist_ok=True)
    plt.grid(True)
    plt.savefig(file_path)
    plt.close()

    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Probability')
    ax1.plot([i/average_chosen_prob[0] for i in average_chosen_prob], label='Average Chosen Probability (Relative)')
    ax1.plot([i/min_chosen_prob[0] for i in min_chosen_prob], label='Minimum Chosen Probability (Relative)')
    ax1.plot([i/average_reject_prob[0] for i in average_reject_prob], label='Average Reject Probability (Relative)')
    ax1.plot([i/max_reject_prob[0] for i in max_reject_prob], label='Maximum Reject Probability (Relative)')
    ax1.plot([i/average_others_prob[0] for i in average_others_prob], label='Average Unseen Probability (Relative)')
    ax1.tick_params(axis='y')
    ax1.set_title('Relative Probability Change')

    ax1.legend(loc='best', fontsize=legend_font)


    file_name = 'relative_probability_varying.png'
    file_path = os.path.join(save_dir, file_name)
    os.makedirs(save_dir, exist_ok=True)
    plt.grid(True)
    plt.savefig(file_path)
    plt.close()

    # Save data to JSON
    data = {
        'average_chosen_prob': np.asarray(average_chosen_prob).tolist(),
        'min_chosen_prob': np.asarray(min_chosen_prob).tolist(),
        'average_reject_prob': np.asarray(average_reject_prob).tolist(),
        'max_reject_prob': np.asarray(max_reject_prob).tolist(),
        'average_unseen_prob': np.asarray(max_others_prob).tolist(),
        'average_others_prob': np.asarray(average_others_prob).tolist(),
    }
    
    json_file_name = 'statisitic_prob_data.json'
    json_file_path = os.path.join(save_dir, json_file_name)
    with open(json_file_path, 'w') as json_file:
        json.dump(data, json_file)

def plot_gradient(chosen_prob_grads, rejected_prob_grads, beta1, beta2, save_dir):
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Probability')
    ax1.plot(np.abs(chosen_prob_grads), label='Average Gradient on pi_+ (log)')
    ax1.plot(np.abs(rejected_prob_grads), label='Average Gradient on pi_- (log)')
    ax1.tick_params(axis='y')
    ax1.set_title('Gradient')
    ax1.set_yscale('log')

    ax1.legend(loc='best', fontsize=legend_font)

    file_name = 'gradient.png'
    file_path = os.path.join(save_dir, file_name)
    os.makedirs(save_dir, exist_ok=True)
    plt.grid(True)
    plt.savefig(file_path)
    plt.close()

def draw_preference_dataset(preferences, save_dir):
    all_data = np.zeros([4, 10])
    for value in preferences.values():
        all_data[value[0], value[0]] = 1
        all_data[value[0], value[1]] = 1

    fig, ax = plt.subplots()
    cax = ax.matshow(all_data, cmap='coolwarm')

    # Make colorbar same height as the plot
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax_cb = divider.append_axes("right", size="5%", pad=0.05)
    cb = fig.colorbar(cax, cax=cax_cb)

    ax.set_xlabel('Response y')
    ax.set_ylabel('Prompt x')
    ax.set_title('Preference Dataset', fontsize=title_fontsize, pad=20)

    file_name = 'preference_dataset.png'
    file_path = os.path.join(save_dir, file_name)
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(file_path)
    plt.close()

def plot_heatmap(epoch, outputs, save_dir):
    probabilities = torch.exp(outputs).detach().numpy()
    
    # 创建图像并添加格点
    fig, ax = plt.subplots()
    cax = ax.matshow(probabilities, cmap='viridis')  # 使用'viridis'色彩映射，看起来比较舒服

    # Make colorbar same height as the plot
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax_cb = divider.append_axes("right", size="5%", pad=0.05)
    cb = fig.colorbar(cax, cax=cax_cb)

    # 在格子上添加数值
    for (i, j), val in np.ndenumerate(probabilities):
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', color='white')

    # 设置标题和坐标轴标签
    ax.set_xlabel('Response y')
    ax.set_ylabel('Prompt x')
    ax.set_title(f'Output Probabilities at Epoch {epoch}', fontsize=title_fontsize, pad=20)

    # 保存图像
    file_name = f'epoch_{epoch}.png'
    file_path = os.path.join(save_dir, file_name)
    plt.savefig(file_path)
    plt.close()  