import numpy as np
import argparse
import scipy.io as sio
# import time
import os
import scipy as spy
import matplotlib.pyplot as plt
from utils import (make_dir, plot_sparce_vs_dense_cluster_adj,
                   getData, random_generate_masks,
                   unsupervised_guided_split)
from models import AGC, SGC_AGC, lp_agc


parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--fsize', type=int, default=15)
parser.add_argument('--num_runs', type=int, default=10)
parser.add_argument('--max_num_ablation_labels', type=int, default=20)
parser.add_argument('--plot_load_dir', type=str, default='Results_mat_v12Guided_guidedminRandomKnnGraph60')
parser.add_argument('--plot_save_dir', type=str, default='Paper_plot/v12Guided_guidedminRandomKnnGraph60')


def ablation(dataset, max_num_labels_ablation=20,
                         num_runs=10, fsize=12):
    
    adj, gnd, k, feature = getData(dataset)

    # # ablation
    save_dict = {}
    no_kg_SGC_AGC_acc = []
    no_KL_SGC_AGC_acc = []
    full_SGC_AGC_acc = []
    no_kg_SGC_AGC_std = []
    no_KL_SGC_AGC_std = []
    full_SGC_AGC_std = []

    no_kg_LP_AGC_acc = []
    no_KL_LP_AGC_acc = []
    full_LP_AGC_acc = []
    no_kg_LP_AGC_std = []
    no_KL_LP_AGC_std = []
    full_LP_AGC_std = []

    min_num_per_clus = 1
    type_random = "random"
    type_guided = 'guided-min-random' 
    SGC_ACG_power = 5
    LP_agc_power = 60
    alpha_few_shot_lp = [0.33, 0.33, 0.33]
    alpha_few_shot_sgc = [0.33, 0.33, 0.33]
    topK_sgc = 1
    topK_lp = 1
    lp_aplha = 0.5

    ac_agc, _, _, kmeans_, u_, prelab, best_p, KG = AGC(adj, feature, gnd,
                                                                num_NN=60, 
                                                                tol=0.001,
                                                                pow=60,
                                                                k=k, num_runs=1,
                                                                norm=None)
    num_labels = np.arange(1, max_num_labels_ablation, 1)

    for num_lab in num_labels:
        ac_LP_AGC_full = np.zeros(num_runs)
        ac_LP_AGC_no_KG = np.zeros(num_runs)
        ac_LP_AGC_no_KL = np.zeros(num_runs)
        ac_SGC_AGC_full = np.zeros(num_runs)
        ac_SGC_AGC_no_KG = np.zeros(num_runs)
        ac_SGC_AGC_no_KL = np.zeros(num_runs)
        
        for r in range(num_runs):
            # take the full from sparse label saved data
            # take the nokg no kl from sparse label saved data
            print("========================================================")
            print(f"[WITH KL GENERATE LABELS] num_labels: {num_lab}")
            print("========================================================")
            train_mask, test_mask = unsupervised_guided_split(kmeans_, prelab, u_, gnd,
                                                              num_lab=num_lab,
                                                              type_split=type_guided,
                                                              min_num=min_num_per_clus, 
                                                              type_random=type_random)
            print("======================")
            print("[NO KG WITH KL LP_AGC]")
            ac_LP_AGC_no_KG[r], _, _, lp_agc_lab = lp_agc(kmeans_, 
                                                          prelab, u_, adj, gnd,
                                                          train_mask, test_mask,
                                                          lp_prop_num=LP_agc_power,
                                                          lp_p=0.6, lp_alpha=lp_aplha,
                                                          alpha=alpha_few_shot_lp,
                                                          keep=topK_lp, knngraph=None,
                                                          modify_knn=False,
                                                          use_aug_graph=False)  # v1 used best_p

            print("======================")
            print("[NO KG WITH KL SGC_AGC]")
            ac_SGC_AGC_no_KG[r], _, _, _ = SGC_AGC(kmeans_, prelab, u_,
                                                   feature, adj, gnd,
                                                   test_mask, train_mask,
                                                   prop_num=SGC_ACG_power,
                                                   alpha=alpha_few_shot_sgc,
                                                   keep=topK_sgc,
                                                   message='only', knngraph=None,
                                                   modify_knn=False,
                                                   use_aug_graph=False)  # v1 used best_p

            print("========================================================")
            print(f"[NO KL GENERATE LABELS] num_labels: {num_lab}")
            print("========================================================")
            train_mask, test_mask = random_generate_masks(gnd, num_for_train=num_lab,
                                                           type_generate=type_random)
            print("======================")
            print("[NO KL WITH KG LP_AGC]")
            ac_LP_AGC_no_KL[r], _, _, lp_agc_lab = lp_agc(kmeans_, 
                                                          prelab, u_, adj, gnd,
                                                          train_mask, test_mask,
                                                          lp_prop_num=LP_agc_power,
                                                          lp_p=0.6, lp_alpha=lp_aplha,
                                                          alpha=alpha_few_shot_lp,
                                                          keep=topK_lp, knngraph=KG,
                                                          modify_knn=False)  # v1 used best_p

            print("======================")
            print("[NO KL WITH KG SGC_AGC]")
            ac_SGC_AGC_no_KL[r], _, _, _ = SGC_AGC(kmeans_, prelab, u_,
                                                   feature, adj, gnd,
                                                   test_mask, train_mask,
                                                   prop_num=SGC_ACG_power,
                                                   alpha=alpha_few_shot_sgc,
                                                   keep=topK_sgc,
                                                   message='only', knngraph=KG,
                                                   modify_knn=False)  # v1 used best_p
        
        no_kg_LP_AGC_acc.append(np.mean(ac_LP_AGC_no_KG))
        no_KL_LP_AGC_acc.append(np.mean(ac_LP_AGC_no_KL))
        no_kg_LP_AGC_std.append(np.std(ac_LP_AGC_no_KG))
        no_KL_LP_AGC_std.append(np.std(ac_LP_AGC_no_KL))

        no_kg_SGC_AGC_acc.append(np.mean(ac_SGC_AGC_no_KG))
        no_KL_SGC_AGC_acc.append(np.mean(ac_SGC_AGC_no_KL))
        no_kg_SGC_AGC_std.append(np.std(ac_SGC_AGC_no_KG))
        no_KL_SGC_AGC_std.append(np.std(ac_SGC_AGC_no_KL))

    save_dict['no_kg_SGC_AGC_acc'] = no_kg_SGC_AGC_acc
    save_dict['no_KL_SGC_AGC_acc'] = no_KL_SGC_AGC_acc
    save_dict['no_kg_SGC_AGC_std'] = no_kg_SGC_AGC_std
    save_dict['no_KL_SGC_AGC_std'] = no_KL_SGC_AGC_std

    save_dict['no_kg_LP_AGC_acc'] = no_kg_LP_AGC_acc
    save_dict['no_KL_LP_AGC_acc'] = no_KL_LP_AGC_acc
    save_dict['no_kg_LP_AGC_std'] = no_kg_LP_AGC_std
    save_dict['no_KL_LP_AGC_std'] = no_KL_LP_AGC_std
    
    
    data_save_dir = "Ablation_mat"
    make_dir(data_save_dir)
    sio.savemat(os.path.join(data_save_dir, f'Ablation_{dataset}.mat'), save_dict)

    plt.figure()
    x, y, error = np.array(num_labels), np.array(no_kg_LP_AGC_acc), np.array(no_kg_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
    plt.plot(x, y, 'b-*', linewidth=1, label=f'no KG LP')

    x, y, error = np.array(num_labels), np.array(no_KL_LP_AGC_acc), np.array(no_KL_LP_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
                          # edgecolor='#3F7F4C', facecolor='#7EFF99', linewidth=3)
    plt.plot(x, y, 'g-*', linewidth=1, label=f'no KL LP')

    x, y, error = np.array(num_labels), np.array(no_kg_SGC_AGC_acc), np.array(no_kg_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='y')
                     # edgecolor='#CC4F1B', facecolor='#FF9848', linewidth=3)
    plt.plot(x, y, 'y-*', linewidth=1, label=f'no KG SGC')

    x, y, error = np.array(num_labels), np.array(no_KL_SGC_AGC_acc), np.array(no_KL_SGC_AGC_std)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='m')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'm-*', linewidth=1, label=f'no KL SGC')

    plt.xlabel('Number of labels per class')
    plt.ylabel('% Accuracy')
    plt.xlim([num_labels[0], num_labels[-1]])
    plt.ylim([0, 1])
    plt.title(f'{dataset}')
    plt.legend(loc=4)

    plot_save_dir = "Ablation_plot"
    make_dir(plot_save_dir)
    plt.savefig(os.path.join(plot_save_dir, f'Ablation_{dataset}.png'))
    plt.savefig(os.path.join(plot_save_dir, f'Ablation_{dataset}.pdf'))


