import pandas as pd
from config import get_args
import torch
from datasets import Dataset
import dgl
from dgl import ops
from sklearn.feature_selection import mutual_info_classif
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import os

import warnings
warnings.filterwarnings("ignore")

DATASETS = ['minesweeper', 'roman-empire', 'amazon-ratings', 'tolokers', 'questions', 'squirrel-filtered', 'chameleon-filtered', 'actor', 'texas-4-classes', 'cornell', 'wisconsin', 'cora', 'citeseer', 'pubmed']
DATASETS_SHOW = ['Mines.', 'Roman.', 'Amazon.', 'Tolokers', 'Questions', 'Squirrel', 'Chameleon', 'Actor', 'Texas', 'Cornell', 'Wisconsin', 'Cora', 'Citeseer', 'Pubmed']
DATASETS_HOM = ['minesweeper', 'tolokers', 'questions', 'cora', 'citeseer', 'pubmed']
DATASETS_HETERO = ['roman-empire', 'amazon-ratings', 'squirrel-filtered', 'chameleon-filtered', 'actor', 'texas-4-classes', 'cornell', 'wisconsin']
MP_DATASETS = {'roman-empire':'Roman.', 'amazon-ratings':'Amazon.',
                'minesweeper':'Mines.', 'tolokers':'Tolokers', 
                'questions':'Questions', 'squirrel-filtered':'Squirrel', 
                'chameleon-filtered':'Chameleon', 'actor':'Actor',
                'texas-4-classes':'Texas', 'cornell':'Cornell', 
                'wisconsin':'Wisconsin', 'cora':'Cora',
                'corafull':'CoraFull','citeseer':'CiteSeer',
                'pubmed':'PubMed','flickr':'Flickr',
                'amazon-photo':'Amazon-Photo','amazon-computer':'Amazon-Computer',
                'coauthor-cs':'Coauthor-CS','coauthor-physics':'Coauthor-Physics',
                'wikics':'WikiCS','blog-catalog':'Blog-Catalog',
                'ogbn-arxiv':'Ogbn-Arxiv','genius':'Genius',
                'twitch-DE':'Twitch-DE','twitch-ENGB':'Twitch-ENGB',
                'twitch-ES':'Twitch-ES','twitch-FR':'Twitch-FR',
                'twitch-PTBR':'Twitch-PTBR','twitch-RU':'Twitch-RU',
                'twitch-TW':'Twitch-TW'}

MODELS = ['ResNet','GCN', 'SGC', 'SAGE', 'GAT']
MODELS_SHOW = ['MLP','GCN', 'SGC', 'SAGE', 'GAT']
MODELS_SHOW = {MODELS[i]:MODELS_SHOW[i] for i in range(len(MODELS_SHOW))}

BASISES = ['feature', 'agge_feature', 'mlp_feature', 'gcn_feature', 'grace_feature']
BASISES_SHOW = [r'$X$', r'$\hat{A}X$', r'MLP($X$)', r'GCN($X$,$A$)', r'GCL($X$,$A$)']

def rewriting_acc_compare():
    data = pd.read_csv("results/rewriting.csv")
    res = {"model":[],"rewrite_basis":[],"rewrite_construct":[],"rewrite_fusion":[],"rewrite_fusion_state":[]}
    for dataset in DATASETS: res[dataset] = []
    for model in MODELS:
        for rewrite_basis in data['rewrite_basis'].unique():
            for rewrite_construct in data['rewrite_construct'].unique():
                for rewrite_fusion in data['rewrite_fusion'].unique():
                    for rewrite_fusion_state in data['rewrite_fusion_state'].unique():
                        item = data[(data['model']==model)&(data['rewrite_basis']==rewrite_basis)&(data['rewrite_construct']==rewrite_construct)&(data['rewrite_fusion']==rewrite_fusion)&(data['rewrite_fusion_state']==rewrite_fusion_state)]
                        if len(item)==0: continue
                        res['model'].append(model)
                        res['rewrite_basis'].append(rewrite_basis)
                        res['rewrite_construct'].append(rewrite_construct)
                        res['rewrite_fusion'].append(rewrite_fusion)
                        res['rewrite_fusion_state'].append(rewrite_fusion_state)
                        for dataset in DATASETS:
                            item_d = item[item['dataset']==dataset]
                            acc = item_d['test ROC AUC mean'].values.max()
                            res[dataset].append(acc)
                        pass
    res = pd.DataFrame(res)
    res.to_csv("results/baseline_acc_compare/acc_compare.csv",index=False)
    pass

