import pickle
import numpy as np
import datetime
import importlib
import os
import pandas as pd
import time

import matplotlib.pyplot as plt
from IPython.display import Image
import seaborn as sns

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset

from matplotlib.ticker import FormatStrFormatter

import matplotlib

    
def HSn_dLtrSn(log, lr, end_step,start_step=0, savefig=0, half=False, savename='HSn_dLtrSn',textloc=[100,0.2],s= matplotlib.rcParams['lines.markersize'] ** 2, textloc1=[50,0.4], ylim=None, xlim=[-10,200], endx=False,title=None,figsize=(6,4),n_classes=10):
    
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    df =  df_all(log)
    print(df.shape)

    train_loss = np.array(df.iloc[:,4])

    d_train_loss = train_loss[1:]-train_loss[:-1]
    
    trSn_array = np.array(df.iloc[:,23+n_classes])**2
    HSn = np.array(df.iloc[:,36+n_classes])
    H = np.array(df.iloc[:,13])

    fig,ax = plt.subplots(figsize=figsize)
    plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))

    cm = plt.cm.get_cmap('RdYlBu')

    end_step = end_step+1
    x = HSn[start_step:end_step]
    y = d_train_loss[start_step:end_step]/trSn_array[start_step:end_step]
    H = H[start_step:end_step]
    if (H>2/lr).sum()>0:
        print('entering Cohen EOS at', start_step+np.argmax(H>2/lr))
    if (x>2/lr).sum()>0:
        print('entering our EOS at', start_step+np.argmax(x>2/lr))
    if endx:
        plt.scatter(x[-1],y[-1], marker='x', s=50, c='darkblue')
    plt.scatter(x[:-1],y[:-1], s=s,c=np.arange(start_step,end_step-1), cmap=cm)
    
    plt.plot([0,xlim[1]],[-lr, lr**2/2*(xlim[1]-2/lr)],alpha=0.3)
    if half:
        plt.plot([0,xlim[1]],[-lr/2, lr**2/2*(xlim[1]-2/lr)/2],alpha=0.3)
        plt.plot([0,xlim[1]],[-lr/4, lr**2/2*(xlim[1]-2/lr)/4],alpha=0.3)
    # plt.plot([0,xlim[1]],[-lr/2, lr**2/2*(xlim[1]-2/lr)/2],alpha=0.3)
    plt.text(textloc[0],textloc[1],r'slope=$\frac{\eta^2}{2}$')
    plt.text(textloc1[0],textloc1[1],r'$\frac{2}{\eta}=$%.f'%(2/lr), c='r')
    plt.xlabel(r'$||H||_{S_n}$')
    plt.ylabel(r'$(L_{t+1}-L_t)/tr(S_n)$')
    
    plt.xlim(xlim[0],xlim[1])
    if ylim:
        plt.ylim(ylim[0],ylim[1])
    else:
        plt.ylim(-10000/32*6/5*lr**2,10000/32*6/5*lr**2)
    plt.xticks(xlim[1]*np.arange(5)/4)
    plt.axhline(0,alpha=0.4,c='k')
    plt.axvline(0,alpha=0.4,c='k')
    plt.axvline(2/lr,alpha=0.4,c='r',linestyle='--')
    # plt.cool()
    plt.colorbar(label='step', format='%d')
    if title:
        plt.title(title)
    plt.tight_layout()
    if savefig:
        plt.savefig('../pdfs/'+savename+'_'+time_now+'.pdf')
        print(savename+time_now)
    plt.show()
    
    