def sensitivity(dataset, lab_sensitivity=[1, 5, 10],
                neigh_sensitivity=[1,20, 40, 60, 80, 100],
                num_runs=10, fsize=12):
    
    adj, gnd, k, feature = getData(dataset)

    min_num_per_clus = 1
    type_random = "random"
    type_guided = 'guided-min-random' 
    SGC_ACG_power = 5
    LP_agc_power = 60
    alpha_few_shot_lp = [0.33, 0.33, 0.33]
    alpha_few_shot_sgc = [0.33, 0.33, 0.33]
    topK_sgc = 1
    topK_lp = 1
    lp_aplha = 0.5

    plt.figure()
    color = ["b-*", "r-*", "y-*", "m-*", "c-*", "g-*", "k-*"]

    save_dict_sens = {}

    for ii in range(len(lab_sensitivity)):
        num_lab = lab_sensitivity[ii]
        sensitivity_SGC_AGC_acc = []
        sensitivity_SGC_AGC_std = []
        
        sensitivity_LP_AGC_acc = []
        sensitivity_LP_AGC_std = []
        for num_niegh in neigh_sensitivity:
            
            ac_agc, _, _, kmeans_, u_, prelab, best_p, KG = AGC(adj, feature, gnd,
                                                        num_NN=num_niegh, 
                                                        tol=0.001,
                                                        pow=60,
                                                        k=k, num_runs=1,
                                                        norm=None)
            acc_lp_agc = np.zeros(num_runs)
            acc_SGC_AGC = np.zeros(num_runs)
            for r in range(num_runs):

                print("========================================================")
                print(f"[SENSITIVITY GENERATE LABELS] num_labels: {num_lab}, num_niegh: {num_niegh}")
                print("========================================================")
                train_mask, test_mask = unsupervised_guided_split(kmeans_, prelab, u_, gnd,
                                                                  num_lab=num_lab,
                                                                  type_split=type_guided,
                                                                  min_num=min_num_per_clus, 
                                                                  type_random=type_random)
                print("======================")
                print("[SENSITIVITY LP_AGC]")
                acc_lp_agc[r], _, _, lp_agc_lab = lp_agc(kmeans_, 
                                                            prelab, u_, adj, gnd,
                                                            train_mask, test_mask,
                                                            lp_prop_num=LP_agc_power,
                                                            lp_p=0.6, lp_alpha=lp_aplha,
                                                            alpha=alpha_few_shot_lp,
                                                            keep=topK_lp, knngraph=KG,
                                                            modify_knn=False,
                                                            use_aug_graph=True)  # v1 used best_p

                print("======================")
                print("[SENSITIVITY SGC_AGC]")
                acc_SGC_AGC[r], _, _, _ = SGC_AGC(kmeans_, prelab, u_,
                                                       feature, adj, gnd,
                                                       test_mask, train_mask,
                                                       prop_num=SGC_ACG_power,
                                                       alpha=alpha_few_shot_sgc,
                                                       keep=topK_sgc,
                                                       message='only', knngraph=KG,
                                                       modify_knn=False,
                                                       use_aug_graph=True)  # v1 used best_p
                
            sensitivity_SGC_AGC_acc.append(np.mean(acc_SGC_AGC))
            sensitivity_SGC_AGC_std.append(np.std(acc_SGC_AGC))

            sensitivity_LP_AGC_acc.append(np.mean(acc_lp_agc))
            sensitivity_LP_AGC_std.append(np.std(acc_lp_agc))

        save_dict_sens[f'num_lab{num_lab}_sensitivity_SGC_AGC_acc'] = sensitivity_SGC_AGC_acc
        save_dict_sens[f'num_lab{num_lab}_sensitivity_SGC_AGC_std'] = sensitivity_SGC_AGC_std
        save_dict_sens[f'num_lab{num_lab}_sensitivity_LP_AGC_acc'] = sensitivity_LP_AGC_acc
        save_dict_sens[f'num_lab{num_lab}_sensitivity_LP_AGC_std'] = sensitivity_LP_AGC_std

        x, y, error = np.array(neigh_sensitivity), np.array(sensitivity_LP_AGC_acc), np.array(sensitivity_LP_AGC_std)
        plt.fill_between(x, y - error, y + error, alpha=0.2, color=color[0][0])
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
        plt.plot(x, y, color[0], linewidth=1, label=f'LP-AGC-num_lab{num_lab}')

        x, y, error = np.array(neigh_sensitivity), np.array(sensitivity_SGC_AGC_acc), np.array(sensitivity_SGC_AGC_std)
        plt.fill_between(x, y - error, y + error, alpha=0.2, color=color[1][0])
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
        plt.plot(x, y, color[1], linewidth=1, label=f'SGC-AGC-num_lab{num_lab}')
        
        color.remove(color[0])
        color.remove(color[1])

    plt.xlabel('Number of neighbors used')
    plt.ylabel('% Accuracy')
    plt.xlim([neigh_sensitivity[0], neigh_sensitivity[-1]])
    plt.ylim([0, 1])
    plt.title(f'{dataset}')
    plt.legend(loc=4)

    plot_save_dir = "Sensitivity_plot"
    make_dir(plot_save_dir)
    plt.savefig(os.path.join(plot_save_dir, f'Sensitivity_{dataset}.png'))
    plt.savefig(os.path.join(plot_save_dir, f'Sensititivity_{dataset}.pdf'))
    
    data_save_dir = "Sensitivity_mat"
    make_dir(data_save_dir)
    sio.savemat(os.path.join(data_save_dir, f'Sensitivity_{dataset}.mat'), save_dict_sens)

            
if __name__ == "__main__":

    args = parser.parse_args()

    # sensitivity(dataset=args.dataset, lab_sensitivity=[1,5,10],
    #             neigh_sensitivity=[1,20, 40, 60, 80, 100],
    #             num_runs=args.num_runs, fsize=args.fsize)

    # ablation(dataset=args.dataset, 
    #          max_num_labels_ablation=args.max_num_ablation_labels,
    #          num_runs=args.num_runs, fsize=args.fsize)

