from selectors import EpollSelector
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#from utils import random_planetoid_splits
from torch.nn import Sigmoid, SiLU
import argparse
from IPython import embed
import torch
from basic_gnn import GCN, MLP
import wandb
import pandas as pd
import seaborn as sns
from utils import NNNodeBenchmarker, ContextualSBM, MMD
from deepjdot import NNNodeBenchmarker_JDOT
#from cdan import NNNodeBenchmarker_CDAN
import random
import pickle
from tqdm import tqdm
from torch_geometric.data import Data
from syn_dataset import CustomDataset
from torch_geometric.utils import from_scipy_sparse_matrix
from utils import index_to_mask

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--phi', type=float, default=1)
    parser.add_argument('--epsilon', type = float , default = 3.25)
    parser.add_argument('--lr', type = float , default = 0.01)
    parser.add_argument('--dataset', default = 'syn-cora')
    parser.add_argument('--method', default = 'gcn')
    parser.add_argument('--model_type', type = str, default = 'gcn')
    parser.add_argument('--da', type = bool, default = True)
    parser.add_argument('--num_layers', type = int, default = 2)
    parser.add_argument('--gpu', type = int , default = -1)
    parser.add_argument('--log', action='store_true')
    args = parser.parse_args()

    dataset = CustomDataset(root=args.dataset, name="h1.00-r1", setting="gcn", seed=15)
    edge_index, edge_weight = from_scipy_sparse_matrix(dataset.adj)
    train_graph = Data(x=torch.FloatTensor(dataset.features.toarray()),
                edge_index=torch.LongTensor(edge_index),
                y=torch.LongTensor(dataset.labels))
    device = torch.device(f'cuda:{args.gpu}' if args.gpu > -1 else "cpu")
    train_graph.train_mask = index_to_mask(torch.LongTensor(dataset.idx_train), train_graph.num_nodes)
    train_graph.val_mask = index_to_mask(torch.LongTensor(dataset.idx_val), train_graph.num_nodes)
    testdata_list = []
    for p_q in range(10):
        p_q *= 0.1
        dataset = CustomDataset(root=args.dataset, name="h%.2f-r1" % p_q, setting="gcn", seed=15)
        edge_index, edge_weight = from_scipy_sparse_matrix(dataset.adj)
        data = Data(x=torch.FloatTensor(dataset.features.toarray()),
                edge_index=torch.LongTensor(edge_index),
                y=torch.LongTensor(dataset.labels))
        data.train_mask = index_to_mask(torch.LongTensor(dataset.idx_train), data.num_nodes)
        data.val_mask = index_to_mask(torch.LongTensor(dataset.idx_val), data.num_nodes)
        data.test_mask = index_to_mask(torch.LongTensor(dataset.idx_test), data.num_nodes)
        data.p_q = p_q
        testdata_list.append(data)
        
        dataset = CustomDataset(root=args.dataset, name="h%.2f-r2" % p_q, setting="gcn", seed=15)
        edge_index, edge_weight = from_scipy_sparse_matrix(dataset.adj)
        data = Data(x=torch.FloatTensor(dataset.features.toarray()),
                edge_index=torch.LongTensor(edge_index),
                y=torch.LongTensor(dataset.labels))
        data.train_mask = index_to_mask(torch.LongTensor(dataset.idx_train), data.num_nodes)
        data.val_mask = index_to_mask(torch.LongTensor(dataset.idx_val), data.num_nodes)
        data.test_mask = index_to_mask(torch.LongTensor(dataset.idx_test), data.num_nodes)
        data.p_q = p_q
        testdata_list.append(data)
        
        dataset = CustomDataset(root=args.dataset, name="h%.2f-r3" % p_q, setting="gcn", seed=15)
        edge_index, edge_weight = from_scipy_sparse_matrix(dataset.adj)
        data = Data(x=torch.FloatTensor(dataset.features.toarray()),
                edge_index=torch.LongTensor(edge_index),
                y=torch.LongTensor(dataset.labels))
        data.train_mask = index_to_mask(torch.LongTensor(dataset.idx_train), data.num_nodes)
        data.val_mask = index_to_mask(torch.LongTensor(dataset.idx_val), data.num_nodes)
        data.test_mask = index_to_mask(torch.LongTensor(dataset.idx_test), data.num_nodes)
        data.p_q = p_q
        testdata_list.append(data)
    #embed()
    #
    if args.method == 'cmd':
        GCN_benchmark = NNNodeBenchmarker(arch='I_GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':data.x.shape[1], 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':data.y.max()+1, 'act':SiLU()}, device=device)
    elif args.method == 'cdan':
        GCN_benchmark = NNNodeBenchmarker_CDAN(arch='I-GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':data.x.shape[1], 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':data.y.max()+1, 'act':SiLU()}, device=device)
    elif args.method == 'gcn':
        GCN_benchmark = NNNodeBenchmarker(arch='I_GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':data.x.shape[1], 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':data.y.max()+1, 'act':SiLU()}, device=device)
    elif args.method == 'gjdot':
        GCN_benchmark = NNNodeBenchmarker_JDOT(arch='I_GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':data.x.shape[1], 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':data.y.max()+1, 'act':SiLU()}, device=device)
        #GCN_benchmark = NNNodeBenchmarker_JDOT(arch='MLP-GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':data.x.shape[1], 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':data.y.max()+1, 'act':SiLU()}, device=device)
    plot_data = {'loss':[], 'test_roc':[], 'p/q':[], 'val_roc':[], 'accuracy':[]}
    #for delta in range(10):
    count = 0
    Z = None
    run=None
    #data_specs = []
    #data_specs = pickle.load(open(f'dataset/csbm_{args.bias_type}_repeat_{args.num_samples}.p', 'rb'))
    for test in tqdm(testdata_list, total=len(testdata_list)):
        #print(train_test[0].delta, train_test[1].delta)
        #if train_test[1].delta != 0.5:
         # continue
        #continue
        data, test_data = train_graph.to(device), test.to(device)
        #
        #
        GCN_benchmark.reset_parameters()
        #print(data.u)
        #continue

        GCN_benchmark.SetMasks(data.train_mask, data.val_mask, test_data.test_mask)
        #losses, val_res = GCN_benchmark.train(data, None, 'logloss', True, run)
        losses, val_res = GCN_benchmark.train(data, None if args.method == 'gcn' else test_data, 'logloss', True, run) # None as second parameter meaning no adaptation

        #GCN_benchmark.SetMasks(test_data.train_mask, test_data.val_mask, test_data.test_mask)
        #embed()
        test_res = GCN_benchmark.test(test_data, test_on_val=False, da=True)
        #test_res = GCN_benchmark.test(test_data, test_on_val=False, test_on_classifier_1=True)
        #h_src, h_tgt = GCN_benchmark.get_embeddings(data, test_data, test_on_classifier_1=True)
        #MMD_alpha_1 = MMD(torch.cat([data.y.unsqueeze(1).repeat(1, h_src.shape[1]), h_src], dim=-1), torch.cat([test_data.y.unsqueeze(1).repeat(1, h_src.shape[1]), h_tgt], dim=-1), 0.25)
        plot_data['p/q'].append(test_data.p_q)
        plot_data['loss'].append(np.log(test_res['logloss']+1e-8))
        plot_data['val_roc'].append(val_res['rocauc_ovr'])
        plot_data['test_roc'].append(test_res['rocauc_ovr'])
        plot_data['accuracy'].append(test_res['accuracy'])
        #plot_data['method'].append('GCN')
        #if len(plot_data['test_roc']) > 10:
        #  break
    #print(d1.adj)
    plot_df = pd.DataFrame(data=plot_data)
    pickle.dump(plot_df, open(f"figure/thm1/{args.dataset}_{args.method}_result.pkl", 'wb'))
    #embed()