def HSn_dLtrSn_SGD(log, lr, end_step, start_step=0,savefig=0, half=False, s= matplotlib.rcParams['lines.markersize'] ** 2,endx=False,
                   savename='HSn_dLtrSn',textloc=[100,0.2], textloc1=[50,0.4], ylim=None, xlim=[-10,200], title=None,figsize=(6,4),n_classes=10):
    
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    df =  df_all(log)
    print(df.shape)

    train_loss = np.array(df.iloc[:,45+n_classes])
    L_t_next = np.array(df.iloc[:,42+n_classes])

    d_train_loss = L_t_next-train_loss
    
    trSn_array = np.array(df.iloc[:,23+n_classes])**2
    HSn = np.array(df.iloc[:,36+n_classes])
    H = np.array(df.iloc[:,13])

    fig,ax = plt.subplots(figsize=figsize)
    plt.ticklabel_format(axis="x", style="sci", scilimits=(0,0))
    plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
    # ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

    cm = plt.cm.get_cmap('RdYlBu')

    end_step = end_step+1
    x = HSn[start_step:end_step]
    y = d_train_loss[start_step:end_step]/trSn_array[start_step:end_step]
    H = H[start_step:end_step]
    # print('H',H-2/lr)
    if endx:
        plt.scatter(x[-1],y[-1], marker='x', s=50,c='darkblue')
        print(x[-1],y[-1])
    plt.scatter(x[:-1],y[:-1], s=s,c=np.arange(start_step,end_step-1), cmap=cm)
    
    plt.plot([0,xlim[1]],[-lr, lr**2/2*(xlim[1]-2/lr)],alpha=0.3)
    if half:
        plt.plot([0,xlim[1]],[-lr/2, lr**2/2*(xlim[1]-2/lr)/2],alpha=0.3)
        plt.plot([0,xlim[1]],[-lr/4, lr**2/2*(xlim[1]-2/lr)/4],alpha=0.3)
    # plt.plot([0,xlim[1]],[-lr/2, lr**2/2*(xlim[1]-2/lr)/2],alpha=0.3)
    plt.text(textloc[0],textloc[1],r'slope=$\frac{\eta^2}{2}$')
    plt.text(textloc1[0],textloc1[1],r'$\frac{2}{\eta}=$%.f'%(2/lr), c='r')
    plt.xlabel(r'$\frac{tr(HS_b)}{tr(S_n)}$')
    plt.ylabel(r'$(E[L_{t+1}]-L_t)/tr(S_n)$')
    
    plt.xlim(xlim[0],xlim[1])
    if ylim:
        plt.ylim(ylim[0],ylim[1])
    else:
        plt.ylim(-10000/32*6/5*lr**2,10000/32*6/5*lr**2)
    plt.xticks(xlim[1]*np.arange(5)/4)
    plt.axhline(0,alpha=0.4,c='k')
    plt.axvline(0,alpha=0.4,c='k')
    plt.axvline(2/lr,alpha=0.4,c='r',linestyle='--')
    # plt.cool()
    plt.colorbar(label='step')
    if title:
        plt.title(title)
    plt.tight_layout()
    if savefig:
        plt.savefig('../pdfs/'+savename+'_'+time_now+'.pdf')
        print(savename+time_now)
    plt.show()
    
def df_all(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],ylim4=[None,None],
               acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,show_list4=None,
               opt=None,opt1=None,opt2=None,opt3=None,opt4=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,c_list4=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,l_list4=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],dashed4=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,list_opt4=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,ncol4=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,scatter4=False,
               linewidth=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               ygap=0.,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,axhlines4=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               log0=0, log1=0,log2=0,log3=0,log4=0,
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False,
               acc_start=0,acc_end=-1):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = '../log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        prefixed.sort()
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 

    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
        
    return df
    
    
def figure_num(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],ylim4=[None,None],
               acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,show_list4=None,
               opt=None,opt1=None,opt2=None,opt3=None,opt4=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,c_list4=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,l_list4=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],dashed4=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,list_opt4=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,ncol4=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,scatter4=False,
               linewidth=1,linewidth1=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               ygap=0.,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,axhlines4=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               log0=0, log1=0,log2=0,log3=0,log4=0,
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False,
               lshift=False,
               yticks1=None,
               fill1=None,
               inlegend=False,
               title=None,
               acc_start=0,acc_end=-1):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = '../log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        prefixed.sort()
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
#             else:
#                 print('last out',i)
#                 print(len(line_split))

    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
    print('# steps:',df.shape[0])
    print('total time taken (train):(%2.2f h,%4.2f m)'%(np.array(df[1])[-1]/3600, np.array(df[1])[-1]/60))
    print('avg time taken   (test) :(%4.2f s)'%(np.array(df[2])[df[2]>0].mean()))
    print('# cols:',df.shape[1], '; 0~'+str(df.shape[1]-1))
    fig, ax = plt.subplots(figsize=figsize,dpi=400)
    if title:
        plt.title(title)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan','deeppink','royalblue']
    
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    
        
    for idx, i in enumerate(show_list):
        if isinstance(i,list):
            x = sgd*df[0]
            if list_opt == 'divide' or list_opt == 'x/y':
                y = df[i[0]]/df[i[1]]
            elif list_opt == 'product' or list_opt == 'x*y':
                y = df[i[0]]*df[i[1]]
            elif list_opt == 'sum' or list_opt == 'x+y':
                y = df[i[0]]+df[i[1]]
            elif list_opt == 'x^2y':
                y = df[i[0]]**2*df[i[1]]
            elif list_opt == 'x-y':
                y = df[i[0]]-df[i[1]]
        else:
            if opt=='delay':                
                x = sgd*df[0][1:]
                y = df[i][:-1]
            else:              
                x = sgd*df[0]                
                if opt=='abs':
                    y = df[i].abs()
                elif opt=='sqrt':
                    y = np.sqrt(df[i]) 
                elif opt=='square':
                    y = df[i]**2    
                else:
                    y = df[i]
        if scatter:
            ax.scatter(x,y,c=colors[idx],label=label[idx],s=s)
            if inlegend:
                ax.legend(loc='upper left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
            else:
                ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
        else:
            if idx in dashed:
                ax.plot(x,y,c=colors[idx],label=label[idx],linestyle='--',linewidth=linewidth)
            else:
                ax.plot(x,y,c=colors[idx],label=label[idx],linewidth=linewidth)
            if inlegend:
                ax.legend(loc='upper left',bbox_to_anchor=(0,1),ncol=ncol)
            else:
                ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    
    
    major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
    ax.set_xticks(major_ticks)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True,which='both',axis='x')
    
        
    if axhlines is not None:
        for axhline in axhlines:
            ax.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
    if axvlines is not None:
        for axvline in axvlines:
            ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if show_list1:
        ax_twin = ax.twinx()
        if c_list1 is not None:
            colors = c_list1
        else:
            colors = tab_c
        if l_list1 is not None:        
            label = l_list1
        else:
            label = [str(i) for i in show_list1]
        if axhlines1 is not None:
            for axhline in axhlines1:
                ax_twin.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
                if fill1:
                    ax_twin.fill_between(np.array([x.min(),x.max()]), axhline, color='r', alpha=0.1)
