import scipy.io as sio
# import time
import os
# import tensorflow as tf
# import torch
# from torch.utils.data import Dataset, DataLoader
import numpy as np
# import scipy.sparse as sp
# import scipy.special as ss
import scipy as spy
import matplotlib.pyplot as plt
import argparse
from utils import (getData, random_generate_masks,
                   unsupervised_guided_split, make_dir)
from models import (AGC, SGC, lp, sgc_agc, lp_agc,
                    run_DGI, run_GMI)
from other_models import (SGC_torch, SGC_AGC_torch, cands_agc)


parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--numruns', type=int, default=10)
parser.add_argument('--min_num_per_clus', type=int, default=1)
parser.add_argument('--max_num_labels', type=int, default=20)
parser.add_argument('--AGC_tol', type=float, default=0.001)
parser.add_argument('--AGC_power', type=int, default=60)
parser.add_argument('--num_neighbors', type=int, default=60)
parser.add_argument('--useKnn', type=bool, default=True)
parser.add_argument('--modifyKnn', type=bool, default=False)
parser.add_argument('--SGC_power', type=int, default=5)
parser.add_argument('--LP_power', type=int, default=60)
parser.add_argument('--use_best_power', type=bool, default=False)
parser.add_argument('--AGC_norm', type=str, default=None)
# parser.add_argument('--plot_save_dir', type=str, default='Results_plots_v12Guided_guidedminRandomKnnGraph5')
parser.add_argument('--plot_save_dir', type=str, default='Dummy')
# v12 is --type_guided=guided-min-random with min_num_per_class=1 vs random --type_random=guided --AGC_norm None --AGC_power 60 AGC_tol 0.001 --alpha [0.33, 0.33, 0.33] --useKnn true ----modifyKnn False --num_neighbors 60
# v13 is --type_guided=guided-min-random with min_num_per_class=1 vs random --type_random=random --AGC_norm None --AGC_power 60 AGC_tol 0.001 --alpha [0.33, 0.33, 0.33] --useKnn true ----modifyKnn False --num_neighbors 60
# # #  bad  --type_guided=guided-min-all-random # donot use the alls
# parser.add_argument('--data_save_dir', type=str, default='Results_mat_v12Guided_guidedminRandomKnnGraph5')
parser.add_argument('--data_save_dir', type=str, default='Dummy')
parser.add_argument('--cluster_dynamic', type=bool, default=False)
parser.add_argument('--type_random', type=str, default='guided',
                    help='per class or not i.e., guided or random')
parser.add_argument('--type_guided', type=str, default='guided-min-random',
                    help=' use agc results or not i.e., random, guided-min-random, '
                         'guided-min, guided-min-max, guided-min-all, guided-min-all-random, '
                         'guided-min-max-all')
parser.add_argument('--alpha', nargs='+', default=[0.33, 0.33, 0.33],
                    action='append', help='laplacian weights,'
                                          '[KnnG, Truelab, adj]')

args = parser.parse_args()