def get_test_by_best_val(item):
    if len(item)>0:
        # acc = item['test ROC AUC mean'].values.max()
        acc_idx = item['val ROC AUC mean'].idxmax()
        acc = item.loc[acc_idx,'test ROC AUC mean']
        # acc = item.loc[acc_idx,'test ROC AUC std']
    else:
        acc = -1
    return acc

def rewriting_acc_compare_0725():
    data1 = pd.read_csv("results/rewriting1.csv")
    data2 = pd.read_csv("results/rewriting2.csv")
    # gnn_finetune1 = pd.read_csv("results/mlp_gnn_cuda1.csv")
    # gnn_finetune2 = pd.read_csv("results/mlp_gnn_cuda2.csv")
    data = pd.concat([data1,data2])
    # data_ours = data[data['rewrite_basis']=='do_not_rewrite'].reset_index()
    # data_ours = pd.concat([gnn_finetune1,gnn_finetune2])
    data_ours = pd.read_csv("results/fix_pretrain_feat/mlp_gnn_cuda0.csv")
    data = data[data['rewrite_basis']!='do_not_rewrite'].reset_index()
    res = {"model":[],"rewrite_basis":[],"rewrite_construct":[],"rewrite_fusion":[],"rewrite_fusion_state":[]}
    for dataset in DATASETS: res[dataset] = []
    for model in MODELS:
        for rewrite_basis in data['rewrite_basis'].unique():
            # Add ours
            res['model'].append(model)
            res['rewrite_basis'].append(rewrite_basis)
            res['rewrite_construct'].append("do_not_rewrite")
            res['rewrite_fusion'].append("")
            res['rewrite_fusion_state'].append("")
            item = data_ours[(data_ours['model']==model)&(data_ours['rewrite_basis']==rewrite_basis)]
            for dataset in DATASETS:
                item_d = item[item['dataset']==dataset]
                res[dataset].append(get_test_by_best_val(item_d))
        for rewrite_basis in data['rewrite_basis'].unique():
            # Add GSL
            for rewrite_construct in data['rewrite_construct'].unique():
                for rewrite_fusion in data['rewrite_fusion'].unique():
                    for rewrite_fusion_state in data['rewrite_fusion_state'].unique():
                        item = data[(data['model']==model)&(data['rewrite_basis']==rewrite_basis)&(data['rewrite_construct']==rewrite_construct)&(data['rewrite_fusion']==rewrite_fusion)&(data['rewrite_fusion_state']==rewrite_fusion_state)]
                        if len(item)==0: continue
                        res['model'].append(model)
                        res['rewrite_basis'].append(rewrite_basis)
                        res['rewrite_construct'].append(rewrite_construct)
                        res['rewrite_fusion'].append(rewrite_fusion)
                        res['rewrite_fusion_state'].append(rewrite_fusion_state)
                        for dataset in DATASETS:
                            item_d = item[item['dataset']==dataset]
                            res[dataset].append(get_test_by_best_val(item_d))
    res = pd.DataFrame(res)
    res.to_csv("results/baseline_acc_compare/acc_compare_0831.csv",index=False)
    # res.to_csv("results/archive/all_std_gnn_plus_gsl.csv",index=False)

    # Select best one for rewriting basis
    clean_main_table(res,"results/archive/gnn_plus_gsl_best_basis.csv")
    pass

