import pickle
import numpy as np
import datetime
import importlib
import os
import pandas as pd
import time

import matplotlib.pyplot as plt
from IPython.display import Image
import seaborn as sns

from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset



def figure_num(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,
               opt=None,opt1=None,opt2=None,opt3=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,
               linewidth=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False,
               acc_start=0,acc_end=-1):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = './log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
#             else:
#                 print('last out',i)
#                 print(len(line_split))

    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
    print('# steps:',df.shape[0])
    print('# cols:',df.shape[1], '; 0~'+str(df.shape[1]-1))
    fig, ax = plt.subplots(figsize=figsize)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan']
    
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    
        
    for idx, i in enumerate(show_list):
        if isinstance(i,list):
            x = sgd*df[0]
            if list_opt == 'divide' or list_opt == 'x/y':
                y = df[i[0]]/df[i[1]]
            elif list_opt == 'product' or list_opt == 'x*y':
                y = df[i[0]]*df[i[1]]
            elif list_opt == 'sum' or list_opt == 'x+y':
                y = df[i[0]]+df[i[1]]
            elif list_opt == 'x^2y':
                y = df[i[0]]**2*df[i[1]]
            elif list_opt == 'x-y':
                y = df[i[0]]-df[i[1]]
        else:
            if opt=='delay':                
                x = sgd*df[0][1:]
                y = df[i][:-1]
            else:              
                x = sgd*df[0]                
                if opt=='abs':
                    y = df[i].abs()
                elif opt=='sqrt':
                    y = np.sqrt(df[i]) 
                elif opt=='square':
                    y = df[i]**2    
                else:
                    y = df[i]
        if scatter:
            ax.scatter(x,y,c=colors[idx],label=label[idx],s=s)
            ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
        else:
            if idx in dashed:
                ax.plot(x,y,c=colors[idx],label=label[idx],linestyle='--',linewidth=linewidth)
            else:
                ax.plot(x,y,c=colors[idx],label=label[idx],linewidth=linewidth)
            ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    
    
    major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
    ax.set_xticks(major_ticks)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True,which='both',axis='x')
    
        
    if axhlines is not None:
        for axhline in axhlines:
            ax.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
    if axvlines is not None:
        for axvline in axvlines:
            ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if show_list1:
        ax_twin = ax.twinx()
        if c_list1 is not None:
            colors = c_list1
        else:
            colors = tab_c
        if l_list1 is not None:        
            label = l_list1
        else:
            label = [str(i) for i in show_list1]
        if axhlines1 is not None:
            for axhline in axhlines1:
                ax_twin.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines1 is not None:
#             for axvline in axvlines1:
#                 ax_twin.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list1):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt1 == 'divide' or list_opt1 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt1 == 'product' or list_opt1 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt1 == 'sum' or list_opt1 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt1 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt1 == 'x-y':
                    y = df[i[0]]-df[i[1]]
                elif list_opt1 == 'custom':
                    y = df[i[0]]**2*(df[i[1]]+df[i[2]])
                elif list_opt1 == 'custom1':
                    y = df[i[0]]**2*(df[i[1]]/df[i[2]])
                elif list_opt1 == 'cx':
                    y = i[1]*df[i[0]]
            else:
                if opt1=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt1=='abs':
                        y = df[i].abs()
                    elif opt1=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt1=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter1:
                ax_twin.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)
            else:
                if idx in dashed1:
                    ax_twin.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),ncol=ncol1)
                

        ax_twin.set_ylim(ylim1)
        
    if show_list2:
        ax_twin2 = ax.twinx()
        if c_list2 is not None:
            colors = c_list2
        else:
            colors = tab_c
        if l_list2 is not None:        
            label = l_list2
        else:
            label = [str(i) for i in show_list2]
            
        if axhlines2 is not None:
            for axhline in axhlines2:
                ax_twin2.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines2 is not None:
#             for axvline in axvlines2:
#                 ax_twin2.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list2):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt2 == 'divide' or list_opt2 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt2 == 'product' or list_opt2 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt2 == 'sum' or list_opt2 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt2 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt2 == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt2=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt2=='abs':
                        y = df[i].abs()
                    elif opt2=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt2=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter2:
                ax_twin2.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin2.legend(loc='lower center',bbox_to_anchor=(1+gap,1),scatterpoints=10,ncol=ncol2)
            else:
                if idx in dashed2:
                    ax_twin2.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin2.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin2.legend(loc='lower center',bbox_to_anchor=(1+gap,1),ncol=ncol2)
                
        ax_twin2.spines["right"].set_position(("axes", 1+gap))
        ax_twin2.set_ylim(ylim2)

        
    if show_list3:
        ax_twin3 = ax.twinx()
        if c_list3 is not None:
            colors = c_list3
        else:
            colors = tab_c
        if l_list3 is not None:        
            label = l_list3
        else:
            label = [str(i) for i in show_list3]
            
        if axhlines3 is not None:
            for axhline in axhlines3:
                ax_twin3.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines3 is not None:
#             for axvline in axvlines3:
#                 ax_twin3.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, i in enumerate(show_list3):
            if isinstance(i,list):
                x = sgd*df[0]
                if list_opt3 == 'divide' or list_opt3 == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt3 == 'product' or list_opt3 == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt3 == 'sum' or list_opt3 == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt3 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt3 == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt3=='delay':                
                    x = sgd*df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = sgd*df[0]                
                    if opt3=='abs':
                        y = df[i].abs()
                    elif opt3=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt3=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            if scatter3:
                ax_twin3.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax_twin3.legend(loc='lower center',bbox_to_anchor=(1+2*gap,1),scatterpoints=10,ncol=ncol2)
            else:
                if idx in dashed3:
                    ax_twin3.plot(x,y,c=colors[idx],label=label[idx],linestyle='--')
                else:
                    ax_twin3.plot(x,y,c=colors[idx],label=label[idx])
                ax_twin3.legend(loc='lower center',bbox_to_anchor=(1+2*gap,1),ncol=ncol2)
                
        ax_twin3.spines["right"].set_position(("axes", 1+2*gap))
        ax_twin3.set_ylim(ylim3)

    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
        elif len(t_list)==2:
            ax_twin.set_ylabel(t_list[1])
        elif len(t_list)==3:
            ax_twin2.set_ylabel(t_list[2])
        else:
            ax_twin3.set_ylabel(t_list[3])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('pdfs/'+save_name+time_now+'.png')
        plt.savefig('pdfs/'+save_name+time_now+'.pdf')
    plt.show()
    
    
    if acc_plot:
        fig, ax2 = plt.subplots()
        ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
        ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
        print('Best Test Acc',df[9].max())
        print('Best Train Acc',df[5].max())
        print('Best Train Loss',df[6].min())
        
        print('Best Test Acc (cut)',df[9][acc_start:acc_end].max())
        print('Best Train Acc (cut)',df[5][acc_start:acc_end].max())
        print('Best Train Loss (cut)',df[6][acc_start:acc_end].min())
        if std:
            pass
        else:
            ax2.plot(df[0], df[7], label='Train Adv Acc',c=colors[0])
            ax2.plot(df[0][1:], df[11][:-1], label='Test Adv Acc',c=colors[0], linestyle='--')
            ax2.axhline(y=df[11].max(),c='k',linestyle=':')
            ax2.text(df[11].argmax()+1,df[11].max(),str(df[11].max()))   

        ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
        ax2.set_ylabel('Acc')
        ax2.set_xlim(acc_plot_xlim)
        ax2.set_ylim(acc_plot_ylim)
        plt.show() 
    
