import matplotlib
from matplotlib import pyplot as plt
from matplotlib import rc, markers
from matplotlib.ticker import ScalarFormatter
import torch
import numpy as np

plt.interactive(False)

plot_individual = False
plot_average =  False
plot_per_iteration = False
plot_indexes = False
plot_time_time = True
plot_min_min = False
plot_avg_step = False
plot_box = True

rc('font', **{'family':'serif', 'serif':['Times'], 'size':16})
rc('text', usetex=True)
rc('axes', labelsize=20)
rc('figure', autolayout=True)


# Pick up the file
filename = 'results/zip_res_nh_1_units_20_N_1000_mnist_3.pt'
# filename = 'results/zip_res_nh_2_units_400_N_1000_mnist_3.pt'
# filename = 'results/zip_res_nh_4_units_800_N_1000_mnist_3.pt'
#filename = 'results/zip_ep50_res_nh_1_units_20_N_1000_mnist_3.pt'
# filename = 'results/zip_ep50_res_nh_2_units_400_N_1000_mnist_3.pt'
# filename = 'results/zip_ep50_res_nh_4_units_800_N_1000_mnist_3.pt'

data_in = torch.load(filename,map_location=torch.device('cpu'))
we = 3
crits = [ 1e-5, 1e-6, 1e-7] #% 1e-2, 1e-3,  1e-4,
#crits = [ 1e-3, 1e-4, 1e-5] #% 1e-2, 1e-3,  1e-4,


loss = data_in['res'][:,:,0,:]
t = data_in['res'][:,:,1,:]
gnorm = data_in['res'][:,:,2,:]
steps = data_in['res'][:,:,3,:]
opt = data_in['optimizers']

plot_optms =   range(t.size(0))#  [0, 1, 5, 6, 7] # range(t.size(0))
if plot_individual:
    for task, xlim in zip(range(0,t.size(2)),3*torch.ones(t.size(2))): # zip([41, 63, 183],[1.0,1.0,1.5]): #
        plt.figure(task,figsize=[5.4,5.0]).set_tight_layout(False)
        p_loss = plt.subplot(2,1,1, position=[0.20, 0.5, 0.75, 0.45])
        p_step = plt.subplot(2,1,2, position=[0.20, 0.15, 0.75, 0.25])
        for i in plot_optms:
            valid = (t[i, :, task]>0) & (t[i, :, task]<xlim)
            if plot_per_iteration:
                p_loss.semilogy(loss[i, valid, task], label=opt[i][2][:10])
            else:
                p_loss.semilogy(t[i,valid,task], loss[i,valid,task], label=opt[i][2][:10], ls='--', marker='o',markersize=3)
        p_loss.set_ylabel(f'$f(x_k)$, task = {task}')
        if plot_per_iteration:
            p_loss.set_xlabel('$k$')
        else:
            p_step.set_xlabel('time')
        if task == 41:
            p_loss.legend(loc='lower left')
        valid = (t[we-1, :, task]>0) & (t[we-1, :, task]<xlim)
        p_step.semilogy(t[we-1,valid,task][:-1],steps[we-1,valid,task][:-1], ls='--', marker='o',markersize=3,color='tab:green')
        valid = (t[we, :, task]>0) & (t[we, :, task]<xlim)
        p_step.semilogy(t[we,valid,task][:-1],steps[we,valid,task][:-1] , ls='--', marker='o',markersize=3,color='tab:red')

        p_step.axis
        p_loss.set_xlim([0-0.01,xlim+0.01])
        p_step.set_xlim([0-0.01,xlim+0.01])

        p_step.yaxis.set_ticks([1e-1,1])
        p_step.set_ylabel('$t_k$')

        plt.show(block=True)