def clean_main_table(data, output):
    # Select best one for rewriting basis
    res = {"model":[],"rewrite_construct":[],"rewrite_fusion":[],"rewrite_fusion_state":[]}
    for dataset in DATASETS: res[dataset] = []
    for model in MODELS:
        for rewrite_construct in data['rewrite_construct'].unique():
            for rewrite_fusion in data['rewrite_fusion'].unique():
                for rewrite_fusion_state in data['rewrite_fusion_state'].unique():
                    item = data[(data['model']==model)&(data['rewrite_construct']==rewrite_construct)&(data['rewrite_fusion']==rewrite_fusion)&(data['rewrite_fusion_state']==rewrite_fusion_state)]
                    if len(item)>0:
                        res['model'].append(model)
                        res['rewrite_construct'].append(rewrite_construct)
                        res['rewrite_fusion'].append(rewrite_fusion)
                        res['rewrite_fusion_state'].append(rewrite_fusion_state)
                        for dataset in DATASETS:
                            res[dataset].append(item[dataset].max())
    res = pd.DataFrame(res)
    res = res[res['rewrite_construct']!='editing'].reset_index()
    res = res[res['rewrite_fusion_state']!='early'].reset_index()
    res['rewrite_fusion_state']='-'
    for i in range(len(res)):
        fusion = res.loc[i,'rewrite_fusion']
        if fusion=='':
            res.loc[i,'rewrite_fusion']="-"
            res.loc[i,'rewrite_fusion_state']="-"
        elif fusion=='only_new':
            res.loc[i,'rewrite_fusion']=r"$\{\mathcal{G}'\}$"
        elif fusion=='both_share_param':
            res.loc[i,'rewrite_fusion']=r"$\{\mathcal{G},\mathcal{G}'\}$"
            res.loc[i,'rewrite_fusion_state']=r"$\theta_1=\theta_2$"
        elif fusion=='both_seperate_param':
            res.loc[i,'rewrite_fusion']=r"$\{\mathcal{G},\mathcal{G}'\}$"
            res.loc[i,'rewrite_fusion_state']=r"$\theta_1\neq\theta_2$"
        
        recon = res.loc[i,'rewrite_construct']
        if recon=='do_not_rewrite':
            res.loc[i,'rewrite_construct']=r"None"
        elif recon=='cos_sim_graph':
            res.loc[i,'rewrite_construct']=r"cos-graph"
        elif recon=='cos_sim_node':
            res.loc[i,'rewrite_construct']=r"cos-node"
        elif recon=='knn':
            res.loc[i,'rewrite_construct']=r"kNN"
        pass
    res.to_csv("results/archive/gnn_plus_gsl_best_basis.csv",index=False)
    pass

def latex_baseline_plus_rewriting():
    data1 = pd.read_csv("results/rewriting1.csv")
    data2 = pd.read_csv("results/rewriting2.csv")
    gnn_finetune1 = pd.read_csv("results/mlp_gnn_cuda1.csv")
    gnn_finetune2 = pd.read_csv("results/mlp_gnn_cuda2.csv")
    data = pd.concat([data1,data2])
    data_rank = pd.read_csv("results/baseline_acc_compare/acc_compare_0809.csv")

    data = data[data['rewrite_construct']!='editing'].reset_index()
    data_rank = data_rank[data_rank['rewrite_construct']!='editing'].reset_index()

    model = 'GCN'
    # DATASETS = DATASETS_HETERO
    DATASETS = DATASETS_HOM
    data_rank = data_rank[data_rank['model']==model]
    for dataset in DATASETS: data_rank[dataset] = data_rank[dataset].rank(ascending=False)

    print(r"\begin{tabular}{*{"+str(4+len(DATASETS))+r"}{c}}\toprule")
    print(r"Basis & Construct & Fusion & Param Mode &",end='')
    for dataset in DATASETS:
        end_symbol = ' & ' if dataset!=DATASETS[-1] else '\\\\ \\toprule \n'
        print(MP_DATASETS[dataset], end = end_symbol)

    for i in range(len(data_rank)):
        rewrite_basis = data_rank.loc[i,'rewrite_basis']
        rewrite_construct = data_rank.loc[i,'rewrite_construct']
        rewrite_fusion = data_rank.loc[i,'rewrite_fusion']
        rewrite_fusion_state = data_rank.loc[i,'rewrite_fusion_state']
        if rewrite_basis=='do_not_rewrite':
            item = data[(data['model']==model)&(data['rewrite_basis']==rewrite_basis)]
        else:
            item = data[(data['model']==model)&(data['rewrite_basis']==rewrite_basis)&(data['rewrite_construct']==rewrite_construct)&(data['rewrite_fusion']==rewrite_fusion)&(data['rewrite_fusion_state']==rewrite_fusion_state)]
        if len(item)==0: continue
        for name_rec in [rewrite_basis,rewrite_construct,rewrite_fusion,rewrite_fusion_state]:
            print(name_rec.replace('_',' '),end=' & ')
        for dataset in DATASETS:
            item_d = item[item['dataset']==dataset]
            if len(item_d)>0:
                id_max = item_d['test ROC AUC mean'].idxmax()
                acc = 100*item_d['test ROC AUC mean'][id_max]
                std = 100*item_d['test ROC AUC std'][id_max]
            else:
                acc, std = 0, 0
            end_symbol = ' & ' if dataset!=DATASETS[-1] else '\\\\ \n'
            rank = data_rank.loc[i,dataset]
            value_print = "{:.2f}".format(acc)+r"$\pm$"+"{:.2f}".format(std)
            if rank<2:
                print("\\textcolor{red}{"+value_print+"}",end=end_symbol)
            elif rank<3:
                print("\\textcolor{blue}{"+value_print+"}",end=end_symbol)
            elif rank<4:
                print("\\textcolor{purple}{"+value_print+"}",end=end_symbol)
            else:
                print(value_print,end=end_symbol)
            # print(r"{"+"{:.2f}".format(acc)+r"$\pm$"+"{:.2f}".format(std)+r"}",end=end_symbol)
    print("\\bottomrule\n\\end{tabular}\n}")
    pass