def StatsSavingPloting():

    acc_save = {}
    acc_save[dataset + 'acc_SGC'] = acc_SGC_list
    acc_save[dataset + 'acc_AGC'] = acc_AGC_list
    acc_save[dataset + 'acc_LP'] = acc_LP_list
    acc_save[dataset + 'acc_GCN'] = acc_GCN_list
    acc_save[dataset + 'acc_LP_AGC'] = acc_LP_AGC_list
    acc_save[dataset + 'acc_SGC_AGC'] = acc_SGC_AGC_list
    acc_save[dataset + 'acc_SGC_AGC_LP'] = acc_SGC_AGC_LP_list
    acc_save[dataset + 'ac_CandS_AGC'] = ac_CandS_agc_list
    acc_save[dataset + 'ac_DGI'] = acc_DGI_list
    acc_save[dataset + 'ac_GMI'] = acc_GMI_list

    acc_save[dataset + 'std_SGC'] = std_SGC_list
    acc_save[dataset + 'std_AGC'] = std_AGC_list
    acc_save[dataset + 'std_LP'] = std_LP_list
    acc_save[dataset + 'std_GCN'] = std_GCN_list
    acc_save[dataset + 'std_LP_AGC'] = std_LP_AGC_list
    acc_save[dataset + 'std_SGC_AGC'] = std_SGC_AGC_list
    acc_save[dataset + 'std_SGC_AGC_LP'] = std_SGC_AGC_LP_list
    acc_save[dataset + 'std_CandS_AGC'] = std_CandS_agc_list
    acc_save[dataset + 'std_DGI'] = std_DGI_list
    acc_save[dataset + 'std_GMI'] = std_GMI_list

    make_dir(args.data_save_dir)
    sio.savemat(os.path.join(args.data_save_dir, f'fewShot_SGC_{dataset}.mat'), acc_save)

    plt.figure()
    x, y, error = np.array(num_labels), np.array(acc_AGC_list), np.array(std_AGC_list)
    # plt.fill_between(x, y - error, y + error, alpha=0.5, color='k')
    plt.plot(x, y, 'k-*', linewidth=1, label=f'AGC:{AGC_power}')

    x, y, error = np.array(num_labels), np.array(acc_SGC_list), np.array(std_SGC_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='b')
                     # edgecolor='#1B2ACC', facecolor='#089FFF', linewidth=3), linestyle='dashdot', antialiased=True)
    plt.plot(x, y, 'b-*', linewidth=1, label=f'SGC: power={SGC_power}')

    x, y, error = np.array(num_labels), np.array(acc_LP_list), np.array(std_LP_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='g')
                          # edgecolor='#3F7F4C', facecolor='#7EFF99', linewidth=3)
    plt.plot(x, y, 'g-*', linewidth=1, label=f'LP: power={LP_power}')

    x, y, error = np.array(num_labels), np.array(acc_LP_AGC_list), np.array(std_LP_AGC_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='y')
                     # edgecolor='#CC4F1B', facecolor='#FF9848', linewidth=3)
    plt.plot(x, y, 'y-*', linewidth=1, label=f'LP-AGC={alpha_few_shot_lp}')

    x, y, error = np.array(num_labels), np.array(acc_SGC_AGC_list), np.array(std_SGC_AGC_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='m')
                          # edgecolor='darkorange', facecolor='moccasin', linewidth=3)
    plt.plot(x, y, 'm-*', linewidth=1, label=f'SGC-AGC={alpha_few_shot_sgc}')

    x, y, error = np.array(num_labels), np.array(acc_GMI_list), np.array(std_GMI_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='c')
    plt.plot(x, y, 'c-*', linewidth=1, label=f'GMI')

    x, y, error = np.array(num_labels), np.array(acc_DGI_list), np.array(std_DGI_list)
    plt.fill_between(x, y - error, y + error, alpha=0.2, color='r')
    plt.plot(x, y, 'r-*', linewidth=1, label=f'DGI')

    plt.xlabel('Number of labels per class')
    plt.ylabel('% Accuracy')
    plt.xlim([num_labels[0], num_labels[-1]])
    plt.ylim([0, 1])
    plt.title(f'{args.dataset}')
    plt.legend()


    make_dir(args.plot_save_dir)
    plt.savefig(os.path.join(args.plot_save_dir, f'fewShot_SGC_{dataset}.png'))
    plt.savefig(os.path.join(args.plot_save_dir, f'fewShot_SGC_{dataset}.pdf'))