if plot_average:
    plt.figure()
    exclusions = []
    indexes = torch.arange(loss.shape[2])
    slicer = torch.as_tensor(indexes).bool() + True
    for e in exclusions:
        slicer = slicer * (indexes != e)

    for i in plot_optms:
        if plot_per_iteration:
            plt.semilogy(loss[i, :, slicer].mean(1), label=opt[i][2])
        else:
            plt.semilogy(t[i, :, slicer].mean(1), loss[i, :, slicer].mean(1), label=opt[i][2], ls='--', marker='o')
    plt.ylabel(f'$f(x_k)$, averaged')
    if plot_per_iteration:
        plt.xlabel('$k$')
        plt.xlim(0,10)
    else:
        plt.xlabel('$t$')
    plt.legend(loc='upper right')
    plt.savefig(f'figs/f_averaged.pdf')
    plt.show(block=True)



if plot_indexes:
    ref = 5
    def value_at_time(time, task_n):
        try:
            idx = max((t[ref, :, task_n] > time).nonzero()[0, 0]-1,0)
        except IndexError:
            idx=t[ref, :, task_n].size(0)-1
        return loss[ref,idx,task_n]


    for i in range(t.size(0)):
        plt.figure(i)
        aux = torch.zeros_like(loss[i, :, :])
        for task in range(t.size(2)):
            for k in range(aux.size(0)):
                aux[k,task] = (loss[i, k, task] / value_at_time(t[i, k, task], task))
            plt.semilogy(t[i, :, task], aux[:,task], label=opt[i][2])

        p = plt.semilogy(t[i, :, :].mean(1) , aux[:, :].mean(1), linewidth=4)
        tks = plt.yticks()
        plt.grid()
        plt.yticks(np.append(tks[0], 1.e0))
        plt.ylabel(f'$I_k$, optimizer = {opt[i][2]}')
        if plot_per_iteration:
            plt.xlabel('$k$')
        else:
            plt.xlabel('$t$')

        plt.savefig(f'figs/f_ik_opt_{opt[i][2]}.pdf')
        plt.show(block=True)



def plot_tt(sel, crits):
    out = torch.zeros(t.size(2), t.size(0))

    plt.figure(figsize=[5.4,5.0])
    p_y = plt.subplot(2, 2, 2, position=[0.85, 0.15, 0.1, 0.65])
    p_x = plt.subplot(2, 2, 3, position=[0.15, 0.85, 0.65, 0.1])
    p_xy = plt.subplot(2, 2, 1, position=[0.15, 0.15, 0.65, 0.65])

    for r in crits:
        for task in range(t.size(2)):
            all_min = loss[:, :, task].min()
            for n_opt in sel:
                gnorm_aux = gnorm[n_opt, :, task]
                t_aux = t[n_opt, :, task]

                out[task, n_opt] = t_aux[gnorm_aux < r][0] if any(gnorm_aux < r) else -1e9

        xy_valid = (out[:, sel[0]] != -1e9) & (out[:, sel[1]] != -1e9)
        p_xy.scatter(out[xy_valid, sel[0]], out[xy_valid, sel[1]], label=f'$\|g_k\|<${r:.0e}')

        y_inf = out[(out[:, sel[0]] == -1e9) & (out[:, sel[1]] != -1e9), sel[1]]
        p_y.scatter(np.zeros_like(y_inf), y_inf, label=None)

        x_inf = out[(out[:, sel[1]] == -1e9) & (out[:, sel[0]] != -1e9), sel[0]]
        p_x.scatter(x_inf, np.zeros_like(x_inf), label=None)

        win = ((out[:, sel[0]] / out[:, sel[1]]).abs() > 1).sum()
        draw = ((out[:, sel[0]] == out[:, sel[1]])).sum()
        print(f'|g_k|<{r:.0e} - Win: {win/(t.size(2)/100.0):.1f}, Draw: {draw/(t.size(2)/100.0):.1f}')
    p_xy.set_xlabel(opt[sel[0]][2])
    p_xy.set_ylabel(opt[sel[1]][2])
    p_xy.legend(loc=[0.7,0.9])
    m_lim = out[:, sel].max()
    z = torch.zeros(1)
    l = torch.cat([z, m_lim.view(-1)])
    p_xy.plot(l, l)

    p_xy.grid()
    p_xy.set_xlim(0, m_lim * 1.05)
    p_xy.set_ylim(0, m_lim * 1.05)

    p_x.grid()
    p_x.set_xlim(0, m_lim * 1.05)
    p_x.set_ylim(-1, 1)
    p_y.grid()
    p_y.set_xlim(-1, 1)
    p_y.set_ylim(0, m_lim * 1.05)

    p_x.set_xticklabels([])
    p_y.set_xticks([0])
    p_y.set_xticklabels(['$\infty$'])

    p_x.set_yticks([0])
    p_x.set_yticklabels(['$\infty$'])
    p_y.set_yticklabels([])
    plt.show()