def latex_baseline_plus_rewriting_0901():
    data_all = pd.read_csv("results/archive/acc_gnn_plus_gsl_best_basis.csv")
    data_all_std = pd.read_csv("results/archive/std_gnn_plus_gsl_best_basis.csv")
    data_all = data_all.drop(columns=['level_0','index'])
    data_all_std = data_all_std.drop(columns=['level_0','index'])

    # DATASETS = DATASETS_HETERO
    # DATASETS = DATASETS_HOM
    print(r"\resizebox{1\hsize}{!}{")
    print(r"\begin{tabular}{*{"+str(5+len(DATASETS))+r"}{c}}\toprule")
    print(r"Model & Construct & Fusion & Param Sharing &",end='')
    for dataset in DATASETS:
        end_symbol = ' & '
        print(MP_DATASETS[dataset], end = end_symbol)
    print("Rank",end='\\\\ \\toprule \n')
    MODELS.remove("ResNet")
    for model in MODELS:
        data = data_all[(data_all['model']=='ResNet')|(data_all['model']==model)].reset_index()
        data_std = data_all_std[(data_all_std['model']=='ResNet')|(data_all_std['model']==model)].reset_index()
        data_rank = data.copy()
        for dataset in DATASETS: data_rank[dataset] = data_rank[dataset].rank(ascending=False)

        for i in range(len(data_rank)):
            rewrite_construct = data_rank.loc[i,'rewrite_construct']
            rewrite_fusion = data_rank.loc[i,'rewrite_fusion']
            rewrite_fusion_state = data_rank.loc[i,'rewrite_fusion_state']
            model_print = data_rank.loc[i,'model']
            if model_print=='ResNet': model_print='MLP'
            for name_rec in [MODELS_SHOW[data_rank.loc[i,'model']],rewrite_construct,rewrite_fusion,rewrite_fusion_state]:
                print(name_rec,end=' & ')
            for dataset in DATASETS:
                acc = data.loc[i,dataset]*100
                std = data_std.loc[i,dataset]*100
                end_symbol = ' & '
                rank = data_rank.loc[i,dataset]
                # value_print = "{:.2f}".format(acc)
                value_print = "{:.2f}".format(acc)+r"$\pm$"+"{:.2f}".format(std)
                if rank<2:
                    print("\\textcolor{red}{"+value_print+"}",end=end_symbol)
                elif rank<3:
                    print("\\textcolor{blue}{"+value_print+"}",end=end_symbol)
                # elif rank<4:
                #     print("\\textcolor{purple}{"+value_print+"}",end=end_symbol)
                else:
                    print(value_print,end=end_symbol)
            mean_rank = data_rank.iloc[i,-len(DATASETS):].mean()
            print("{:.2f}".format(mean_rank),end='\\\\ \n')
                # print(r"{"+"{:.2f}".format(acc)+r"$\pm$"+"{:.2f}".format(std)+r"}",end=end_symbol)
        end_symbol = '\\toprule \n' if model!=MODELS[-1] else '\\bottomrule\n'
        print(end_symbol)
    print("\\end{tabular}\n}")
    pass

