import os
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
import argparse
from utils import (make_dir, plot_sparce_vs_dense_cluster_adj,
                    getData)
from models import AGC


parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--fsize', type=int, default=15)
parser.add_argument('--max_num_labels', type=int, default=5)
# parser.add_argument('--plot_load_dir', type=str, default='Results_mat_v12Guided_guidedminRandomKnnGraph60')
# parser.add_argument('--plot_save_dir', type=str, default='Paper_plot/v12Guided_guidedminRandomKnnGraph60')
parser.add_argument('--plot_load_dir', type=str, default='Results_mat_v13Random_guidedminRandomKnnGraph60')
parser.add_argument('--plot_save_dir', type=str, default='Paper_plot/v13Random_guidedminRandomKnnGraph60')



def StatsSavingPloting(plot_save_dir, plot_load_dir, NUM_LABELS=5, 
                       plot_std=False, plot_experiments=True, 
                       appendix_plot=False, plot_begining_demo=False,
                       AGC_plot=False, fsize=5, xfsize=8):
    print('===============================================')
    
    dataset_list = ['cora', 'pubmed', 'photo', 'cs', 'citeseer', 'wiki', 'computers']
    num_classes = [7, 3, 8, 15, 6, 17, 10]
    for i in range(len(dataset_list)):
        dataset = dataset_list[i]
        acc_save = sio.loadmat(os.path.join(plot_load_dir,
                                        f'fewShot_SGC_{dataset}.mat'))
        num_labels = np.array(range(1, NUM_LABELS+1))*num_classes[i]

        print(f'[INFO] dataset: {dataset}, num_clases: {num_classes[i]}, ' 
              f'num_labels_for_plot: {num_labels.max()//num_classes[i]}, '
              f'min_num_labels_in_plot: {num_labels.min()//num_classes[i]} ,'
              f'min number in num_labels: {num_labels.min()}')

        acc_SGC_list = acc_save[dataset + 'acc_SGC'][0, :NUM_LABELS] 
        acc_AGC_list = acc_save[dataset + 'acc_AGC'][0, :NUM_LABELS] 
        acc_LP_list = acc_save[dataset + 'acc_LP'][0, :NUM_LABELS] 
        acc_GCN_list = acc_save[dataset + 'acc_GCN'][0, :NUM_LABELS] 
        acc_LP_AGC_list = acc_save[dataset + 'acc_LP_AGC'][0, :NUM_LABELS] 
        acc_SGC_AGC_list = acc_save[dataset + 'acc_SGC_AGC'][0, :NUM_LABELS]  
        acc_SGC_AGC_LP_list = acc_save[dataset + 'acc_SGC_AGC_LP'][0, :NUM_LABELS] 
        ac_CandS_agc_list = acc_save[dataset + 'ac_CandS_AGC'][0, :NUM_LABELS]
        acc_DGI_list = acc_save[dataset + 'ac_DGI'][0, :NUM_LABELS] 
        acc_GMI_list = acc_save[dataset + 'ac_GMI'][0, :NUM_LABELS] 

        std_SGC_list = acc_save[dataset + 'std_SGC'][0, :NUM_LABELS]  
        std_AGC_list = acc_save[dataset + 'std_AGC'][0, :NUM_LABELS]
        std_LP_list = acc_save[dataset + 'std_LP'][0, :NUM_LABELS] 
        std_GCN_list = acc_save[dataset + 'std_GCN'][0, :NUM_LABELS] 
        std_LP_AGC_list = acc_save[dataset + 'std_LP_AGC'][0, :NUM_LABELS] 
        std_SGC_AGC_list = acc_save[dataset + 'std_SGC_AGC'][0, :NUM_LABELS] 
        std_SGC_AGC_LP_list = acc_save[dataset + 'std_SGC_AGC_LP'][0, :NUM_LABELS] 
        std_CandS_agc_list = acc_save[dataset + 'std_CandS_AGC'][0, :NUM_LABELS] 
        std_DGI_list = acc_save[dataset + 'std_DGI'][0, :NUM_LABELS]
        std_GMI_list = acc_save[dataset + 'std_GMI'][0, :NUM_LABELS]

        if AGC_plot:
            plt.figure()
            plt.tight_layout()
            save_name ="Appendix_AGC"
            print(f"[{save_name.upper()}] ploting {dataset} for experiments with STD: {plot_std}")

            x, y, error = np.array(num_labels), np.array(acc_AGC_list), np.array(std_AGC_list)
            # plt.fill_between(x, y - error, y + error, alpha=0.5, color='k')
            # plt.plot(x, y, 'k-*', linewidth=2, label=f'AGC:{AGC_power}')
            plt.plot(x, y, 'k-*', linewidth=2, label=f'AGC')

            plt.xlabel('Number of labeled nodes used for training', fontsize=fsize)
            plt.ylabel('Accuracy', fontsize=fsize)
            plt.xlim([num_labels[0], num_labels[-1]])
            plt.ylim([0, 1])
            plt.xticks(num_labels.astype(int), fontsize=xfsize)
            plt.yticks(fontsize=fsize)
            plt.title(f'{dataset}', fontsize=fsize)
            plt.legend(fontsize=int(fsize*0.8), loc=4)

            save_dir = plot_save_dir + f"/{save_name}"
            make_dir(save_dir)
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.png'))
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.pdf'))

        if plot_experiments or appendix_plot:
            plt.figure()
            plt.tight_layout()
            if plot_experiments:
                save_name = 'Expeririments'
            elif appendix_plot:
                save_name = 'Appendix'
            print(f"[{save_name.upper()}] ploting {dataset} for experiments with STD: {plot_std}")
            
            if appendix_plot:
                x, y, error = np.array(num_labels), np.array(acc_AGC_list), np.array(std_AGC_list)
                # plt.fill_between(x, y - error, y + error, alpha=0.5, color='k')
                # plt.plot(x, y, 'k-*', linewidth=2, label=f'AGC:{AGC_power}')
                plt.plot(x, y, 'k-*', linewidth=2, label=f'AGC')

            x, y, error = np.array(num_labels), np.array(acc_SGC_list), np.array(std_SGC_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
                                # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
            # plt.plot(x, y, 'b-*', linewidth=2, label=f'SGC: power={SGC_power}')
            plt.plot(x, y, 'b-*', linewidth=2, label=f'SGC')

            x, y, error = np.array(num_labels), np.array(acc_LP_list), np.array(std_LP_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
                                    # edgecolor='#3F7F4C', facecolor='#7EFF99', linewidth=3)
                # plt.plot(x, y, 'g-*', linewidth=2, label=f'LP: power={LP_power}')
            plt.plot(x, y, 'g-*', linewidth=2, label=f'LP')

            x, y, error = np.array(num_labels), np.array(acc_LP_AGC_list), np.array(std_LP_AGC_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='y')
                                # edgecolor='#CC4F1B', facecolor='#FF9848', linewidth=3)
            # plt.plot(x, y, 'y-*', linewidth=2, label=f'LP-AGC={alpha_few_shot_lp}')
            plt.plot(x, y, 'y-*', linewidth=2, label=f'LP-ELI')

            x, y, error = np.array(num_labels), np.array(acc_SGC_AGC_list), np.array(std_SGC_AGC_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='m')
                                    # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
            # plt.plot(x, y, 'm-*', linewidth=2, label=f'SGC-AGC={alpha_few_shot_sgc}')
            plt.plot(x, y, 'm-*', linewidth=2, label=f'SGC-ELI')


            x, y, error = np.array(num_labels), np.array(acc_GMI_list), np.array(std_GMI_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='c')
            plt.plot(x, y, 'c-*', linewidth=2, label=f'GMI')

            x, y, error = np.array(num_labels), np.array(acc_DGI_list), np.array(std_DGI_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='r')
            plt.plot(x, y, 'r-*', linewidth=2, label=f'DGI')

            plt.xlabel('Number of labeled nodes used for training', fontsize=fsize)
            plt.ylabel('Accuracy', fontsize=fsize)
            plt.xlim([num_labels[0], num_labels[-1]])
            plt.ylim([0, 1])
            plt.xticks(num_labels.astype(int), fontsize=xfsize)
            plt.yticks(fontsize=fsize)
            plt.title(f'{dataset}', fontsize=fsize)
            plt.legend(fontsize=int(fsize*0.8), loc=4)

            save_dir = plot_save_dir + f"/{save_name}"
            make_dir(save_dir)
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.png'))
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.pdf'))
        
        if plot_begining_demo:
            plt.figure()
            plt.tight_layout()
            save_name = "Demo"

            x, y, error = np.array(num_labels), np.array(acc_SGC_list), np.array(std_SGC_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
            plt.plot(x, y, 'b-*', linewidth=2, label=f'SGC')

            x, y, error = np.array(num_labels), np.array(acc_LP_list), np.array(std_LP_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
            plt.plot(x, y, 'g-*', linewidth=2, label=f'LP')

            x, y, error = np.array(num_labels), np.array(acc_GMI_list), np.array(std_GMI_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='c')
            plt.plot(x, y, 'c-*', linewidth=2, label=f'GMI')

            x, y, error = np.array(num_labels), np.array(acc_DGI_list), np.array(std_DGI_list)
            if plot_std:
                plt.fill_between(x, y - error, y + error, alpha=0.2, color='r')
            plt.plot(x, y, 'r-*', linewidth=2, label=f'DGI')

            plt.xlabel('Number of labeled nodes used for training', fontsize=fsize)
            plt.ylabel('Accuracy', fontsize=fsize)
            plt.xlim([num_labels[0], num_labels[-1]])
            plt.ylim([0, 1])
            plt.xticks(num_labels.astype(int), fontsize=xfsize)
            plt.yticks(fontsize=fsize)
            plt.title(f'{dataset}', fontsize=fsize)
            plt.legend(fontsize=int(fsize*0.8), loc=4)

            save_dir = plot_save_dir + f"/{save_name}"
            make_dir(save_dir)
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.png'))
            plt.savefig(os.path.join(save_dir, f'std{plot_std}_{dataset}.pdf'))


def plot_AGC(plot_save_dir, plot_load_dir, fsize=10):
    
    print('===============================================')
    save_name = "Appendix_AGC"
    dataset_list = ['cora', 'pubmed', 'photo', 'cs', 'citeseer', 'wiki', 'computers']
    num_classes = [7, 3, 8, 15, 6, 17, 10]
    acc_AGC_list = []
    num_datasets = range(0, len(num_classes))
    plt.figure()
    plt.tight_layout()
    for i in range(len(dataset_list)):
        dataset = dataset_list[i]
        acc_save = sio.loadmat(os.path.join(plot_load_dir,
                                        f'fewShot_SGC_{dataset}.mat'))
        
        print(f'[INFO] dataset: {dataset}, num_clases: {num_classes[i]}')
        acc_AGC_list.append(acc_save[dataset + 'acc_AGC'][0, 0]) 

    plt.bar(dataset_list, acc_AGC_list, color ='maroon', width = 0.4)

    plt.xlabel('Datasets',  fontsize=fsize)
    plt.ylabel('% Clustering Accuracy', fontsize=fsize)
    plt.ylim([0, 1])
    plt.xticks(fontsize=fsize)
    plt.yticks(fontsize=fsize)
    plt.title(f'Modified AGC clustering performance',  fontsize=fsize)
    # plt.legend( fontsize=fsize)

    save_dir = plot_save_dir + f"/{save_name}"
    make_dir(save_dir)
    plt.savefig(os.path.join(save_dir, f'AGC_ALL_.png'))
    plt.savefig(os.path.join(save_dir, f'AGC_ALL.pdf'))


def optimized_Graph(dataset, fsize=12):

    adj, gnd, k, feature = getData(dataset)

    ac_agc, _, _, kmeans_, u_, prelab, best_p, KG = AGC(adj, feature, gnd,
                                                                num_NN=60, 
                                                                tol=0.001,
                                                                pow=60,
                                                                k=k, num_runs=1,
                                                                norm=None)

    plot_sparce_vs_dense_cluster_adj(prelab, adj, KG=KG, use_knn=True, fsize=fsize)
 

def sensitivity_and_ablation(plot_load_dir, plot_save_dir, 
                             lab_sensitivity=[1, 5, 10],
                             NUM_LABELS=5, neigh_sensitivity=[1,20, 40, 60, 80, 100], 
                             fsize=12):
    
    dataset="cora"

    # sensitivity
    acc_save = sio.loadmat(os.path.join(plot_load_dir,
                                    f'Sensitivity_{dataset}.mat'))
    num_classes =[7]
    num_labels = np.array(range(1, NUM_LABELS+1))*num_classes[0]

    print(f'[INFO] dataset: {dataset}, num_clases: {num_classes[0]}, ' 
            f'num_labels_for_plot: {num_labels.max()//num_classes[0]}, '
            f'min_num_labels_in_plot: {num_labels.min()//num_classes[0]} ,'
            f'min number in num_labels: {num_labels.min()}')
    
    print("======================================")
    print(f"[INFO] PLOTTING SENSITIVITY {dataset}")

    plt.figure()
    color = ["b-*", "r-*", "y-*", "m-*", "c-*", "g-*", "k-*"]

    for ii in range(len(lab_sensitivity)):
        num_lab = lab_sensitivity[ii]
        c1 = color[0]
        c2= color[1]
        sensitivity_SGC_AGC_acc = acc_save[f'num_lab{num_lab}_sensitivity_SGC_AGC_acc'][0, :]
        sensitivity_SGC_AGC_std = acc_save[f'num_lab{num_lab}_sensitivity_SGC_AGC_std'][0, :]
        sensitivity_LP_AGC_acc = acc_save[f'num_lab{num_lab}_sensitivity_LP_AGC_acc'][0, :]
        sensitivity_LP_AGC_std = acc_save[f'num_lab{num_lab}_sensitivity_LP_AGC_std'][0, :]

        # print(sensitivity_LP_AGC_acc, sensitivity_LP_AGC_std)
        # exit(0)

        x, y, error = np.array(neigh_sensitivity), np.array(sensitivity_LP_AGC_acc), np.array(sensitivity_LP_AGC_std)
        plt.fill_between(x, y - error, y + error, alpha=0.2, color=c1[0])
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
        plt.plot(x, y, c1, linewidth=1, label=f'LP-ELI-num_lab{num_lab}')

        x, y, error = np.array(neigh_sensitivity), np.array(sensitivity_SGC_AGC_acc), np.array(sensitivity_SGC_AGC_std)
        plt.fill_between(x, y - error, y + error, alpha=0.2, color=c2[0])
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
        plt.plot(x, y, c2, linewidth=1, label=f'SGC-ELI-num_lab{num_lab}')
        
        color.remove(c1)
        color.remove(c2)
    #     print(color)

    plt.xlabel('Number of neighbors used', fontsize=fsize)
    plt.ylabel('Accuracy', fontsize=fsize)
    plt.xlim([neigh_sensitivity[0], neigh_sensitivity[-1]])
    plt.xticks(neigh_sensitivity, fontsize=10)
    plt.ylim([0, 1])
    plt.yticks(fontsize=fsize)
    plt.title(f'{dataset}', fontsize=fsize)
    plt.legend(loc=4)

    save_name = "Sensitivity_plot"
    save_dir = plot_save_dir + f"/{save_name}"
    make_dir(save_dir)
    plt.savefig(os.path.join(save_dir, f'Sensitivity_{dataset}.png'))
    plt.savefig(os.path.join(save_dir, f'Sensititivity_{dataset}.pdf'))


    # ablation
    acc_save1 = sio.loadmat(os.path.join(plot_load_dir,
                                    f'Ablation_{dataset}.mat'))
    
    acc_save2 = sio.loadmat(os.path.join(plot_load_dir,
                                        f'fewShot_SGC_{dataset}.mat'))
    
    num_labels = np.array(range(1, NUM_LABELS+1))*num_classes[0]  

    no_kg_SGC_AGC_acc = acc_save1['no_kg_SGC_AGC_acc'][0, :NUM_LABELS] 
    no_KL_SGC_AGC_acc = acc_save1['no_KL_SGC_AGC_acc'][0, :NUM_LABELS] 
    full_SGC_AGC_acc = acc_save2[dataset + 'acc_SGC_AGC'][0, :NUM_LABELS]
    no_KL_no_KG_SGC_AGC_acc = acc_save2[dataset + 'acc_SGC'][0, :NUM_LABELS]
    no_kg_SGC_AGC_std = acc_save1['no_kg_SGC_AGC_std'][0, :NUM_LABELS] 
    no_KL_SGC_AGC_std = acc_save1['no_KL_SGC_AGC_std'][0, :NUM_LABELS] 
    full_SGC_AGC_std = acc_save2[dataset + 'std_SGC_AGC'][0, :NUM_LABELS]
    no_KL_no_KG_SGC_AGC_std = acc_save2[dataset + 'std_SGC'][0, :NUM_LABELS]

    no_kg_LP_AGC_acc = acc_save1['no_kg_LP_AGC_acc'][0, :NUM_LABELS] 
    no_KL_LP_AGC_acc = acc_save1['no_KL_LP_AGC_acc'][0, :NUM_LABELS] 
    no_KL_no_KG_LP_AGC_acc = acc_save2[dataset + 'acc_LP'][0, :NUM_LABELS] 
    full_LP_AGC_acc = acc_save2[dataset + 'acc_LP_AGC'][0, :NUM_LABELS] 
    no_kg_LP_AGC_std = acc_save1['no_kg_LP_AGC_std'][0, :NUM_LABELS] 
    no_KL_LP_AGC_std = acc_save1['no_KL_LP_AGC_std'][0, :NUM_LABELS] 
    full_LP_AGC_std = acc_save2[dataset + 'std_LP_AGC'][0, :NUM_LABELS] 
    no_KL_no_KG_LP_AGC_std = acc_save2[dataset + 'std_LP'][0, :NUM_LABELS] 

    # print(full_LP_AGC_std.shape)
    # exit(0)

    # plot LP
    print("======================================")
    print(f"[INFO] PLOTTING ABLATION LP {dataset}")
    plt.figure()    
    x, y, error = np.array(num_labels), np.array(no_kg_LP_AGC_acc), np.array(no_kg_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='y')
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
    plt.plot(x, y, 'y-*', linewidth=2, label=f'no KG LP-ELI')

    x, y, error = np.array(num_labels), np.array(no_KL_LP_AGC_acc), np.array(no_KL_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='m')
                          # edgecolor='#3F7F4C', facecolor='#7EFF99', linewidth=3)
    plt.plot(x, y, 'm-*', linewidth=2, label=f'no KL LP-ELI')

    x, y, error = np.array(num_labels), np.array(full_LP_AGC_acc), np.array(full_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'g-*', linewidth=2, label=f'full LP-ELI')

    x, y, error = np.array(num_labels), np.array(no_KL_no_KG_LP_AGC_acc), np.array(no_KL_no_KG_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'b-*', linewidth=2, label=f'no KL no KG LP-ELI')

    plt.xlabel('Number of labeled nodes used for training', fontsize=fsize)
    plt.xlim([num_labels[0], num_labels[-1]])
    plt.xticks(num_labels.astype(int), fontsize=8)
    plt.ylim([0, 1])
    plt.yticks(fontsize=fsize)
    plt.title(f'{dataset}', fontsize=fsize)
    plt.legend(loc=4)

    save_name = "Ablation_plot"
    save_dir = plot_save_dir + f"/{save_name}"
    make_dir(save_dir)
    plt.savefig(os.path.join(save_dir, f'LP_Ablation_{dataset}.png'))
    plt.savefig(os.path.join(save_dir, f'LP_Ablation_{dataset}.pdf'))

    # Plot SGC
    print("======================================")
    print(f"[INFO] PLOTTING ABLATION SGC {dataset}")
    plt.figure()
    x, y, error = np.array(num_labels), np.array(no_kg_SGC_AGC_acc), np.array(no_kg_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='y')
                     # edgecolor='#CC4F1B', facecolor='#FF9848', linewidth=3)
    plt.plot(x, y, 'y-*', linewidth=2, label=f'no KG SGC-ELI')

    x, y, error = np.array(num_labels), np.array(no_KL_SGC_AGC_acc), np.array(no_KL_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='m')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'm-*', linewidth=2, label=f'no KL SGC-ELI')

    x, y, error = np.array(num_labels), np.array(full_SGC_AGC_acc), np.array(full_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'g-*', linewidth=2, label=f'full SGC-ELI')

    x, y, error = np.array(num_labels), np.array(no_KL_no_KG_SGC_AGC_acc), np.array(no_KL_no_KG_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'b-*', linewidth=2, label=f'no KL no KG SGC-ELI')

    plt.xlabel('Number of labeled nodes used for training', fontsize=fsize)
    plt.ylabel('Accuracy', fontsize=fsize)
    plt.xlim([num_labels[0], num_labels[-1]])
    plt.xticks(num_labels.astype(int), fontsize=8)
    plt.ylim([0, 1])
    plt.yticks(fontsize=fsize)
    plt.title(f'{dataset}', fontsize=fsize)
    plt.legend(loc=4)

    save_name = "Ablation_plot"
    save_dir = plot_save_dir + f"/{save_name}"
    make_dir(save_dir)
    plt.savefig(os.path.join(save_dir, f'SGC_Ablation_{dataset}.png'))
    plt.savefig(os.path.join(save_dir, f'SGC_Ablation_{dataset}.pdf'))


if __name__=="__main__":

    args = parser.parse_args()

    # #  # appendix
    # plot_AGC(args.plot_save_dir, args.plot_load_dir, fsize=12)

    # # appendix
    # # optimized_Graph(args.dataset, fsize=12)

    # # # experiments
    StatsSavingPloting(args.plot_save_dir, args.plot_load_dir, NUM_LABELS=4, 
                       plot_std=True, plot_experiments=True, 
                       appendix_plot=False, plot_begining_demo=False,
                       AGC_plot=False, fsize=args.fsize)
    
    # # # Demo
    # StatsSavingPloting(args.plot_save_dir, args.plot_load_dir, NUM_LABELS=5, 
    #                    plot_std=True, plot_experiments=False, 
    #                    appendix_plot=False, plot_begining_demo=True,
    #                    AGC_plot=False, fsize=args.fsize)

    # # Appendix
    # StatsSavingPloting(args.plot_save_dir, args.plot_load_dir, NUM_LABELS=19, 
    #                    plot_std=True, plot_experiments=False, 
    #                    appendix_plot=True, plot_begining_demo=False,
    #                    AGC_plot=True, fsize=12, xfsize=8)

    # # # appendix
    # sensitivity_and_ablation('Results_mat_v13Random_guidedminRandomKnnGraph60',
    #                          'Paper_plot/v13Random_guidedminRandomKnnGraph60',
    #                          NUM_LABELS=19, fsize=15)