from selectors import EpollSelector
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
#from utils import random_planetoid_splits
from torch.nn import Sigmoid, SiLU
import argparse
from IPython import embed
import torch
from basic_gnn import GCN, MLP
import wandb
import pandas as pd
import seaborn as sns
from utils import NNNodeBenchmarker, ContextualSBM, MMD
import random
import pickle
from tqdm import tqdm

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--phi', type=float, default=1)
    parser.add_argument('--epsilon', type = float , default = 3.25)
    parser.add_argument('--lr', type = float , default = 0.01)
    parser.add_argument('--root', default = '../data/')
    parser.add_argument('--name', default = 'cSBM_demo')
    parser.add_argument('--num_nodes', type = int, default = 128)
    parser.add_argument('--num_features', type = int, default = 128)
    parser.add_argument('--avg_degree', type = float, default = 10)
    parser.add_argument('--bias_type', type = str, default = 'hybrid')
    parser.add_argument('--num_samples', type = int, default = 100)
    parser.add_argument('--log', action='store_true')
    args = parser.parse_args()

    p_q = 5
    #wandb.run.name = f"p_q_{p_q}_u_v_{8}"

    plot_data = {'loss':[], 'method':[], 'accuracy':[], 'p/q':[], 'loss_dis':[], 'val_accuracy':[], 'mlp_accuracy':[], 'mlp_val_accuracy':[]}
    #for delta in range(10):
    count = 0
    Z = None
    data_specs = []
    #data_specs = pickle.load(open('dataset/csbm_0.p', 'rb'))
    train_graphs, test_graphs = [], []
    for nrepeat in tqdm(range(args.num_samples)):
      #for p_q in range(5,6):
        #p_q = nrepeat + 1
      u = np.random.normal(0, 1/np.sqrt(args.num_features), [args.num_features])
      #degree_multiplier
      if args.bias_type == 'structure':
        p_q = (random.randrange(10)+1)
        u_prime = u
        delta = 0.0
        #continue
      elif args.bias_type == 'translation':
        u_prime = u
        delta = (random.randrange(10)+1) * 0.1
        test_degree = args.avg_degree
      elif args.bias_type == 'hybrid':
        rotation_degree = random.uniform(0, 60/90*np.pi/2)
        
        v = np.random.normal(0, 1, [args.num_features])
        #gaussian_noise = np.random.normal(0, 1, [args.num_nodes, args.num_features])
        #cos_u = float((np.dot(u, v.T) / (np.linalg.norm(u) * np.linalg.norm(v))))
        #while cos_u <= 0.0 or cos_u > 0.5:
        #  v = np.random.normal(0, 1/np.sqrt(args.num_features), [args.num_features])
        #  cos_u = float((np.dot(u, v.T) / (np.linalg.norm(u) * np.linalg.norm(v))))
        #while cos_u <= 0.0:

        unit_u = u / np.linalg.norm(u)
        #unit_v = v / np.linalg.norm(v)
        vv = v - np.dot(unit_u, v) * unit_u
        unit_vv = vv / np.linalg.norm(vv)
        # rotation bias
        u_prime = np.linalg.norm(u) * (np.cos(rotation_degree) * unit_u + np.sin(rotation_degree) * unit_vv)
        # transaltion bias
        #u_prime = unit_u
        delta = (random.randrange(10)+1) * 0.1
        #delta = 0.9



      data = ContextualSBM(args.num_nodes, args.avg_degree, 5, args.num_features, mu=8, train_percent=0.5, u=u)
      test_data = ContextualSBM(args.num_nodes, args.avg_degree, p_q, args.num_features, mu=8, train_percent=0.0, u=u_prime, delta=delta)
      train_graphs.append(data)
      test_graphs.append(test_data)
    pickle.dump({'train_graphs':train_graphs, 'test_graphs':test_graphs}, open(f'dataset/csbm_{args.bias_type}_repeat_{args.num_samples}_d_128.p','wb'))