def figure_dual(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,
               opt=None,opt1=None,opt2=None,opt3=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,
               linewidth=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = './log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
#             else:
#                 print('last out',i)
#                 print(len(line_split))

    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
    print('# steps:',df.shape[0])
    print('# cols:',df.shape[1], '; 0~'+str(df.shape[1]-1))
    fig, ax = plt.subplots(figsize=figsize)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan']
    
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    
        
#     for idx, i in enumerate(show_list):
#         if isinstance(i,list):
#             x = sgd*df[0]
#             if list_opt == 'divide' or list_opt == 'x/y':
#                 y = df[i[0]]/df[i[1]]
#             elif list_opt == 'product' or list_opt == 'x*y':
#                 y = df[i[0]]*df[i[1]]
#             elif list_opt == 'sum' or list_opt == 'x+y':
#                 y = df[i[0]]+df[i[1]]
#             elif list_opt == 'x^2y':
#                 y = df[i[0]]**2*df[i[1]]
#             elif list_opt == 'x-y':
#                 y = df[i[0]]-df[i[1]]
#         else:
#             if opt=='delay':                
#                 x = sgd*df[0][1:]
#                 y = df[i][:-1]
#             else:              
#                 x = sgd*df[0]                
#                 if opt=='abs':
#                     y = df[i].abs()
#                 elif opt=='sqrt':
#                     y = np.sqrt(df[i]) 
#                 elif opt=='square':
#                     y = df[i]**2    
#                 else:
#                     y = df[i]
                    
                    
                    
#         if scatter:
#             ax.scatter(x,y,c=colors[idx],label=label[idx],s=s)
#             ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
#         else:
#             if idx in dashed:
#                 ax.plot(x,y,c=colors[idx],label=label[idx],linestyle='--',linewidth=linewidth)
#             else:
#                 ax.plot(x,y,c=colors[idx],label=label[idx],linewidth=linewidth)
#             ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    df = df [xlim[0]:xlim[1]]
    
#     major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
#     ax.set_xticks(major_ticks)
    
#     ax.set_xlim(xlim)
#     ax.set_xlim(ylim)
#     ax.grid(True,which='both',axis='x')
    
        
#     if axhlines is not None:
#         for axhline in axhlines:
#             ax.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#     if axvlines is not None:
#         for axvline in axvlines:
#             ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if show_list1:
#         ax_twin = ax.twinx()
        if c_list1 is not None:
            colors = c_list1
        else:
            colors = tab_c
        if l_list1 is not None:        
            label = l_list1
        else:
            label = [str(i) for i in show_list1]
#         if axhlines1 is not None:
#             for axhline in axhlines1:
#                 ax_twin.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
#         if axvlines1 is not None:
#             for axvline in axvlines1:
#                 ax_twin.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
            
        for idx, (i,j) in enumerate(zip(show_list, show_list1)):
            
            if isinstance(i,list):
                if list_opt == 'divide' or list_opt == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt1 == 'product' or list_opt == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt1 == 'sum' or list_opt == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt1 == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt1 == 'x-y':
                    y = df[i[0]]-df[i[1]]
                elif list_opt1 == 'custom':
                    y = df[i[0]]**2*(df[i[1]]+df[i[2]])
            else:
#                 print(i, 'is not list')
                if opt=='delay':                
                    y = df[i][:-1]
                else:       
                    if opt=='abs':
                        y = df[i].abs()
                    elif opt=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
#             print(y)
            
            
            if isinstance(j,list):
                if list_opt1 == 'divide' or list_opt1 == 'x/y':
                    z = df[j[0]]/df[j[1]]
                elif list_opt1 == 'product' or list_opt1 == 'x*y':
                    z = df[j[0]]*df[j[1]]
                elif list_opt1 == 'sum' or list_opt1 == 'x+y':
                    z = df[j[0]]+df[j[1]]
                elif list_opt1 == 'x^2y':
                    z = df[j[0]]**2*df[j[1]]
                elif list_opt1 == 'x-y':
                    z = df[j[0]]-df[j[1]]
                elif list_opt1 == 'custom':
                    z = df[j[0]]**2*(df[j[1]]+df[j[2]])
            else:
                if opt1=='delay':                
                    z = df[j][:-1]
                else:              
                    if opt1=='abs':
                        z = df[j].abs()
                    elif opt1=='sqrt':
                        z = np.sqrt(df[j]) 
                    elif opt1=='square':
                        z = df[j]**2    
                    else:
                        z = df[j]
#             print(z)
                        

#             ax.set_xlim(ylim)
#             ax.set_ylim(ylim1)
            ax.scatter(y,z,c=colors[idx],label=label[idx],s=s)
            ax.legend(loc='lower center',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)
                
        
    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
#         elif len(t_list)==2:
#             ax_twin.set_ylabel(t_list[1])
#         elif len(t_list)==3:
#             ax_twin2.set_ylabel(t_list[2])
#         else:
#             ax_twin3.set_ylabel(t_list[3])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('pdfs/'+save_name+time_now+'.png')
        plt.savefig('pdfs/'+save_name+time_now+'.pdf')
    plt.show()
    
    
    if acc_plot:
        fig, ax2 = plt.subplots()
        ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
        ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
        print('Best Test Acc',df[9].max())
        print('Best Train Acc',df[5].max())
        print('Best Train Loss',df[6].min())
        if std:
            pass
        else:
            ax2.plot(df[0], df[7], label='Train Adv Acc',c=colors[0])
            ax2.plot(df[0][1:], df[11][:-1], label='Test Adv Acc',c=colors[0], linestyle='--')
            ax2.axhline(y=df[11].max(),c='k',linestyle=':')
            ax2.text(df[11].argmax()+1,df[11].max(),str(df[11].max()))   

        ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
        ax2.set_ylabel('Acc')
        ax2.set_xlim(acc_plot_xlim)
        ax2.set_ylim(acc_plot_ylim)
        plt.show() 
    
    
def load_df(log, path=None):
    if path is None:
        path = './log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
            else:
                df_list.append(line_split[:len(header)]) 
                
#                 print('last out',i)
#                 print(len(line_split))
#                 print(len(header))

    df = pd.DataFrame(df_list,dtype=np.float)
    return df



def figure_multi(logs, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],
               show_list=None,show_list1=None,show_list2=None,
               opt=None,opt1=None,opt2=None,
               c_list=None, c_list1=None, c_list2=None,
               l_list=None, l_list1=None, l_list2=None,
               dashed=[], dashed1=[], dashed2=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,
               ncol=1,ncol1=1,ncol2=1,
               scatter=False,scatter1=False,scatter2=False,
               linewidth=2,
               linestyle = ['-','--',':'],
               markerstyle = [None,None,None],
               markevery=5,
               ticks=1000,
               s=5,
               acc_plot=True,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               axhlines=None,
               axvlines=None,
               moving_average=True,
               sgd=None,
               text=False,
               log_plot=False, 
               step=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               opt_sum=False):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    if moving_average:
        dfs = []
        for log_idx ,log in enumerate(logs):
            if isinstance(step,list):
                df = load_df(log, path)
                tmp_df = df.rolling(window=step[log_idx]).mean()
                tmp_df[:1] = df[:1]
                tmp_df =tmp_df.drop(np.arange(1,step[log_idx]),0)
                dfs.append(tmp_df)                
            else:
                df = load_df(log, path)
                tmp_df = df.rolling(window=step).mean()
                tmp_df[:1] = df[:1]
                tmp_df =tmp_df.drop(np.arange(1,step),0)
                dfs.append(tmp_df)
    else:
        dfs = [load_df(log, path)[::step] for log in logs]
    if sgd is None:
        sgd = [1]*len(dfs)


    
    fig, ax = plt.subplots(figsize=figsize)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan']
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    

    for df_idx, df in enumerate(dfs):
        print('# steps:',df.shape[0])
        for idx, i in enumerate(show_list):
            if isinstance(i,list):
                x = df[0]
                if list_opt == 'divide' or list_opt == 'x/y':
                    y = df[i[0]]/df[i[1]]
                elif list_opt == 'product' or list_opt == 'x*y':
                    y = df[i[0]]*df[i[1]]
                elif list_opt == 'sum' or list_opt == 'x+y':
                    y = df[i[0]]+df[i[1]]
                elif list_opt == 'x^2y':
                    y = df[i[0]]**2*df[i[1]]
                elif list_opt == 'x-y':
                    y = df[i[0]]-df[i[1]]
            else:
                if opt=='delay':                
                    x = df[0][1:]
                    y = df[i][:-1]
                else:              
                    x = df[0]                
                    if opt=='abs':
                        y = df[i].abs()
                    elif opt=='sqrt':
                        y = np.sqrt(df[i]) 
                    elif opt=='square':
                        y = df[i]**2    
                    else:
                        y = df[i]
            x = sgd[df_idx]*x
            if scatter:
                ax.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
            else:
                if idx in dashed:
                    ax.plot(x,y,c=colors[idx],label=label[idx],linestyle='--',linewidth=linewidth)
                else:
                    if df_idx ==0:
                        ax.plot(x,y,c=colors[idx],label=label[idx],linewidth=linewidth,linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
                    else:
                        ax.plot(x,y,c=colors[idx],linewidth=linewidth,linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
    ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    
    
    major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
    ax.set_xticks(major_ticks)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True,which='both',axis='x')
    
        
    if axhlines is not None:
        for idx, axhline in enumerate(axhlines):
            ax.axhline(y=axhline, alpha=0.5, linestyle=linestyle[idx], c=colors[0])
    if axvlines is not None:
        for axvline in axvlines:
            ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if show_list1:
        ax_twin = ax.twinx()
        if c_list1 is not None:
            colors = c_list1
        else:
            colors = tab_c
        if l_list1 is not None:        
            label = l_list1
        else:
            label = [str(i) for i in show_list1]
            
        for df_idx, df in enumerate(dfs):
            for idx, i in enumerate(show_list1):
                if isinstance(i,list):
                    x = df[0]
                    if list_opt1 == 'divide' or list_opt1 == 'x/y':
                        y = df[i[0]]/df[i[1]]
                    elif list_opt1 == 'product' or list_opt1 == 'x*y':
                        y = df[i[0]]*df[i[1]]
                    elif list_opt1 == 'sum' or list_opt1 == 'x+y':
                        y = df[i[0]]+df[i[1]]
                    elif list_opt1 == 'x^2y':
                        y = df[i[0]]**2*df[i[1]]
                    elif list_opt1 == 'x-y':
                        y = df[i[0]]-df[i[1]]
                else:
                    if opt1=='delay':                
                        x = df[0][1:]
                        y = df[i][:-1]
                    else:              
                        x = df[0]                
                        if opt1=='abs':
                            y = df[i].abs()
                        elif opt1=='sqrt':
                            y = np.sqrt(df[i]) 
                        elif opt1=='square':
                            y = df[i]**2    
                        else:
                            y = df[i]
                x = sgd[df_idx]*x
                if scatter1:
                    ax_twin.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                    ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),scatterpoints=10,ncol=ncol1)
                else:
                    if df_idx ==0:
                        ax_twin.plot(x,y,c=colors[idx],label=label[idx],linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
                    else:
                        ax_twin.plot(x,y,c=colors[idx],linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
        ax_twin.legend(loc='lower center',bbox_to_anchor=(1,1),ncol=ncol1)
                

        ax_twin.set_ylim(ylim1)
        
    if show_list2:
        ax_twin2 = ax.twinx()
        if c_list2 is not None:
            colors = c_list2
        else:
            colors = tab_c
        if l_list2 is not None:        
            label = l_list2
        else:
            label = [str(i) for i in show_list2]
            
        for df_idx, df in enumerate(dfs):
            for idx, i in enumerate(show_list2):
                if isinstance(i,list):
                    x = df[0]
                    if list_opt2 == 'divide' or list_opt2 == 'x/y':
                        y = df[i[0]]/df[i[1]]
                    elif list_opt2 == 'product' or list_opt2 == 'x*y':
                        y = df[i[0]]*df[i[1]]
                    elif list_opt2 == 'sum' or list_opt2 == 'x+y':
                        y = df[i[0]]+df[i[1]]
                    elif list_opt2 == 'x^2y':
                        y = df[i[0]]**2*df[i[1]]
                    elif list_opt2 == 'x-y':
                        y = df[i[0]]-df[i[1]]
                else:
                    if opt2=='delay':                
                        x = df[0][1:]
                        y = df[i][:-1]
                    else:              
                        x = df[0]                
                        if opt2=='abs':
                            y = df[i].abs()
                        elif opt2=='sqrt':
                            y = np.sqrt(df[i]) 
                        elif opt2=='square':
                            y = df[i]**2    
                        else:
                            y = df[i]
                x = sgd[df_idx]*x
                if scatter2:
                    ax_twin2.scatter(x,y,c=colors[idx],label=label[idx],s=s)
                    ax_twin2.legend(loc='lower center',bbox_to_anchor=(1+gap,1),scatterpoints=10,ncol=ncol2)
                else:
                    if df_idx ==0:
                        ax_twin2.plot(x,y,c=colors[idx],label=label[idx],linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
                    else:
                        ax_twin2.plot(x,y,c=colors[idx],linestyle=linestyle[df_idx], marker=markerstyle[df_idx],markevery=markevery)
        ax_twin2.legend(loc='lower left',bbox_to_anchor=(1+gap,1),ncol=ncol2)
                
        ax_twin2.spines["right"].set_position(("axes", 1+gap))
        ax_twin2.set_ylim(ylim2)

    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
        elif len(t_list)==2:
            ax.set_ylabel(t_list[0])
            ax_twin.set_ylabel(t_list[1])
        else:
            ax.set_ylabel(t_list[0])
            ax_twin.set_ylabel(t_list[1])
            ax_twin2.set_ylabel(t_list[2])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('pdfs/'+save_name+time_now+'.png')
        plt.savefig('pdfs/'+save_name+time_now+'.pdf')
    plt.show()
    
    
    if acc_plot:
        for df_idx, df in enumerate(dfs):
            fig, ax2 = plt.subplots()
            ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
            ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
            print('Best Test Acc',df[9].max())
            print('Best Train Acc',df[5].max())
            print('Best Train Loss',df[6].min())


            ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
            ax2.set_ylabel('Acc')
            plt.show() 


def figure_std(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,
               opt=None,opt1=None,opt2=None,opt3=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,
               linewidth=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False,
               alpha=0.9):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = './log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
#             else:
#                 print('last out',i)
#                 print(len(line_split))

    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
    print('# steps:',df.shape[0])
    print('# cols:',df.shape[1], '; 0~'+str(df.shape[1]-1))
    fig, ax = plt.subplots(figsize=figsize)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan']
    
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    
        
    idx=0
    i = show_list[0]
    if isinstance(i,list):
        x = sgd*df[0]
        if list_opt == 'divide' or list_opt == 'x/y':
            y = df[i[0]]/df[i[1]]
        elif list_opt == 'product' or list_opt == 'x*y':
            y = df[i[0]]*df[i[1]]
        elif list_opt == 'sum' or list_opt == 'x+y':
            y = df[i[0]]+df[i[1]]
        elif list_opt == 'x^2y':
            y = df[i[0]]**2*df[i[1]]
        elif list_opt == 'x-y':
            y = df[i[0]]-df[i[1]]
    else:
        if opt=='delay':                
            x = sgd*df[0][1:]
            y = df[i][:-1]
        else:              
            x = sgd*df[0]                
            if opt=='abs':
                y = df[i].abs()
            elif opt=='sqrt':
                y = np.sqrt(df[i]) 
            elif opt=='square':
                y = df[i]**2    
            else:
                y = df[i]
    idx=1
    i = show_list[1]
    if isinstance(i,list):
        if list_opt == 'divide' or list_opt == 'x/y':
            z = df[i[0]]/df[i[1]]
        elif list_opt == 'product' or list_opt == 'x*y':
            z = df[i[0]]*df[i[1]]
        elif list_opt == 'sum' or list_opt == 'x+y':
            z = df[i[0]]+df[i[1]]
        elif list_opt == 'x^2y':
            z = df[i[0]]**2*df[i[1]]
        elif list_opt == 'x-y':
            z = df[i[0]]-df[i[1]]
    else:
        if opt=='delay':        
            z = df[i][:-1]
        else:                            
            if opt=='abs':
                z = df[i].abs()
            elif opt=='sqrt':
                z = np.sqrt(df[i]) 
            elif opt=='square':
                z = df[i]**2    
            else:
                z = df[i]
                
                
        if scatter:
            ax.scatter(x,y,c=colors[0],label=label[0],s=s)
            ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
        else:
            if idx in dashed:
                ax.plot(x,y,c=colors[0],label=label[0],linestyle='--',linewidth=linewidth)
            else:
                ax.plot(x,y,c=colors[0],label=label[0],linewidth=linewidth)
            ax.fill_between(x, y-z, y+z, facecolor=colors[1],alpha=alpha)
#             ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    
    
    major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
    ax.set_xticks(major_ticks)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True,which='both',axis='x')
    
        
    if axhlines is not None:
        for axhline in axhlines:
            ax.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
    if axvlines is not None:
        for axvline in axvlines:
            ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
        elif len(t_list)==2:
            ax_twin.set_ylabel(t_list[1])
        elif len(t_list)==3:
            ax_twin2.set_ylabel(t_list[2])
        else:
            ax_twin3.set_ylabel(t_list[3])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('pdfs/'+save_name+time_now+'.png')
        plt.savefig('pdfs/'+save_name+time_now+'.pdf')
    plt.show()
    
    
    if acc_plot:
        fig, ax2 = plt.subplots()
        ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
        ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
        print('Best Test Acc',df[9].max())
        print('Best Train Acc',df[5].max())
        print('Best Train Loss',df[6].min())
        if std:
            pass
        else:
            ax2.plot(df[0], df[7], label='Train Adv Acc',c=colors[0])
            ax2.plot(df[0][1:], df[11][:-1], label='Test Adv Acc',c=colors[0], linestyle='--')
            ax2.axhline(y=df[11].max(),c='k',linestyle=':')
            ax2.text(df[11].argmax()+1,df[11].max(),str(df[11].max()))   

        ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
        ax2.set_ylabel('Acc')
        ax2.set_xlim(acc_plot_xlim)
        ax2.set_ylim(acc_plot_ylim)
        plt.show() 
            


def figure_box(log, path=None, std=False,
               xlim=[None,None],
               ylim=[None,None],ylim1=[None,None],ylim2=[None,None],ylim3=[None,None],acc_plot_xlim=[None,None],acc_plot_ylim=[None,None],
               show_list=None,show_list1=None,show_list2=None,show_list3=None,
               opt=None,opt1=None,opt2=None,opt3=None,
               c_list=None, c_list1=None, c_list2=None,c_list3=None,
               l_list=None, l_list1=None, l_list2=None,l_list3=None,
               dashed=[], dashed1=[], dashed2=[],dashed3=[],
               t_list=None,
               list_opt=None,list_opt1=None,list_opt2=None,list_opt3=None,
               ncol=1,ncol1=1,ncol2=1,ncol3=1,
               scatter=False,scatter1=False,scatter2=False,scatter3=False,
               linewidth=1,
               ticks=1000,
               s=5,
               save_name='fig_num',
               figsize=(15,5),
               gap=0.15,
               axhlines=None,axhlines1=None,axhlines2=None,axhlines3=None,
               axvlines=None,
               text=False,
               log_plot=False, 
               step=1,
               sgd=1,
               savefig=False, 
               snMlim=4000,
               add_log=None, 
               acc_plot=True,
               opt_sum=False,
               alpha=0.9):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')

    if path is None:
        path = './log/'
        
    if os.path.isfile(path+log):
        f = open(path+log, 'r')
    else:
        prefixed = [filename for filename in os.listdir(path) if filename.startswith(log)]
        print(prefixed)
        f = open(path+prefixed[-1], 'r')

    data = f.read()
    df_list = []

    first_line = data.split('\n')[1]

    header = first_line.split('] -')[1].split('\t')
    print(data.split('\n')[0])
    for i, line in enumerate(data.split('\n')[1:]):
        if len(line)>0:
            line_split = line.split('] -')[1].split('\t')
            if len(line_split)==len(header):
                df_list.append(line_split) 
#             else:
#                 print('last out',i)
#                 print(len(line_split))
    df_origin = pd.DataFrame(df_list,dtype=np.float)
    df = pd.DataFrame(df_list,dtype=np.float).rolling(window=step).mean()
    print('# steps:',df.shape[0])
    print('# cols:',df.shape[1], '; 0~'+str(df.shape[1]-1))
    fig, ax = plt.subplots(figsize=figsize)
    plt.rc('font', size=15)
    tab_c = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan']
    
    
    if c_list is not None:        
        colors = c_list
    else:
        colors = tab_c
        
    if l_list is not None:        
        label = l_list
    else:
        label = [str(i) for i in show_list]    
        
    idx=0
    i = show_list[0]
    if isinstance(i,list):
        x = sgd*df[0]
        if list_opt == 'divide' or list_opt == 'x/y':
            y = df[i[0]]/df[i[1]]
        elif list_opt == 'product' or list_opt == 'x*y':
            y = df[i[0]]*df[i[1]]
        elif list_opt == 'sum' or list_opt == 'x+y':
            y = df[i[0]]+df[i[1]]
        elif list_opt == 'x^2y':
            y = df[i[0]]**2*df[i[1]]
        elif list_opt == 'x-y':
            y = df[i[0]]-df[i[1]]
    else:
        if opt=='delay':                
            x = sgd*df[0][1:]
            y = df[i][:-1]
        else:              
            x = sgd*df[0]                
            if opt=='abs':
                y = df[i].abs()
            elif opt=='sqrt':
                y = np.sqrt(df[i]) 
            elif opt=='square':
                y = df[i]**2    
            else:
                y = df[i]
    idx=1
    i = show_list[1]
    if isinstance(i,list):
        if list_opt == 'divide' or list_opt == 'x/y':
            z = df[i[0]]/df[i[1]]
        elif list_opt == 'product' or list_opt == 'x*y':
            z = df[i[0]]*df[i[1]]
        elif list_opt == 'sum' or list_opt == 'x+y':
            z = df[i[0]]+df[i[1]]
        elif list_opt == 'x^2y':
            z = df[i[0]]**2*df[i[1]]
        elif list_opt == 'x-y':
            z = df[i[0]]-df[i[1]]
    else:
        if opt=='delay':        
            z = df[i][:-1]
        else:                            
            if opt=='abs':
                z = df[i].abs()
            elif opt=='sqrt':
                z = np.sqrt(df[i]) 
            elif opt=='square':
                z = df[i]**2    
            else:
                z = df[i]
                
    idx=2
    i = show_list[2]
    if isinstance(i,list):
        if list_opt == 'divide' or list_opt == 'x/y':
            w = df[i[0]]/df[i[1]]
        elif list_opt == 'product' or list_opt == 'x*y':
            w = df[i[0]]*df[i[1]]
        elif list_opt == 'sum' or list_opt == 'x+y':
            w = df[i[0]]+df[i[1]]
        elif list_opt == 'x^2y':
            w = df[i[0]]**2*df[i[1]]
        elif list_opt == 'x-y':
            w = df[i[0]]-df[i[1]]
    else:
        if opt=='delay':        
            w = df[i][:-1]
        else:                            
            if opt=='abs':
                w = df[i].abs()
            elif opt=='sqrt':
                w = np.sqrt(df[i]) 
            elif opt=='square':
                w = df[i]**2    
            else:
                w = df[i]
              
    x[0] = 0
    y[0] = df_origin[show_list[0]][0]
    z[0] = df_origin[show_list[1]][0]
    w[0] = df_origin[show_list[2]][0]
    
    
    if scatter:
        ax.scatter(x,y,c=colors[0],label=label[0],s=s)
        ax.legend(loc='lower left',bbox_to_anchor=(0,1),scatterpoints=10,ncol=ncol)
    else:
        if idx in dashed:
            ax.plot([x[0],x[step+1]],[y[0],y[step+1]],c=colors[0],label=label[0],linestyle='--',linewidth=linewidth)
            ax.plot(x,y,c=colors[0],label=label[0],linestyle='--',linewidth=linewidth)
        else:
            ax.plot([x[0],x[step+1]],[y[0],y[step+1]],c=colors[0],label=label[0],linewidth=linewidth)
            ax.plot(x,y,c=colors[0],label=label[0],linewidth=linewidth)
#         print([x[0],x[step+1]],[z[0],z[step+1]],[w[0],w[step+1]])
        ax.fill_between([x[0],x[step+1]],[z[0],z[step+1]],[w[0],w[step+1]], facecolor=colors[1],alpha=alpha)
        ax.fill_between(x, z, w, facecolor=colors[1],alpha=alpha)
#             ax.legend(loc='lower left',bbox_to_anchor=(0,1),ncol=ncol)
            
        
    ax.set_xlabel('Step')
    
    
    if xlim[0] is None:
        xlim[0]=0
        xlim[1]=len(df)        
    
    
    major_ticks = np.arange(xlim[0], xlim[1]+1, ticks)
    ax.set_xticks(major_ticks)
    
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.grid(True,which='both',axis='x')
    
        
    if axhlines is not None:
        for axhline in axhlines:
            ax.axhline(y=axhline, alpha=0.5, linestyle='--', c=colors[0])
    if axvlines is not None:
        for axvline in axvlines:
            ax.axvline(x=axvline, alpha=0.5, linestyle='--', c='k')
    
    if t_list:
        if len(t_list)==1:
            ax.set_ylabel(t_list[0])
        elif len(t_list)==2:
            ax_twin.set_ylabel(t_list[1])
        elif len(t_list)==3:
            ax_twin2.set_ylabel(t_list[2])
        else:
            ax_twin3.set_ylabel(t_list[3])
    
    plt.tight_layout()
    if savefig:
        plt.savefig('pdfs/'+save_name+time_now+'.png')
        plt.savefig('pdfs/'+save_name+time_now+'.pdf')
    plt.show()
    
    
    if acc_plot:
        fig, ax2 = plt.subplots()
        ax2.plot(df[0], df[5], label='Train Acc',c=colors[-1])
        ax2.plot(df[0][1:], df[9][:-1], label='Test Acc',c=colors[-1], linestyle='--')
        print('Best Test Acc',df[9].max())
        print('Best Train Acc',df[5].max())
        print('Best Train Loss',df[6].min())
        if std:
            pass
        else:
            ax2.plot(df[0], df[7], label='Train Adv Acc',c=colors[0])
            ax2.plot(df[0][1:], df[11][:-1], label='Test Adv Acc',c=colors[0], linestyle='--')
            ax2.axhline(y=df[11].max(),c='k',linestyle=':')
            ax2.text(df[11].argmax()+1,df[11].max(),str(df[11].max()))   

        ax2.legend(loc='lower right',bbox_to_anchor=(1,1))
        ax2.set_ylabel('Acc')
        ax2.set_xlim(acc_plot_xlim)
        ax2.set_ylim(acc_plot_ylim)
        plt.show() 

def figure_umap(data, start, end, step=20,n_max=500,n_start=0,name='umap',savefig=0):
    time_now = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
    start = start-n_start
    end = end-n_start
    plt.scatter(data[start:end+1,0],data[start:end+1,1], c=np.arange(n_start+start,np.min((n_start+end+1,n_max))), cmap='jet', s=5)
    plt.plot(data[start:end+1, 0], data[start:end+1, 1], linewidth=0.3, c='k')
    cbar = plt.colorbar(boundaries=np.arange(n_start+start,n_start+end+1), ticks=np.hstack([np.arange(n_start+start,n_start+end,step),[n_start+end]]))
    cbar.set_label('Step')
    if savefig:
        plt.savefig('./pdfs/'+name+str(time_now)+'.pdf')