"""
This is a script for contexual SBM model and its dataset generator.
contains functions:
        ContextualSBM
        parameterized_Lambda_and_mu
        save_data_to_pickle
    class:
        dataset_ContextualSBM

"""
from selectors import EpollSelector
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#from utils import random_planetoid_splits
from torch.nn import Sigmoid, SiLU
import argparse
from IPython import embed
import torch
from basic_gnn import GCN, MLP
import wandb
import pandas as pd
import seaborn as sns
from utils import NNNodeBenchmarker, ContextualSBM, MMD
from deepjdot import NNNodeBenchmarker_JDOT
from cdan import NNNodeBenchmarker_CDAN
import random
import pickle
from tqdm import tqdm

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--phi', type=float, default=1)
    parser.add_argument('--epsilon', type = float , default = 3.25)
    parser.add_argument('--lr', type = float , default = 0.01)
    parser.add_argument('--root', default = '../data/')
    parser.add_argument('--name', default = 'cSBM_demo')
    parser.add_argument('--num_nodes', type = int, default = 128)
    parser.add_argument('--num_features', type = int, default = 16)
    parser.add_argument('--avg_degree', type = float, default = 10)
    parser.add_argument('--bias_type', type = str, default = 'hybrid')
    parser.add_argument('--model_type', type = str, default = 'gcn')
    parser.add_argument('--num_samples', type = int, default = 100)
    parser.add_argument('--num_layers', type = int, default = 2)
    parser.add_argument('--gpu', type = int , default = -1)
    parser.add_argument('--log', action='store_true')
    args = parser.parse_args()


    torch.manual_seed(1)
    np.random.seed(1)
    device = torch.device(f'cuda:{args.gpu}' if args.gpu > -1 else "cpu")
    #GCN_benchmark = NNNodeBenchmarker(arch='I-GCN', model_class=GCN if args.model_type == 'gcn' else MLP, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':args.num_features, 'hidden_channels':16, 'dropout':0.0, 'num_layers':args.num_layers, 'out_channels':2, 'act':SiLU()}, device=device)
    GCN_benchmark = NNNodeBenchmarker_JDOT(arch='I-GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':args.num_features, 'hidden_channels':16, 'num_layers':2,  'dropout':0.0,'out_channels':2, 'act':SiLU()}, device=device)
    #GCN_benchmark = NNNodeBenchmarker_CDAN(arch='I-GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 50}, h_params={'in_channels':args.num_features, 'hidden_channels':16, 'num_layers':2,  'dropout':0.0,'out_channels':2, 'act':SiLU()}, device=device)
    #embed()
    #MLP_benchmark = NNNodeBenchmarker(arch='MLP-GCN', model_class=GCN, benchmark_params={'lr': args.lr, 'epochs': 200}, h_params={'in_channels':args.num_features, 'hidden_channels':16, 'dropout':0.5, 'num_layers':2, 'out_channels':2, 'act':SiLU()})
    p_q = 5
    avg_losses, avg_acc = [], []
    #wandb.run.name = f"p_q_{p_q}_u_v_{8}"

    plot_data = {'loss':[], 'test_roc':[], 'p/q':[], 'val_roc':[], 'delta':[], 'OT_C':[], 'cmd_z':[], 'cmd_z_y':[]}
    #for delta in range(10):
    count = 0
    Z = None
    run=None
    data_specs = []
    data_specs = pickle.load(open(f'dataset/csbm_{args.bias_type}_repeat_{args.num_samples}.p', 'rb'))

    for train_test in tqdm(zip(data_specs['train_graphs'], data_specs['test_graphs']), total=args.num_samples):
        #print(train_test[0].delta, train_test[1].delta)
        #if train_test[1].delta != 0.5:
         # continue
        data, test_data = train_test[0].to(device), train_test[1].to(device)
        #
        #
        GCN_benchmark.reset_parameters()
        #print(data.u)
        #continue

        # MLP Section
        if False:
          MLP_benchmark.reset_parameters()
          MLP_benchmark.SetMasks(data.train_mask, data.val_mask, data.test_mask)
          losses, val_res = MLP_benchmark.train(data, None, 'logloss', True) # None as second parameter meaning no adaptation
          MLP_benchmark.SetMasks(test_data.train_mask, test_data.val_mask, test_data.test_mask)
          test_res = MLP_benchmark.test(test_data, test_on_val=False)
          
          plot_data['p/q'].append(test_data.p_q)
          plot_data['delta'].append(test_data.delta)
          plot_data['loss'].append(np.log(test_res['logloss']+1e-8))
          plot_data['test_roc'].append(test_res['rocauc_ovr'])
          plot_data['val_roc'].append(val_res['rocauc_ovr'])
          plot_data['method'].append('MLP')

        GCN_benchmark.SetMasks(data.train_mask, data.val_mask, test_data.test_mask)
        # used for metric space 
        
        losses, val_res = GCN_benchmark.train(data, test_data, 'logloss', True, run) # None as second parameter meaning no adaptation

        #GCN_benchmark.SetMasks(test_data.train_mask, test_data.val_mask, test_data.test_mask)
        #embed()
        test_res = GCN_benchmark.test(test_data, test_on_val=False, da=True)
        #test_res = GCN_benchmark.test(test_data, test_on_val=False, test_on_classifier_1=True)
        #h_src, h_tgt = GCN_benchmark.get_embeddings(data, test_data, test_on_classifier_1=True)
        #MMD_alpha_1 = MMD(h_src, h_tgt, 0.25)
        MMD_alpha_1 = 0
        #MMD_alpha_1 = MMD(torch.cat([data.y.unsqueeze(1).repeat(1, h_src.shape[1]), h_src], dim=-1), torch.cat([test_data.y.unsqueeze(1).repeat(1, h_src.shape[1]), h_tgt], dim=-1), 0.25)
        plot_data['p/q'].append(test_data.p_q)
        plot_data['delta'].append(test_data.delta)
        plot_data['loss'].append(np.log(test_res['logloss']+1e-8))
        plot_data['val_roc'].append(val_res['rocauc_ovr'])
        plot_data['test_roc'].append(test_res['rocauc_ovr'])
        #plot_data['method'].append('GCN')
        #if len(plot_data['test_roc']) > 10:
        #break
    plot_df = pd.DataFrame(data=plot_data)
    
    #pickle.dump(plot_df, open(f'dataset/csbm_{args.bias_type}_repeat_{args.num_samples}_df.p', 'wb'))
    #print(f"ERM GCN mean:{plot_df.loc[(plot_df['delta']==0.1), 'test_roc'].mean()}, std:{plot_df.loc[plot_df['delta'] == 0.1, 'test_roc'].std()}")
    #print(f"CMD MLP_GCN mean:{plot_df.loc[plot_df['method'] == 'GCN', 'test_roc'].mean()}, std:{plot_df.loc[plot_df['method'] == 'GCN', 'test_roc'].std()}")
    


    #pickle.dump(plot_df, open(f'figure/thm1/csbm_{args.model_type}-{args.num_layers}_pq_{args.bias_type}_repeat_{args.num_samples}_ours_result.pkl', 'wb'))
    embed()


    