#         if axvlines1 is not None:
#             for axvline in axvlines1:
#                 ax_twin.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list1):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt1 == 'divide' or list_opt1 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt1 == 'product' or list_opt1 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt1 == 'sum' or list_opt1 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt1 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt1 == 'x-y':
                    y = df[i[0]]-df[i[1]]
                elif list_opt1 == 'custom':
                    y = df[i[0]]**2*(df[i[1]]+df[i[2]])
                elif list_opt1 == 'custom1':
                    y = df[i[0]]**2*(df[i[1]]/df[i[2]])
                elif list_opt1 == 'cx':
                    y = i[1]*df[i[0]]
            else:
                if opt1=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt1=='abs':
                        y = df[i].abs()
                    elif opt1=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt1=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter1:
                ax_twin.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                if inlegend:
                    ax_twin.legend(loc='upper right',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)                    
                else:
                    if lshift:
                        ax_twin.legend(loc='lower right',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)
                    else:
                        ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)
            else:
                if idx in dashed1:
                    ax_twin.plot(x,y,c=colors[idx],label=label[idx],linestyle='--',linewidth=linewidth1)
                else:
                    ax_twin.plot(x,y,c=colors[idx],label=label[idx],linewidth=linewidth1)
                if inlegend:
                    ax_twin.legend(loc='upper right',bbox_to_anchor=(1,1),ncol=ncol1)   
                else:
                    if lshift:
                        ax_twin.legend(loc='lower right',bbox_to_anchor=(1,1),ncol=ncol1)
                    else:
                        ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),ncol=ncol1)
                

        ax_twin.set_ylim(ylim1)
        if yticks1:
            ax_twin.set_yticks(yticks1)
        if log1:
            ax_twin.set_yscale('log')
    ######################################################    
    if show_list2:
        ax_twin2 = ax.twinx()
        if c_list2 is not None:
            colors = c_list2
        else:
            colors = tab_c
        if l_list2 is not None:        
            label = l_list2
        else:
            label = [str(i) for i in show_list2]
            
        if axhlines2 is not None:
            for axhline in axhlines2:
                ax_twin2.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines2 is not None:
#             for axvline in axvlines2:
#                 ax_twin2.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list2):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt2 == 'divide' or list_opt2 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt2 == 'product' or list_opt2 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt2 == 'sum' or list_opt2 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt2 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt2 == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt2=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt2=='abs':
                        y = df[i].abs()
                    elif opt2=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt2=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter2:
                ax_twin2.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin2.legend(loc='lower center',bbox_to_anchor=(1+gap,1+ygap),scatterpoints=10,ncol=ncol2)
            else:
                if idx in dashed2:
                    ax_twin2.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin2.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin2.legend(loc='lower center',bbox_to_anchor=(1+gap,1+ygap),ncol=ncol2)
                
        ax_twin2.spines["right"].set_position(("axes", 1+gap))
        ax_twin2.set_ylim(ylim2)
        if log2:
            ax_twin2.set_yscale('log')

    ######################################################    
    if show_list3:
        ax_twin3 = ax.twinx()
        if c_list3 is not None:
            colors = c_list3
        else:
            colors = tab_c
        if l_list3 is not None:        
            label = l_list3
        else:
            label = [str(i) for i in show_list3]
            
        if axhlines3 is not None:
            for axhline in axhlines3:
                ax_twin3.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines3 is not None:
#             for axvline in axvlines3:
#                 ax_twin3.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list3):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt3 == 'divide' or list_opt3 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt3 == 'product' or list_opt3 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt3 == 'sum' or list_opt3 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt3 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt3 == 'x/y^2':
                    y = df[i[0]]/(df[i[1]]**2)
                elif list_opt3 == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt3=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt3=='abs':
                        y = df[i].abs()
                    elif opt3=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt3=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter3:
                ax_twin3.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin3.legend(loc='lower center',bbox_to_anchor=(1+2*gap,1+2*ygap),scatterpoints=10,ncol=ncol2)
            else:
                if idx in dashed3:
                    ax_twin3.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin3.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin3.legend(loc='lower center',bbox_to_anchor=(1+2*gap,1+2*ygap),ncol=ncol2)
                
        ax_twin3.spines["right"].set_position(("axes", 1+2*gap))
        ax_twin3.set_ylim(ylim3)
        if log3:
            ax_twin3.set_yscale('log')
    
    ###################################
    if show_list4:
        ax_twin4 = ax.twinx()
        if c_list4 is not None:
            colors = c_list4
        else:
            colors = tab_c
        if l_list4 is not None:        
            label = l_list4
        else:
            label = [str(i) for i in show_list4]
            
        if axhlines4 is not None:
            for axhline in axhlines4:
                ax_twin4.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines2 is not None:
#             for axvline in axvlines2:
#                 ax_twin2.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list4):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt4 == 'divide' or list_opt4 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt4 == 'product' or list_opt4 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt4 == 'sum' or list_opt4 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt4 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt4 == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt4=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt4=='abs':
                        y = df[i].abs()
                    elif opt4=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt4=='square':
                        y = df[i]**2    
                    elif opt4=='1/x':
                        y = 1/df[i]    
                    elif opt4=='2/x':
                        y = 2/df[i]  
                    else:
                        y = df[i]
            if scatter4:
                ax_twin4.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin4.legend(loc='lower center',bbox_to_anchor=(1+3*gap,1+3*ygap),scatterpoints=10,ncol=ncol4)
            else:
                if idx in dashed4:
                    ax_twin4.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin4.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin4.legend(loc='lower center',bbox_to_anchor=(1+3*gap,1+3*ygap),ncol=ncol4)
                
        ax_twin4.spines["right"].set_position(("axes", 1+3*gap))
        ax_twin4.set_ylim(ylim4)
        if log4:
            ax_twin4.set_yscale('log')
    ###################################

    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
        elif len(t_list)==2:
            ax_twin.set_ylabel(t_list[1])
        elif len(t_list)==3:
            ax_twin2.set_ylabel(t_list[2])
        elif len(t_list)==4:
            ax_twin3.set_ylabel(t_list[3])
        else:
            ax_twin4.set_ylabel(t_list[4])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('../pdfs/'+save_name+time_now+'.png')
        plt.savefig('../pdfs/'+save_name+time_now+'.pdf')
        print(save_name+time_now)
    plt.show()
    
    
    if acc_plot:
        fig, ax2 = plt.subplots()
        ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
        ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
        print('Best Test Acc',df[9].max())
        print('Best Train Acc',df[5].max())
        print('Best Train Loss %.6f'%(df[4].min()))
        
        print('Best Test Acc (cut)',df[9][acc_start:acc_end].max())
        print('Best Train Acc (cut)',df[5][acc_start:acc_end].max())
        print('Best Train Loss (cut) %.6f'%(df[4][acc_start:acc_end].min()))
        if std:
            pass
        else:
            ax2.plot(df[0], df[5], label='Train Adv Acc',c=colors[0])
            ax2.plot(df[0][1:], df[12][:-1], label='Test Adv Acc',c=colors[0], linestyle='--')
            ax2.axhline(y=df[12].max(),c='k',linestyle=':')
            ax2.text(df[12].argmax()+1,df[12].max(),str(df[12].max()))   

        ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
        ax2.set_ylabel('Acc')
        ax2.set_xlim(acc_plot_xlim)
        ax2.set_ylim(acc_plot_ylim)
        plt.show() 