import matplotlib.pyplot as plt
import datetime
import numpy as np

def plotAcc(arrs_to_plot, acc_idx, save_path, iter, 
            legend=["uniQuant", "choco-Quant"], 
            title_str="Accuracy over Time (Differential) w/ dith=", split_char='.',
            args=None,dist=5,b=20):
    iter_range = np.arange(1, (iter)+1, acc_idx)
    for res in arrs_to_plot:
        plt.plot(iter_range, res)
    # acc plot
    plt.title(title_str)
    plt.xlabel("Number of Iterations")
    plt.ylabel("Accuracy")

    save_path_acc = (save_path+"acc:"+"it"+str(iter)+'_lr'+str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                         +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    save_path_acc+= datetime.datetime.now().isoformat()
    save_path_acc = save_path_acc+"."+'jpg'
    plt.legend(legend, loc="upper left")
    print(save_path_acc)
    plt.savefig(save_path_acc)
    plt.clf()

def plotAccFinal(arrs_to_plot, acc_idx, save_path, iter, 
            legend=["uniQuant", "choco-Quant"], 
            title_str="Accuracy over Time (Differential) w/ dith=", split_char='.',
            args=None,dist=5,b=20, redux=1, axis_font=4,axis_ticks=5,legend_font=1, 
            y_lims=[0,100]):
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    plot_order=[1,2,0]
    loss_arr=arrs_to_plot[:, ::redux]
    loss_arr = loss_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==2:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    plt.ylim(y_lims)

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Test Accuracy",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)
    print("save_path:", jpg_save_path)
    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()
    

def plotloss(arrs_to_plot, acc_idx, save_path, iter, 
            legend=["uniQuant", "choco-Quant"], 
            title_str="loss over Time (Differential) w/ dith=",split_char='.',
            args=None,dist=5,b=20):
    iter_range = np.arange(1, (iter)+1, acc_idx)
    for res in arrs_to_plot:
        plt.plot(iter_range[1:], res[1:])

    # loss plot
    plt.title(title_str)
    plt.xlabel("Number of Iterations")
    plt.ylabel("loss")

    save_path_acc = (save_path+"loss:"+"it"+str(iter)+'_lr'+str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                         +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    save_path_acc+= datetime.datetime.now().isoformat()
    save_path_acc = save_path_acc+"."+'jpg'
    plt.legend(legend, loc="upper left")
    plt.savefig(save_path_acc)
    plt.clf()

def plotlossFinal(arrs_to_plot, acc_idx, save_path, iter, 
            legend=["uniQuant", "choco-Quant"], 
            title_str="loss over Time (Differential) w/ dith=",split_char='.',
            args=None,dist=5,b=20,redux=1, axis_font=4,axis_ticks=5,legend_font=1, 
            y_lims=[0,100]):
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    plot_order=[1,2,0]
    loss_arr=arrs_to_plot[:, ::redux]
    loss_arr = loss_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        #plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)
        if res ==2:
            plt.plot(iter_range[1:], loss_arr[res], linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    # loss plot
    # plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Loss",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"

    
    plt.tight_layout()
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def plotBit(arrs_to_plot, acc_idx, save_path, iter, legend=["malcom", "choco-Quant"], 
            title_str="Bits over Time (Differential) w/ dith=", plot_desc="bit", 
            args=None,b=20,dist=5):
    iter_range = np.arange(1, (iter)+1, acc_idx)
    for res in arrs_to_plot:
        plt.plot(iter_range, res)

    # bit plot
    plt.title(title_str)
    plt.xlabel("Number of Iterations")
    plt.ylabel("Average Bits per Symbol")

    save_path_acc = (save_path+"bits:"+"it"+str(iter)+'_lr'+str(args['lr'])+"_tr"+str(args['threshold'])+"_lvl" 
                         +str(args['bit'])+"_b"+str(b)+'_e'+str(args['E'])+'dist_'+str(dist))
    save_path_acc+= datetime.datetime.now().isoformat()
    save_path_acc = save_path_acc+"."+'jpg'
    plt.legend(legend, loc="upper left")
    plt.savefig(save_path_acc)
    plt.clf()

def plotBitFinal(arrs_to_plot, acc_idx, save_path, iter, legend=["malcom", "choco-Quant"], 
            title_str="Bits over Time (Differential) w/ dith=", plot_desc="bit", 
            args=None,b=20,dist=5,redux=1, axis_font=4,axis_ticks=5,legend_font=1, 
            y_lims=[0,100]):

    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    plot_order=[1,2,0]
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        if res ==2:
            plt.plot(iter_range[1:], arrs_to_plot[res][1:]*1.25e-7, linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], arrs_to_plot[res][1:]*1.25e-7,linewidth=4.0)

    # bit plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Communication Cost (MB)",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.tight_layout()
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def bitsFinalPlot(npy_file_path, save_path, acc_idx, iter, 
                  title_str="Number of bits per Iteration", 
                  legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"], axis_font=20,axis_ticks=16,
                  legend_font=19):
    
    bits_arr= np.load(npy_file_path)
    print(bits_arr.shape)
    iter_range = np.arange(1, (iter)+1, acc_idx)
    plot_order=[1,2,0]
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        if res ==2:
            plt.plot(iter_range[1:], bits_arr[res][1:]*1.25e-7, linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], bits_arr[res][1:]*1.25e-7,linewidth=4.0)

    # bit plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Communication Cost (MB)",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.tight_layout()
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def lossFinalPlot(npy_file_path, save_path, acc_idx, iter, 
                  title_str="Loss per Iteration", legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"], 
                  redux=3,axis_font=20,axis_ticks=16,legend_font=19):
    loss_arr= np.load(npy_file_path)
    # if redux %2==1:
    #     iter-=1
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    plot_order=[1,2,0]
    loss_arr=loss_arr[:, ::redux]
    loss_arr = loss_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        #plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)
        if res ==2:
            plt.plot(iter_range[1:], loss_arr[res], linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    # loss plot
    # plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Loss",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"

    
    plt.tight_layout()
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def accFinalPlot(npy_file_path, save_path, acc_idx, iter, 
                  title_str="Test Accuracy per Iteration", legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"], 
                  redux=3, y_lims=(20,85), axis_font=20,axis_ticks=16,legend_font=19):
    loss_arr= np.load(npy_file_path)
    # if redux %2==1:
    #     iter-=1
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    plot_order=[1,2,0]
    loss_arr=loss_arr[:, ::redux]
    loss_arr = loss_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==2:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    plt.ylim(y_lims)

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Test Accuracy",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=save_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)
    print(jpg_save_path)
    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()



def genplots():
    bitsFinalPlot("data/datag2/monte/np_arr/MC_bit:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.453449.npy",
                    "data/dataPaper/MC_bit:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.453449.npy",
                    25, 2001)
    bitsFinalPlot("data/datag2/monte/np_arr/MC_bit:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.659862.npy",
                  "data/dataPaper/MC_bit:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.659862.npy",
                    1, 100)
    
    lossFinalPlot("data/datag2/monte/np_arr/MC_loss:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.451323.npy",
                  "data/dataPaper/MC_loss:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.451323.npy",
                  25,2001,redux=2)
    lossFinalPlot("data/datag2/monte/np_arr/MC_loss:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.649224.npy",
                  "data/dataPaper/MC_loss:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.649224.npy",
                  1,100, redux=2)
    accFinalPlot("data/datag2/monte/np_arr/MC_acc:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.452730.npy",
                 "data/dataPaper/MC_acc:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.452730.npy",
                 25,2001,redux=3, y_lims=(10,85))
    
    accFinalPlot("data/datag2/monte/np_arr/MC_acc:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.659149.npy",
                 "data/dataPaper/MC_acc:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.659149.npy",
                 1,100,redux=1, y_lims=[10,95])
    
    #plots for resnet
    bitsFinalPlot("data/datag2/final/res/monte/np_arr/MC_bit:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.450918.npy",
                  "data/dataPaper/MC_bit:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.450918.npy",
                    1, 50)
    accFinalPlot("data/datag2/final/res/monte/np_arr/MC_acc:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.450165.npy",
                 "data/dataPaper/MC_acc:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.450165.npy",
                 1,50,redux=1, y_lims=[10,70])
    lossFinalPlot("data/datag2/final/res/monte/np_arr/MC_loss:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.447752.npy",
                  "data/dataPaper/MC_loss:it50_lr0.1_tr5e-05_lvl7_b200_e1dist_12023-09-26T20:44:20.447752.npy",
                  1,50, redux=1)
    
    
def genPlot4by1():
    mergeAndPlotlossDif("data/datag2/monte/np_arr/MC_loss:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.451323.npy",
                  "data/dataPaper/MC_loss:it2001Combined.npy",
                  "data/datag2/final/full/monte/np_arr/MC_loss:it1501_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-27T02:00:33.007601.npy",
                  25,1501,redux=2,axis_font=24,axis_ticks=20,legend_font=23)
    mergeAndPlotloss("data/datag2/monte/np_arr/MC_loss:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.649224.npy",
                  "data/dataPaper/MC_loss:it100COmbined.npy",
                  "data/datag2/final/full/monte/np_arr/MC_loss:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-26T14:55:21.550666.npy",
                  1,100, redux=2,axis_font=24,axis_ticks=20,legend_font=23)
    mergeAndPlotAccDiff("data/datag2/final/full/monte/np_arr/MC_acc:it2001_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T12:55:14.452730.npy",
                 "data/dataPaper/MC_acc:it2001Combined.npy",
                 "data/datag2/final/full/monte/np_arr/MC_acc:it1501_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-27T02:00:33.009035.npy",
                 25,1501,redux=3, y_lims=(0,85),axis_font=24,axis_ticks=20,legend_font=23)
    mergeAndPlotAcc("data/datag2/monte/np_arr/MC_acc:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-23T06:27:55.659149.npy",
                 "data/dataPaper/MC_acc:it100Combined.npy",
                 "data/datag2/final/full/monte/np_arr/MC_acc:it100_lr0.2_tr5e-05_lvl7_b200_e1dist_52023-09-26T14:55:21.558387.npy",
                 1,100,redux=1, y_lims=[0,95],axis_font=24,axis_ticks=20,legend_font=23)
    
def mergeAndPlotlossDif(og_path, out_path, new_choco_path, acc_idx,iter, redux=1,
                    y_lims=[10,95],axis_font=24,axis_ticks=20,legend_font=23,
                    legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"]):
    og_acc_arr= np.load(og_path)
    choco_acc = np.load(new_choco_path)

 
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    og_acc_arr=og_acc_arr[:, ::redux]
    choco_acc=choco_acc[:,::redux]


    og_acc_arr=np.delete(og_acc_arr, 1, axis=0) # removes choco
    #og_acc_arr=np.delete(og_acc_arr,1,axis=0) # removes base
    choco_acc=np.delete(choco_acc, 1, axis=0) # removes choco
    choco_acc=np.delete(choco_acc,1,axis=0) # removes base
    print(og_acc_arr.shape)
    print("choco: ", choco_acc.shape)

    # Desired dimensions
    desired_length = og_acc_arr.shape[1]
    # Calculate the number of values to add as NaN
    values_to_add = desired_length - choco_acc.shape[1]
    og_acc_arr=og_acc_arr[:,:-values_to_add]

    og_acc_arr= np.vstack((og_acc_arr,choco_acc))
  

    plot_order=[2,1,0]
    loss_arr = og_acc_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==1:
            plt.plot(iter_range[1:], loss_arr[res]*1e-35,linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res]*1e-35,linewidth=4.0)

    

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Loss(~1e35)",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=out_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def mergeAndPlotAccDiff(og_path, out_path, new_choco_path, acc_idx,iter, redux=1,
                    y_lims=[10,95],axis_font=24,axis_ticks=20,legend_font=23,
                    legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"]):
    og_acc_arr= np.load(og_path)
    choco_acc = np.load(new_choco_path)

    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    og_acc_arr=og_acc_arr[:, ::redux]
    choco_acc=choco_acc[:,::redux]

    og_acc_arr=np.delete(og_acc_arr, 1, axis=0) # removes choco
    #og_acc_arr=np.delete(og_acc_arr,1,axis=0) # removes base

    choco_acc=np.delete(choco_acc, 1, axis=0) # removes choco
    choco_acc=np.delete(choco_acc,1,axis=0) # removes base

    print("here!")
    print(choco_acc.shape)
    print(og_acc_arr.shape)

    # Desired dimensions
    desired_length = og_acc_arr.shape[1]
    # Calculate the number of values to add as NaN
    values_to_add = desired_length - choco_acc.shape[1]
    og_acc_arr=og_acc_arr[:,:-values_to_add]

    print(choco_acc.shape)
    # Pad the array with NaN values
    # choco_acc=np.pad(choco_acc[0], (0, values_to_add), mode='mean')
    # choco_acc=choco_acc.reshape(1, -1)
    #print(x.shape)
    print(choco_acc.shape)

    og_acc_arr= np.vstack((og_acc_arr,choco_acc))

    plot_order=[2,1,0]
    loss_arr = og_acc_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==1:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    plt.ylim(y_lims)

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Test Accuracy",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=out_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def mergeAndPlotAcc(og_path, out_path, new_choco_path, acc_idx,iter, redux=1,
                    y_lims=[10,95],axis_font=24,axis_ticks=20,legend_font=23,
                    legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"]):
    og_acc_arr= np.load(og_path)
    choco_acc = np.load(new_choco_path)

    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    og_acc_arr=og_acc_arr[:, ::redux]

    og_acc_arr=np.delete(og_acc_arr, 1, axis=0) # removes choco
    #og_acc_arr=np.delete(og_acc_arr,1,axis=0) # removes base

    choco_acc=np.delete(choco_acc, 1, axis=0) # removes choco
    choco_acc=np.delete(choco_acc,1,axis=0) # removes base

    og_acc_arr= np.vstack((og_acc_arr,choco_acc))

    plot_order=[2,1,0]
    loss_arr = og_acc_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==1:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res],linewidth=4.0)

    plt.ylim(y_lims)

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Test Accuracy",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=out_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()