def plot_avg(sel,exc=[]):
    plt.figure()
    indexes = torch.arange(loss.shape[2])
    slicer = torch.as_tensor(indexes).bool() + True
    for e in exc:
        slicer = slicer * (indexes != e)
    m_step = steps[sel, :, slicer].mean(1)
    max_step = steps[sel, :, slicer].max(1).values
    min_step = steps[sel, :, slicer].min(1).values
    std_step = steps[sel, :, slicer].std(1)

    plt.semilogy(m_step, label='BFGS', ls='--', marker='o')
    plt.fill_between(range(len(m_step)), m_step - std_step, m_step + std_step,
                     color='0.9', ls='--')


    plt.ylabel(f'Averaged $t_k$')
    plt.xlabel('$k$')
    plt.show(block=True)


def plot_mm(sel, exc=[]):
    out = torch.zeros(t.size(2), t.size(0))

    plt.figure()
    p_xy = plt.subplot(1, 1, 1)

    for task in range(t.size(2)):
        all_min = loss[:, :, task].min()
        for n_opt in range(t.size(0)):
            gnorm_aux = gnorm[n_opt, :, task]
            t_aux = t[n_opt, :, task]
            v_aux = loss[n_opt, :, task]


            out[task, n_opt] = v_aux[v_aux > 0].min()

    for n_opt in range(t.size(0)):
        jump = any([n_opt == i for i in exc])
        if (n_opt == sel) or jump:
            continue
        p_xy.scatter(out[:, n_opt], out[:, sel], label=opt[n_opt][2][:10])
        win = (out[:, sel] < out[:, n_opt]).sum()
        print(f'Against{opt[n_opt][2]} - Win: {win / (t.size(2) / 100.0):.1f}')

    p_xy.set_xlabel('$f_{*}$ for other algorithms')
    p_xy.set_ylabel(opt[sel][2])
    p_xy.legend(loc='upper right')
    max_lin = out.max()
    min_lin = out.min()
    l = torch.cat([min_lin.view(-1), max_lin.view(-1)])
    p_xy.loglog(l, l)

    p_xy.grid()
    p_xy.set_xlim(min_lin * 0.95, max_lin * 1.05)
    p_xy.set_ylim(min_lin * 0.95, max_lin * 1.05)

    plt.show()


def plot_boxp(sel,we):
    red_square = dict(markerfacecolor='k', marker='s')
    plt.figure(figsize=[5.4,5.0])

    p_xy = plt.axes()
    out = [torch.zeros(t.size(2)) for _ in sel]
    for i in range(len(sel)):
        for task in range(t.size(2)):

            v_aux1 = loss[sel[i], :, task]
            v_aux2 = loss[we, :, task]

            out[i][task] = (v_aux1[v_aux1 > 0].min()/v_aux2[v_aux2 > 0].min()).log()
    out_np = [o.numpy() for o in out]
    plt.boxplot(out_np, flierprops=red_square)

    p_xy.set_xlabel('')
    p_xy.set_ylabel('${I}_a$')
    p_xy.set_title('nn')
    p_xy.set_xticklabels([opt[n_opt][2] for n_opt in sel])

    p_xy.grid()

    plt.show()


sel = [2, 3]

if plot_time_time:
    for i in plot_optms:
        plot_tt([i, we], crits)


if plot_min_min:
    plot_mm(we,[]) 

if plot_avg_step:
    plot_avg(we)

if plot_box:
    plot_boxp([0,1,2], we)