import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import pickle
from pprint import pprint as prt
import cv2
import seaborn as sns
import os
from glob import glob
import scipy.linalg as la

def plot(*x,close=0,**w):
    # close=0
    if close:
        plt.close('all')
        plt.plot(*x,**w)
        plt.show()
    else:
        # plt.figure(999)
        plt.plot(*x,**w)







x = np.linspace(-0.02,0.02,31)
xticks = np.linspace(-0.02, 0.02, 5).round(2)







names = ['f1_dmclil', 'f1_dm', 'f1_dmclil_simple', 'f1_dm_gp','f2_rp1','f2_rp2','f2_rp1_poly','f2_rp2_gp','f3_rpsi1','f3_rpsi2','f3_rpsi3','f3_rpsi1_gp','f3_rpsi2_simple','f3_rpsi3_poly','f4_adam','f4_sgd','f4_mom','f4_mom_gp','f4_sgd_poly','f4_adam_simple']
beforeadd = 'exp_'


for arrn in names:
    narr = beforeadd+arrn+'.npy'
    arr = np.load(narr)
    exec(f'{arrn} = arr')



















linemks =   ['b*' , 'g^', 'rh',     'k+-' ,  'k2-',  'k|-']
linews =    [10 ,    10,   10  ,     1,       1,     1   ]
linemk_size = [ None, None,  None,   10,      10,     10,]

spacing = 0.8
L = 0.68
L2 = 0.85

alpha = 0.2
ft1 = 17








plt.axes([0,0,L,L2])
plt.title('DM Models', fontsize=ft1)
plt.xticks(xticks,xticks)


idx = 0
plot(x, f1_dm, linemks[idx], linewidth = linews[idx], markersize=linews[idx], alpha=alpha, label=r'DM @ $g_t$')
idx = -1
plot(x, f1_dm_gp, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx], label=r'simpleSR: DM @ $g_t$')


idx = 1
plot(x, f1_dmclil, linemks[idx], linewidth = linews[idx], markersize=linews[idx],alpha=alpha,label=r'DM+CL+IL @ $g_t$')
idx = -2
plot(x, f1_dmclil_simple, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx], label=r'polySR: DM+CL+IL @ $g_t$')






plt.legend(loc='upper right')




plt.axes([spacing,0,L,L2])
plt.title('RP Model', fontsize=ft1)
plt.xticks(xticks,xticks)


idx = 0
plot(x, f2_rp1, linemks[idx], linewidth = linews[idx], markersize=linews[idx],alpha=alpha,label=r'RP @ $\tilde{m}_t$')
idx = -1
plot(x, f2_rp1_poly, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'polySR: RP @ $\tilde{m}_t$')


idx = 1
plot(x, f2_rp2, linemks[idx], linewidth = linews[idx],markersize=linews[idx], alpha=alpha,label=r'RP @ $\tilde{g}_t$')
idx = -2
plot(x, f2_rp2_gp, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'gpSR: RP @ $\tilde{g}_t$')



plt.legend(loc='upper right')








plt.axes([spacing*2,0,L,L2])
plt.title('RP_si model', fontsize=ft1)
plt.xticks(xticks,xticks)


idx = 0
plot(x, f3_rpsi1, linemks[idx], linewidth = linews[idx],markersize=linews[idx], alpha=alpha,label=r'RP_si @ $\tilde{m}_t$')
idx = -1
plot(x, f3_rpsi1_gp, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'gpSR: RP_si @ $\tilde{m}_t$')



idx = 1
plot(x, f3_rpsi2, linemks[idx], linewidth = linews[idx],markersize=linews[idx], alpha=alpha,label=r'RP_si @ $\tilde{g}_t$')
idx = -2
plot(x, f3_rpsi2_simple, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'simpleSR: RP_si @ $\tilde{g}_t$')




idx = 2
plot(x, f3_rpsi3, linemks[idx], linewidth = linews[idx], markersize=linews[idx],alpha=alpha,label=r'RP_si @ $\tilde{n}_t$')
idx = -3
plot(x, f3_rpsi3_poly, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'polySR: RP_si @ $\tilde{n}_t$')








plt.legend(loc='upper right')






plt.axes([spacing*3,0,L,L2])
plt.title('Hand-Crafted Optimizers', fontsize=ft1)
plt.xticks(xticks,xticks)


idx = 0
plot(x, f4_adam, linemks[idx], linewidth = linews[idx],markersize=linews[idx], alpha=alpha,label=r'Adam @ $g_t$')
idx = -2
plot(x, f4_adam_simple, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'simpleSR: Adam @ $g_t$')





idx = 1
plot(x, f4_sgd, linemks[idx], linewidth = linews[idx],markersize=linews[idx], alpha=alpha,label=r'SGD @ $g_t$')
idx = -2
plot(x, f4_sgd_poly, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'polySR: SGD @ $g_t$')




idx = 2
plot(x, f4_mom, linemks[idx], linewidth = linews[idx],markersize=linews[idx],alpha=alpha, label=r'mom(0.5) @ $g_t$')
idx = -3
plot(x, f4_mom_gp, linemks[idx], linewidth = linews[idx], markersize=linemk_size[idx],label=r'gpSR: mom(0.5) @ $g_t$')






plt.legend(loc='upper right')






plt.savefig('nonlinmap.pdf',bbox_inches='tight')