if __name__ == '__main__':

    dataset = args.dataset
    adj, gnd, k, feature = getData(dataset)

    Deg = adj.sum(axis=0)
    print("========================================================")
    print(f'{dataset} adj.shape{adj.shape}, feature.shape {feature.shape},'
          f' k: {k}, len(gnd): {len(gnd)}, mean degree: {Deg.mean()},'
          f'mode degree: {spy.stats.mode(Deg)}, min degree: {Deg.min()},'
          f'max degree: {Deg.max()}, isolated nodes: {Deg[Deg==0].shape}')

    acc_SGC_list = []
    acc_LP_list = []
    acc_AGC_list = []
    acc_GCN_list = []
    acc_LP_AGC_list = []
    acc_SGC_AGC_list = []
    ac_CandS_agc_list = []
    acc_SGC_AGC_LP_list = []
    acc_DGI_list = []
    acc_GMI_list = []

    std_SGC_list = []
    std_LP_list = []
    std_AGC_list = []
    std_GCN_list = []
    std_LP_AGC_list = []
    std_SGC_AGC_list = []
    std_CandS_agc_list = []
    std_SGC_AGC_LP_list = []
    std_DGI_list = []
    std_GMI_list = []

    rep = args.numruns
    num_min = args.min_num_per_clus
    type_guided = args.type_guided
    type_random = args.type_random
    modKG = args.modifyKnn
    useKG = args.useKnn
    num_nei = args.num_neighbors
    SGC_power = args.SGC_power
    LP_power = args.LP_power
    AGC_power = args.AGC_power
    AGC_norm = args.AGC_norm
    AGC_tol = args.AGC_tol
    CS_prop = 50
    alpha_few_shot_sgc = args.alpha  # clustering, true, original
    topK_sgc = 1
    alpha_few_shot_lp = args.alpha  # clustering, true, original
    topK_lp = 1
    alpha_few_shot_sgc_lp = args.alpha  # clustering, true, original
    topK_sgc_lp = 1
    LP_agc = True
    SGC_agc = True
    DGI = True
    GMI = True
    SGC_LP_agc = False
    CandS_agc = False
    alpha_CS = 0.5
    lp_aplha = 0.5

    num_labels = np.arange(1, args.max_num_labels, 1)

    for num in num_labels:

        if args.cluster_dynamic:
            k = num*k

        if num==1 or args.cluster_dynamic:
            print("========================================================")
            # do AGC since not random
            ac_agc, _, _, kmeans_, u_, prelab, best_p, KG = AGC(adj, feature, gnd,
                                                                num_NN=num_nei, 
                                                                tol=AGC_tol,
                                                                pow=AGC_power,
                                                                k=k, num_runs=1,
                                                                norm=AGC_norm)

        if not useKG:
            KG = None

        if args.use_best_power:
            SGC_ACG_power = best_p
            LP_agc_power = best_p
        else:
            SGC_ACG_power = SGC_power
            LP_agc_power = LP_power

        # plot_sparce_vs_dense_cluster_adj(prelab, adj)

        ac_SGC = np.zeros(rep)
        ac_AGC = np.zeros(rep)
        ac_LP = np.zeros(rep)
        ac_GCN = np.zeros(rep)
        ac_LP_AGC = np.zeros(rep)
        ac_SGC_AGC = np.zeros(rep)
        ac_CandS_agc = np.zeros(rep)
        ac_SGC_LP_AGC = np.zeros(rep)
        ac_DGI = np.zeros(rep)
        ac_GMI = np.zeros(rep)

        for r in range(rep):

            print(f' --- {dataset} For {num} Labels --- ')

            # FewShot(feature, adj_normalized, gnd, norm='l2', num_clus=k, pow=60,
            #           iterations=1000, step=0.5, type='NMF')

            # ac_AGC[r], _, _, kmeans_, u_, prelab, best_p = AGC(adj_normalized, feature, gnd,
            #                                                    tol=AGC_tol, pow=AGC_power, k=k,
            #                                                    num_runs=1, norm=AGC_norm)

            ac_AGC[r] = ac_agc

            train_mask, test_mask = random_generate_masks(gnd, num_for_train=num,
                                                           type_generate=type_random)
            # print(test_mask[test_mask==True].shape)
            print("========================================================")
            ac_SGC[r], _, _, _ = SGC(feature, adj, SGC_power,
                                     test_mask, train_mask, gnd)

            # ac_SGC[r], _, _, _ = SGC_torch(feature, adj, SGC_power,
            #                                test_mask, train_mask, gnd)
            print("========================================================")
            ac_LP[r], _, _ = lp(adj, train_mask, test_mask,
                                gnd, LP_power, p=0.6, alpha=lp_aplha)

            if DGI:
                print("========================================================")
                ac_DGI[r], _, _ = run_DGI(dataset,
                                                                 mask_train=train_mask,
                                                                 mask_test=test_mask, gnd_lab=gnd)
            else:
                ac_DGI[r] = 0

            if GMI:
                print("========================================================")
                ac_GMI[r], _, _ = run_GMI(dataset,
                                          mask_train=train_mask,
                                          mask_test=test_mask, gnd_lab=gnd)
            else:
                ac_GMI[r] = 0

            if CandS_agc:
                print("========================================================")
                ac_CandS_agc[r], _, _ = cands_agc(kmeans_, prelab, u_, adj, gnd, test_mask, p=0.6,
                                                  correct_num=50, smooth_num=CS_prop, lp_p=0.5,
                                                  lp_alpha=lp_aplha, alpha=alpha_CS, keep=topK_lp)
            else:
                ac_CandS_agc[r] = 0

            if args.type_guided != 'random':
                print("========================================================")
                train_mask, test_mask = unsupervised_guided_split(kmeans_, prelab, u_, gnd,
                                                                  num_lab=num, type_split=type_guided,
                                                                  min_num=num_min, type_random=type_random)
            # print(test_mask[test_mask == True].shape)
            if LP_agc:
                print("========================================================")
                ac_LP_AGC[r], _, _, lp_agc_lab = lp_agc(kmeans_, prelab, u_, adj, gnd,
                                                        train_mask, test_mask,
                                                        lp_prop_num=LP_agc_power,
                                                        lp_p=0.6, lp_alpha=lp_aplha,
                                                        alpha=alpha_few_shot_lp,
                                                        keep=topK_lp, knngraph=KG,
                                                        modify_knn=modKG)  # v1 used best_p
            else:
                ac_LP_AGC[r] = 0
                lp_agc_lab = [0]*feature.shape[0]

            if SGC_agc:
                print("========================================================")
                ac_SGC_AGC[r], _, _, _ = sgc_agc(kmeans_, prelab, u_, feature, adj, gnd,
                                                 test_mask, train_mask, prop_num=SGC_ACG_power,
                                                 alpha=alpha_few_shot_sgc, keep=topK_sgc,
                                                 message='only', knngraph=KG,
                                                 modify_knn=modKG)  # v1 used best_p

                # ac_SGC_AGC[r], _, _, _ = SGC_AGC_torch(kmeans_, prelab, u_, feature, adj, gnd,
                #                                        test_mask, train_mask, prop_num=SGC_ACG_power,
                #                                        alpha=alpha_few_shot_sgc, keep=topK_sgc,
                #                                        message='only')  # v1 used best_p
            else:
                ac_SGC_AGC[r] = 0

            if SGC_LP_agc:
                print("========================================================")
                ac_SGC_LP_AGC[r], _, _, _ = sgc_agc(kmeans_, lp_agc_lab, u_, feature, adj, gnd,
                                                    test_mask, train_mask, prop_num=best_p,
                                                    alpha=alpha_few_shot_sgc_lp, keep=topK_sgc_lp,
                                                    message='LP')  # v1 used best_p
            else:
                ac_SGC_LP_AGC[r] = 0

            # print(f'checking DGIac: {ac_DGI[0]}, GMIac: {ac_GMI[0]}')
            # exit(0)
        ac_SGC_mean = np.mean(ac_SGC)
        ac_AGC_mean = np.mean(ac_AGC)
        ac_LP_mean = np.mean(ac_LP)
        ac_GCN_mean = np.mean(ac_GCN)
        ac_LP_AGC_mean = np.mean(ac_LP_AGC)
        ac_SGC_AGC_mean = np.mean(ac_SGC_AGC)
        ac_SGC_LP_AGC_mean = np.mean(ac_SGC_LP_AGC)
        ac_CandS_agc_mean = np.mean(ac_CandS_agc)
        ac_DGI_mean = np.mean(ac_DGI)
        ac_GMI_mean = np.mean(ac_GMI)

        acc_GCN_list.append(ac_GCN_mean)
        acc_LP_list.append(ac_LP_mean)
        acc_SGC_list.append(ac_SGC_mean)
        acc_AGC_list.append(ac_AGC_mean)
        acc_LP_AGC_list.append(ac_LP_AGC_mean)
        acc_SGC_AGC_list.append(ac_SGC_AGC_mean)
        acc_SGC_AGC_LP_list.append(ac_SGC_LP_AGC_mean)
        ac_CandS_agc_list.append(ac_CandS_agc_mean)
        acc_GMI_list.append(ac_GMI_mean)
        acc_DGI_list.append(ac_DGI_mean)

        ac_SGC_std = np.std(ac_SGC)
        ac_AGC_std = np.std(ac_AGC)
        ac_LP_std = np.std(ac_LP)
        ac_GCN_std = np.std(ac_GCN)
        ac_LP_AGC_std = np.std(ac_LP_AGC)
        ac_SGC_AGC_std = np.std(ac_SGC_AGC)
        ac_SGC_LP_AGC_std = np.std(ac_SGC_LP_AGC)
        ac_CandS_agc_std = np.std(ac_CandS_agc)
        ac_DGI_std = np.std(ac_DGI)
        ac_GMI_std = np.std(ac_GMI)

        std_SGC_list.append(ac_SGC_std)
        std_LP_list.append(ac_LP_std)
        std_AGC_list.append(ac_AGC_std)
        std_GCN_list.append(ac_GCN_std)
        std_LP_AGC_list.append(ac_LP_AGC_std)
        std_SGC_AGC_list.append(ac_SGC_AGC_std)
        std_CandS_agc_list.append(ac_CandS_agc_std)
        std_SGC_AGC_LP_list.append(ac_SGC_LP_AGC_std)
        std_GMI_list.append(ac_GMI_std)
        std_DGI_list.append(ac_DGI_std)

    print("========================================================")
    print('ploting and saving')
    StatsSavingPloting()
    print("Done!!")