import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

mpl.rcParams.update({
    'font.family': 'serif',
    'font.serif': 'Nimbus Roman',
    'mathtext.fontset': 'cm',
    'mathtext.rm': 'serif',
    'pdf.fonttype': 42,
    'ps.fonttype': 42
})

def plot_single(systems, savename=None, bgvars=(1, 0, 0, 0)):
    '''single variable, single classifier'''
    plot_grid(systems, savename, bgvars)

def plot_row(systems, savename=None, bgvars=(1, 0, 0, 0)):
    '''single variable; multiple classifiers'''
    plot_grid(systems, savename, bgvars)

def plot_vars(systems, savename=None, bgvars=(0, 1, 1, 1)):
    '''single classifier; multiple variables'''
    plot_grid(systems, savename, bgvars)
    raise NotImplementedError

def plot_grid(systems, savename=None, bgvars=(0, 1, 1, 1)):
    '''multiple classifiers; multiple variables'''

    l = len(systems)
    r = sum(bgvars)
    s = 2.3

    if sum(bgvars) == 1:
        # plot single pane
        if len(systems) == 1:
            f = plt.figure(
                figsize=(l * s + 0.2, r * s)
            )
            axsr = [[f.add_subplot(1, 1, 1)]]

        # plot row
        else:
            f, axsr = plt.subplots(
                sum(bgvars), l, sharey=True,
                figsize=(l * s + 0.2, r * s)
            )
            axsr = [axsr]

    # plot grid
    else:
        f, axsr = plt.subplots(
            sum(bgvars), l, sharey=True, sharex=True,
            figsize=(l * s + 0.2, r * s)
        )

    f.subplots_adjust(hspace=0.1, wspace=0.1)

    colormaps = ['', 'Blues', 'Reds', 'Greens']
    for row_var, row_on in enumerate(bgvars):
        if not row_on:
            continue

        row_num = sum(bgvars[:row_var])

        axs = axsr[row_num]

        try:
            iter(axs)
        except TypeError:
            axs = [axs]

        for ax in axs:
            ax.set_xlim(0,1)
            ax.set_ylim(0,1)
            ax.set_aspect('equal')

        if (row_num == sum(bgvars) - 1):
            for ax in axs:
                ax.set_xlabel('Group 1 qualification rate $s_1$')
                ax.xaxis.set_major_locator(ticker.FixedLocator([0.1, 0.3, 0.5, 0.7, 0.9]))
                ax.yaxis.set_major_locator(ticker.FixedLocator([0.1, 0.3, 0.5, 0.7, 0.9]))


        axs[0].set_ylabel('Group 2 qualification rate $s_2$')

        colormin = 1
        colormax = 0

        # Plot each system
        for i, s in enumerate(systems):

            if (row_num == 0):
                axs[i].set_title(s.name)

            if (row_var == 0):
                continue

            elif (row_var == 1):
                color_array = s.A1
            elif (row_var == 2):
                color_array = s.fpr1
            elif (row_var == 3):
                color_array = s.fnr1

            colormin = np.round(min(colormin, np.min(color_array)), 2)
            colormax = np.round(max(colormax, np.max(color_array)), 2)

        for i, s in enumerate(systems):

            thickness = np.sqrt(s.Vx*s.Vx + s.Vy*s.Vy) * 2.5 + 0.2

            axs[i].streamplot(
                s.xx, s.yy,
                s.Vx,
                s.Vy,
                color='black',
                linewidth=0.6,
                arrowsize=0.8
            )

            if (row_var == 0):
                continue
            elif (row_var == 1):
                color_array = s.A1
                cb_label = 'Group 1 acceptance rate \n $\Pr(\hat{Y} = 1 \mid G = 1)$'
            elif (row_var == 2):
                color_array = s.fpr1
                cb_label = 'Group 1 false positive rate \n $\Pr(\hat{Y} = 1 \mid Y = 0, G = 1)$'
            elif (row_var == 3):
                color_array = s.fnr1
                cb_label = 'Group 1 false negative rate \n $\Pr(\hat{Y} = 0 \mid Y = 1, G = 1)$'

            cs = axs[i].contourf(
                s.x, s.y, color_array,
                cmap=plt.get_cmap(colormaps[row_var]),
                levels=np.array(np.linspace(colormin, colormax, 9)),
                alpha=0.8,
            )

        if (row_var != 0):
            cb = f.colorbar(
                cs, ax=axs, fraction=0.046/len(systems), pad=0.04/len(systems),
                ticks=[colormin, colormax],
                ticklocation='left'
            )
            axs[-1].text(
                1.15, 0.5,
                cb_label,
                rotation=270,
                rotation_mode='anchor',
                horizontalalignment='center',
                verticalalignment='baseline',
                multialignment='center'
            )

        # plot equal qualification line
        for ax in axs:
            ax.plot(
                [0.02, 0.98], [0.02, 0.98], color='black',
                linewidth=3.5, label='Equal qualification rates'
            )

    # Set legend location for equal qualification line
    # https://stackoverflow.com/questions/4700614/how-to-put-the-legend-out-of-the-plot/43439132#43439132

    axs = axsr[0]
    try:
        iter(axs)
    except TypeError:
        axs = [axs]

    axs[-1].legend(
        bbox_to_anchor=(0.6, 1.06),
        loc='lower center',
        frameon=False
    )

    # display or save plot
    if savename is None:
        plt.show()
    else:
        filename = f'images/{savename}.pdf'
        print('saving', filename)
        plt.savefig(filename, bbox_inches='tight')