def synthetic_mutual_info(hl,seed,feature_type='raw'):
    args = get_args()
    torch.manual_seed(seed)
    # hl = 0.8
    hs = 1
    hf = 0
    if feature_type == 'raw': rewrite_basis = 'feature' 
    elif feature_type == 'agg': rewrite_basis = 'agge_feature'
    dataset = Dataset(name='syn',
                      model_name=args.model,
                      add_self_loops=True,
                      device="cpu",
                      use_sgc_features=args.use_sgc_features,
                      use_identity_features=args.use_identity_features,
                      use_adjacency_features=args.use_adjacency_features,
                      do_not_use_original_features=args.do_not_use_original_features,
                      topk=args.topk,
                      toprank=args.toprank,
                      syn_num_node=1000,
                      syn_num_class=5,
                      syn_num_degree="5 10",
                      syn_feat_dim=10,
                      syn_label_homophily=hl,
                      syn_structural_homophily=hs,
                      syn_feature_homomophily=hf,
                      syn_train_ratio=args.syn_train_ratio,
                      syn_test_ratio=args.syn_test_ratio,
                      seed=seed,
                      prefer_feat=args.prefer_feat,
                      rewrite_basis=rewrite_basis,
                      rewrite_construct='knn',
                      rewrite_construct_param=3, #####################!!!!!!!!!!!!!!!!!!!!!!!!!!!
                      use_gsl=1
                      )
    
    labels = dataset.labels

    if feature_type=='raw':
        node_features = dataset.node_features
    elif feature_type=='agg':
        node_features = get_aggregate_features(dataset.graph, dataset.node_features)

    agg_old_feat = get_aggregate_features(dataset.graph, node_features)
    agg_new_feat = get_aggregate_features(dataset.graph_new, node_features)

    mi_raw = mutual_info_classif(node_features, labels).mean()
    mi_old = mutual_info_classif(agg_old_feat, labels).mean()
    mi_new = mutual_info_classif(agg_new_feat, labels).mean()

    acc_raw = mlp_classifier(node_features, labels)
    acc_old = mlp_classifier(agg_old_feat, labels)
    acc_new = mlp_classifier(agg_new_feat, labels)

    return mi_raw, mi_old, mi_new, acc_raw, acc_old, acc_new

def get_aggregate_features(graph, node_features):
    graph = graph.remove_self_loop()
    graph = graph.add_self_loop()
    degrees = graph.out_degrees().float()
    edge_num = graph.edges()[0].shape[0]
    node_num, feat_num = node_features.shape
    degree_edge_products = ops.u_mul_v(graph, degrees, degrees) + 1
    coefs = 1 / degree_edge_products ** 0.5
    agg_features = ops.u_mul_e_sum(graph, node_features, coefs)
    return agg_features

def mlp_classifier(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y,
        random_state=1, test_size=0.5)
    clf = MLPClassifier(random_state=1, max_iter=300).fit(X_train, y_train)
    acc = clf.score(X_test, y_test)
    return acc

def check_synthetic_mi_acc(feature_type):
    res = {'seed':[],'hl':[],'mi_raw':[], 'mi_old':[], 'mi_new':[], 'acc_raw':[], 'acc_old':[], 'acc_new':[]}
    for hl in np.arange(11)/10:
        for seed in range(0,10):
            mi_raw, mi_old, mi_new, acc_raw, acc_old, acc_new = synthetic_mutual_info(hl,seed,feature_type=feature_type)
            print("seed {} hl {}, mi_raw {:.4f} mi_old {:.4f} mi_new {:.4f}, acc_raw {:.4f} acc_old {:.4f} acc_new {:.4f}".format(seed,hl,mi_raw,mi_old,mi_new,acc_raw,acc_old,acc_new))
            res['seed'].append(seed)
            res['hl'].append(hl)
            res['mi_raw'].append(mi_raw)
            res['mi_old'].append(mi_old)
            res['mi_new'].append(mi_new)
            res['acc_raw'].append(acc_raw)
            res['acc_old'].append(acc_old)
            res['acc_new'].append(acc_new)
    res_df = pd.DataFrame(res)
    res_df.to_csv(f"results/theoretical/check_{feature_type}_feature_mi_acc3.csv")