def mergeAndPlotloss(og_path, out_path, new_choco_path, acc_idx,iter, redux=1,
                    y_lims=[10,95],axis_font=24,axis_ticks=20,legend_font=23,
                    legend=["Choco-SGD","Decentralized-SGD","Malcom-PSGD"]):
    og_acc_arr= np.load(og_path)
    choco_acc = np.load(new_choco_path)

 
    iter_range = np.arange(1, (iter)+1, acc_idx*redux)
    og_acc_arr=og_acc_arr[:, ::redux]
    choco_acc=choco_acc[:,::redux]


    og_acc_arr=np.delete(og_acc_arr, 1, axis=0) # removes choco
    #og_acc_arr=np.delete(og_acc_arr,1,axis=0) # removes base
    choco_acc=np.delete(choco_acc, 1, axis=0) # removes choco
    choco_acc=np.delete(choco_acc,1,axis=0) # removes base
    print(og_acc_arr.shape)
    print("choco: ", choco_acc.shape)

    og_acc_arr= np.vstack((og_acc_arr,choco_acc))
  

    plot_order=[2,1,0]
    loss_arr = og_acc_arr[:, 1:]

    i=1
    #iter=
    #plt.figure(figsize=(8, 6))
    for res in plot_order:
        print(iter_range[1:].shape)
        print(loss_arr[res].shape)
        if res ==1:
            plt.plot(iter_range[1:], loss_arr[res]*1e-24,linewidth=4.0, linestyle='--', color='k')
        else:
            plt.plot(iter_range[1:], loss_arr[res]*1e-24,linewidth=4.0)

    

    # acc plot
    #plt.title(title_str)
    plt.xlabel("Number of Iterations",fontsize=axis_font)
    plt.ylabel("Loss(~1e24)",fontsize=axis_font)
    plt.xticks(fontsize=axis_ticks)
    plt.yticks(fontsize=axis_ticks)
    plt.xticks(rotation=45)
    plt.tight_layout()
    save_path_arr=out_path.split(".")
    jpg_save_path=""
    for i in range(len(save_path_arr)-1):
        jpg_save_path+=save_path_arr[i]
    # jpg_save_path+=".eps"
    size_str=str(axis_font)+"_"+str(axis_ticks)+"_"+str(legend_font)

    plt.legend(legend, loc="best", fontsize=legend_font)
    plt.savefig(jpg_save_path+size_str+".eps",dpi=1200, format='eps')
    plt.savefig(jpg_save_path+size_str+".jpg",dpi=1200, format='jpg')
    plt.clf()