import os

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.lines import Line2D

def plot_value(df, v, gamma_list, beta_list, plotdir, format='png'): 
    plotdir = os.path.join('./plots', plotdir)
    os.makedirs(plotdir, exist_ok=True)
    figsize=(8, 6)
    df.rename(columns={'$r_1$': '$r_1^*$', 
                       '$r_2$': '$r_2^*$', 
                       '$r_3$': '$r_3^*$', 
                       '$r_1-r_2$': '$r_1^*-r_2^*$', 
                       '$r_1-r_3$': '$r_1^*-r_3^*$'}, 
              inplace=True)
    v_self = [v[frozenset({0})], v[frozenset({1})], v[frozenset({2})]]
    n = max(len(s) for s in v.keys())
    v_N = v[frozenset(range(n))]
    methods = [r'$\gamma$', r'$\beta$']
    hypers = [gamma_list, beta_list]
    colormaps = [plt.cm.Blues, plt.cm.Oranges]
    for (method, hyper) in zip(methods, hypers):
        fig, ax = plt.subplots(figsize=figsize)
        plot_abs(ax, df, v_self, v_N, values=hyper, method=method)
        plt.tight_layout()
        method_str = method.replace('\\', '').replace('$', '')
        plotname = f'value_{method_str}.{format}'
        plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
        plt.close(fig)
    for diff in ['$r_1^*-r_2^*$', '$r_1^*-r_3^*$']:
        fig, ax = plt.subplots(figsize=figsize)
        diff_str = diff.replace('$', '').replace('^', '').replace('*', '')
        plotname = f'value_{diff_str}.{format}'
        for (method, hyper, colormap) in zip(methods, hypers, colormaps):
            plot_diff(ax, df, values=hyper, method=method, y_label=diff, colormap=colormap)
        plt.tight_layout()
        plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
        plt.close(fig) 

def plot_abs(ax, df, v_self, v_N, values, method):
    parties = {'$r_1^*$', '$r_2^*$', '$r_3^*$'}
    colormaps = [plt.cm.Blues, plt.cm.Oranges, plt.cm.Greens]
    linestyles = ['-', 'dotted', 'dashed']
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']
    markersize = 13
    linewidth = 4.0
    
    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]           
    vmin = min(map_values) - 0.5
    vmax = 1 
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    scalar_map = {}   
    for (party, colormap) in zip(parties, colormaps):
        scalar_map[party] = cm.ScalarMappable(norm=color_norm, cmap=colormap)

    for (value, map_value, marker) in zip(values, map_values, markers):
        df_new = df[df[method] == value].copy()
        for (party, linestyle) in zip(parties, linestyles):
            ax.plot(df_new['$t$'], df_new[party], 
                    color=scalar_map[party].to_rgba(map_value), 
                    markersize=markersize, 
                    linewidth=linewidth, 
                    marker=marker, 
                    linestyle=linestyle)

    legends = []
    for (value, marker) in zip(values, markers):
        legends.append(Line2D([], [], 
                              color='black', 
                              linewidth=linewidth, 
                              markersize=markersize, 
                              marker=marker, 
                              label=f'{method}={value}'))
        
    label_color_value = map_values[1] # same as the second hyperpara
    for (party, linestyle) in zip(parties, linestyles):
        legends.append(Line2D([], [], 
                              color=scalar_map[party].to_rgba(label_color_value), 
                              linewidth=linewidth, 
                              linestyle=linestyle, 
                              label=party))
        

    ax.legend(handles=legends, ncol=2, prop={'size': 25}, loc='best')   
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_ylabel('Reward', fontsize=40)
    ax.set_xticks(df_new['$t$'].unique()) 
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)


    # horizontal lines
    plt.axhline(y=v_self[0], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_self[1], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_self[2], color='grey', alpha=0.3, linewidth=2)
    plt.axhline(y=v_N, color='grey', alpha=0.3, linewidth=2)

    yticks_ax2 = v_self + [v_N]
    ax2 = ax.twinx()
    ax2.set_ylim(ax.get_ylim())
    ax2.set_yticks(yticks_ax2)
    ax2.set_yticklabels(['$v_1$', '$v_2$', '$v_3$', '$v_N$'], fontsize=25)
    ax2.tick_params(axis='y', labelsize=30)
    parties = ['$r_1^*$', '$r_2^*$', '$r_3^*$']
    for (v_self_label, party) in zip(ax2.get_yticklabels(), parties):
        v_self_label.set_color(scalar_map[party].to_rgba(label_color_value))