def plot_synthetic_mi_acc(type):
    res = pd.read_csv(f"results/theoretical/check_{type}_feature_mi_acc2.csv")
    # METRICS = ['mi','acc']
    # TYPE_FEATURES = ['raw','old','new']
    acc = {'mi_raw':[],'mi_old':[],'mi_new':[],
           'acc_raw':[],'acc_old':[],'acc_new':[]}
    std = {'mi_raw':[],'mi_old':[],'mi_new':[],
           'acc_raw':[],'acc_old':[],'acc_new':[]}
    mp_label = {'mi_raw':r"I(B;Y)",'mi_old':r"I(H;Y)",'mi_new':r"I(H';Y)",
           'acc_raw':r"Acc(B,Y)",'acc_old':r"Acc(H,Y)",'acc_new':r"Acc(H',Y)"}
    for hl in np.arange(11)/10:
        item = res[res['hl']==hl]
        for k in acc.keys():
            acc[k].append(item[k].mean())
            std[k].append(item[k].std())
    for k in acc.keys():
        acc[k] = np.array(acc[k])
        std[k] = np.array(std[k])
    xticks = np.arange(11)/10
    fig, axs = plt.subplots(1,2,figsize=(5,2.5))
    for idx,metric in enumerate(['mi','acc']):
        for k in acc.keys():
            # if 'old' in k: continue
            if metric in k:
                axs[idx].plot(xticks,acc[k],label=mp_label[k])
                axs[idx].fill_between(xticks,acc[k]-std[k],acc[k]+std[k],alpha=0.2)
        axs[idx].set_xlim(0,1)
        axs[idx].legend(loc='upper left',fontsize = 9)
        axs[idx].set_xlabel("Homophily")
    axs[1].set_ylim(0.2,1)
    axs[0].set_ylabel("Mutual Information")
    axs[1].set_ylabel("Accuracy")
    fig.tight_layout(pad=0.2)
    plt.savefig(f"results/plot/{type}_mi_acc.png")
    plt.savefig(f"results/plot/{type}_mi_acc.pdf")
    pass

def plot_gnn_rewriting_basis():
    data = pd.read_csv("results/archive/all_acc_gnn_plus_gsl.csv")
    # DATASETS = ['minesweeper', 'roman-empire', 'amazon-ratings', 'tolokers', 'questions', 'cora', 'citeseer', 'pubmed']
    # DATASETS_SHOW = ['Mines.', 'Roman.', 'Amazon.', 'Tolokers', 'Questions', 'Cora', 'Citeseer', 'Pubmed']
    # COLORS = ['#ffabab','#ffdaab','#ddffab','#abe4ff','#d9abff']
    COLORS = ['#ff595e','#ffca3a','#8ac926','#1982c4','#6a4c93']
    
    x = np.arange(len(DATASETS))
    width = 0.15
    for model in MODELS:
        item = data[(data['model']==model)&(data['rewrite_construct']=='do_not_rewrite')]
        # fig, ax = plt.subplots(1,1,figsize=(10,4))
        fig, ax = plt.subplots(1,1,figsize=(6.6,5))
        for i,basis in enumerate(BASISES):
            # acc_list = list(item[item['rewrite_basis']==basis][DATASETS].reset_index().loc[0,:])
            acc_list = []
            for dataset in DATASETS: acc_list.append(max(0,item[item['rewrite_basis']==basis][dataset].iloc[0]))
            offset = i*width - 2*width
            ax.bar(x+offset, acc_list, width, label=BASISES_SHOW[i], color=COLORS[i])
        ax.legend(loc='upper left', ncol=len(BASISES))
        ax.set_ylabel("Accuracy / AUC-ROC")
        ax.set_ylim(0.2,1)
        ax.set_xticks(x)
        ax.set_xticklabels(DATASETS_SHOW, rotation=45)
        plt.tight_layout()
        plt.savefig(f"results/plot/basis/{model}.png")
        plt.savefig(f"results/plot/basis/{model}.pdf")
        pass
    pass