def plot_diff(ax, df, values, method, y_label, colormap):
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']

    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]
    vmin = min(map_values) - 0.5
    vmax = 1 
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    scalar_map = cm.ScalarMappable(norm=color_norm, cmap=colormap)    
    
    for (value, map_value, marker) in zip(values, map_values, markers):
        df_new = df[df[method] == value].copy()
        ax.plot(df_new['$t$'], df_new[y_label], 
                color=scalar_map.to_rgba(map_value), 
                linewidth=4.0, 
                markersize=13, 
                marker=marker, 
                label=f'{method}={value}')
        
    ax.legend(ncol=2, prop={'size': 25}, loc='best')
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_xticks(df_new['$t$'].unique()) 
    ax.set_ylabel(y_label, fontsize=40)
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)

    plt.axhline(y=0, color='grey', alpha=0.3, linewidth=2, linestyle='dashed')

def plot_reward(ax, df, values, method, metric='mnlp'):
    markers = ['o', 's', '^', '.', '>', 's', 'v', 'x']
    parties = ['$m_1$', '$m_2$', '$m_3$']
    party_labels = ['Party 1', 'Party 2', 'Party 3']
    linestyles = ['-', 'dotted', 'dashed']
    colormaps = [plt.cm.Blues, plt.cm.Oranges, plt.cm.Greens]
    markersize = 13
    linewidth = 4.0

    n = len(values)
    del_ = (1 - 0.2) / n
    map_values = [(0.2 + i * del_) for i in list(range(n))]           
    vmin = min(map_values) - 0.5
    vmax = 1 
    color_norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    scalar_map = {}   
    for (party, colormap) in zip(parties, colormaps):
        scalar_map[party] = cm.ScalarMappable(norm=color_norm, cmap=colormap)
    for (value, map_value, marker) in zip(values, map_values, markers):
        df_new = df[(df[method] == value) & (df['index'] == metric)].copy()
        for (party, colormap, linestyle) in zip(parties, colormaps, linestyles):
            ax.plot(df_new['$t$'], df_new[party],
                    color=scalar_map[party].to_rgba(map_value), 
                    markersize=markersize, 
                    linewidth=linewidth, 
                    marker=marker, 
                    linestyle=linestyle)

    legends = []
    for (value, marker) in zip(values, markers):
        legends.append(Line2D([], [], 
                             color='black', 
                             linewidth=linewidth, 
                             markersize=markersize, 
                             marker=marker,
                             label=f'{method}={value}'))
 
    label_color_value = map_values[1]
    for (party, linestyle, party_label) in zip(parties, linestyles, party_labels):
        legends.append(Line2D([], [],
                              color=scalar_map[party].to_rgba(label_color_value),
                              linewidth=linewidth,
                              linestyle=linestyle,
                              label=party_label))

    ax.legend(handles=legends, ncol=2, prop={'size': 25}, loc='best')   
    ax.set_xlabel('$t_1$', fontsize=40)
    ax.set_ylabel(f'{metric.upper()}', fontsize=40)
    ax.set_xticks(df_new['$t$'].unique()) 
    ax.tick_params(axis='x', labelsize=30)
    ax.tick_params(axis='y', labelsize=30)

def plot_noise(df, gamma_list, beta_list, plotdir, format='png'):
    plotdir = os.path.join('./plots', plotdir)
    os.makedirs(plotdir, exist_ok=True)
    figsize=(8, 6)
    methods = [r'$\gamma$', r'$\beta$']
    metrics = ['mnlp', 'mse']
    hypers = [gamma_list, beta_list]
    for (method, hyper) in zip(methods, hypers):
        for metric in metrics:
            fig, ax = plt.subplots(figsize=figsize)
            plot_reward(ax, df, method=method, values=hyper, metric=metric)
            method_str = method.replace('\\','').replace('$','')
            plotname = f'noise_{method_str}_{metric}.{format}'
            plt.tight_layout()
            plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
            plt.close(fig)
        
def plot_subset(df, gamma_list, beta_list, plotdir, format='png'):
    plotdir = os.path.join('./plots', plotdir)
    os.makedirs(plotdir, exist_ok=True)
    figsize=(8, 6)
    methods = [r'$\gamma$', r'$\beta$']
    metrics = ['mnlp', 'mse']
    hypers = [gamma_list, beta_list]
    for (method, hyper) in zip(methods, hypers):
        for metric in metrics:
            fig, ax = plt.subplots(figsize=figsize)
            plot_reward(ax, df, method=method, values=hyper, metric=metric)
            method_str = method.replace('\\','').replace('$','')
            plotname = f'subset_{method_str}_{metric}.{format}'
            plt.tight_layout()
            plt.savefig(os.path.join(plotdir, plotname), format=format, bbox_inches='tight', dpi=300)
            plt.close(fig) 