def plot_gsl_visualization():
    # path = "results/gsl_quality/wisconsin_agge_feature_cos_sim_graph_0.1.pt"
    CONSTRUCTION = [
        # 'cos_sim_graph_0.1',
        # 'cos_sim_graph_0.5',
        # 'cos_sim_graph_1.0',
        'cos_sim_graph_5.0',
        # 'cos_sim_node_0.1',
        # 'cos_sim_node_0.5',
        # 'cos_sim_node_1.0',
        # 'cos_sim_node_5.0',
        # 'knn_2.0',
        # 'knn_3.0',
        # 'knn_5.0',
        # 'knn_10.0',
    ]
    def plot_adj(adj,save_path,title=""):
        fig, ax = plt.subplots(1,1,figsize=(3,3))
        ax.imshow(adj,interpolation='none',cmap='copper')
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.set_title(title,y=-0.2)
        # import pdb; pdb.set_trace()
        plt.tight_layout()
        plt.savefig(f"results/plot/adj_visual/{save_path}.png")
        plt.savefig(f"results/plot/adj_visual/{save_path}.pdf")
    # BASISES = ['mlp_feature','gcn_feature']
    for dataset in ['wisconsin']:
        data = np.load(os.path.join('/home/yilun/HOM_GNN/syn-heterophilous-graphs','data', f'{dataset.replace("-", "_")}.npz'))
        labels = torch.tensor(data['node_labels'])
        edges = torch.tensor(data['edges'])
        num_node = labels.shape[0]
        sort_idx = labels.sort()[1]

        src0, tag0 = edges[:,0], edges[:,1]
        num_hom_edge0 = labels[src0.long()] == labels[tag0.long()]
        hom0 = float(num_hom_edge0.sum()/num_hom_edge0.shape[0])
        print("{} {:.4f}".format(dataset,hom0))

        adj = torch.zeros((num_node,num_node))
        adj[src0,tag0] = 1
        adj[tag0,src0] = 1
        adj = adj.fill_diagonal_(0)
        adj = adj[sort_idx,:][:,sort_idx]
        plot_adj(adj, f"{dataset}_original", r"Original, $h_{edge}$="+"{:.2f}".format(hom0))

        for basis in ['mlp_feature','gcn_feature']:
            # prediction = torch.load(f"pretrain/gcn_feature/{dataset}_0.pt")
            prediction_data = torch.load(f"pretrain/{basis}/{dataset}_0.pt")
            prediction = prediction_data.max(dim=1)[1]
            # import pdb; pdb.set_trace()
            uncertain = (prediction_data.softmax(dim=1).max(dim=1)[0]<0.9995)
            adj = prediction.unsqueeze(1).repeat(1,num_node)
            adj = (adj==adj.transpose(1,0)).int()
            adj = adj.fill_diagonal_(0)
            adj[uncertain,:] = 0
            adj[:,uncertain] = 0

            adj0 = labels.unsqueeze(1).repeat(1,num_node)
            adj0 = (adj0==adj0.transpose(1,0)).int()
            adj0 = adj0.fill_diagonal_(0)
            hom0 = (adj0*adj).sum()/adj.sum()
            print("{}\t {}\t {:.4f}".format(basis,adj.sum(),hom0))

            adj = adj[sort_idx,:][:,sort_idx]
            title = r"${\hat{Y}}=GCN({X},{A})$" if basis=='gcn_feature' else r"${\hat{Y}}=MLP({X},{A})$"
            plot_adj(adj, f"{dataset}_{basis}_cls", title+r", $h_{edge}$="+"{:.2f}".format(hom0))

        for i,basis in  enumerate(BASISES):
            for cst in CONSTRUCTION:
                path = f"results/gsl_quality/{dataset}_{basis}_{cst}.pt"
                new_graph = torch.load(path)
                src, tag = new_graph
                src, tag = src.long(), tag.long()
                save_idx = src!=tag
                src = src[save_idx]
                tag = tag[save_idx]

                num_hom_edge = labels[src] == labels[tag]
                hom = float(num_hom_edge.sum()/num_hom_edge.shape[0])
                print("{}\t {}\t {}\t {:.4f}".format(cst,basis,src.shape[0],hom))

                adj = torch.zeros((num_node,num_node))
                adj[src,tag] = 1
                adj[tag,src] = 1
                adj = adj[sort_idx,:][:,sort_idx]
                title = r"$B=$"+BASISES_SHOW[i]+r", $h_{edge}$="+"{:.2f}".format(hom)
                plot_adj(adj, f"{dataset}_{basis}_{cst}",title)
                pass
            pass
        pass
    pass

if __name__=='__main__':
    # rewriting_acc_compare()
    # rewriting_acc_compare_0725()
    # latex_baseline_plus_rewriting()
    # latex_baseline_plus_rewriting_0901()
    # check_synthetic_mi_acc('raw')
    # check_synthetic_mi_acc('agg')
    plot_synthetic_mi_acc('raw')
    plot_synthetic_mi_acc('agg')
    # plot_gnn_rewriting_basis()
    # plot_gsl_visualization()
    pass