from models import GNN_Models
import numpy as np
import torch
import dgl.function as fn
from dgl import DGLGraph
import dgl
from dgl import data
import random
import os
import pickle
import networkx as nx
import argparse






        


parser = argparse.ArgumentParser(description='single')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--gnn', type=str, default='gcn')
parser.add_argument('--new', type=int, default=0)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--hidden_size', type = int, default = 64)
parser.add_argument('--wd', type = float, default = 5e-6)
parser.add_argument('--dp', type = float, default = 0.8)
parser.add_argument('--epochs', type = int, default = 400)
parser.add_argument('--save_model', type = int, default = 0)
parser.add_argument('--lr', type = float, default = 0.1)
parser.add_argument('--num_seeds', type = int, default = 30)

parser.add_argument('--ts', type = int, default = 1)
parser.add_argument('--ms', type = int, default = 0)
parser.add_argument('--feat_normalize', type = int, default = 1)
parser.add_argument('--num_shuffle', type = int, default = 10)
parser.add_argument('--activation', type = str, default = 'relu')
parser.add_argument('--balance', type = float, default = 0.5)







args = parser.parse_args()

def load_obj(file_name):
    with open(file_name,'rb') as f:
        return pickle.load(f)



def read_data(data_name):
    path_to_data = '../data/citation_networks_binary/'
    adj = load_obj(path_to_data + data_name+ '_adj.pkl')
    features = load_obj(path_to_data + data_name + '_features.pkl')
    labels =load_obj(path_to_data + data_name + '_labels.pkl')
    tvt_nids = load_obj(path_to_data + data_name + '_tvt_nids.pkl')
    return adj, features, labels, tvt_nids




def new_read_data(data_name):
    path_to_data = '../new_data/'
    adj = load_obj(path_to_data + data_name+ '_adj.pkl')
    features = load_obj(path_to_data + data_name + '_features.pkl')
    labels =load_obj(path_to_data + data_name + '_labels.pkl')
#     tvt_nids = load_obj(path_to_data + data_name + '_tvt_nids.pkl')
    return adj, features, labels




    



    
def run(data_name, adj, features, labels, tvt_nids ,torch_seed, model_name, hidden_size,wd,dp, epochs, save_model, d=None, dev= None, lr = 0.01,args=None):  
    labels = torch.LongTensor(labels)




    out_dir = 'N_original_results/'

    save_dir = out_dir + 'torch_seed_' +str(torch_seed) +'/' + data_name + '/'  + 'lr_' +str(lr) + '_hidden_size_'  + str(hidden_size) + '_wd_' + str(wd) + '_dp_' + str(dp)   + '_ep_' + str(epochs)  + '/'
        

    model_dir = save_dir + model_name + '/'
    if not os.path.exists(model_dir):
        if save_model:
            os.makedirs(model_dir)

            
    model_dir = save_dir + model_name +'/'
    gnn = GNN_Models(adj, features, labels, tvt_nids, cuda=0, hidden_size=hidden_size, epochs=epochs, seed=torch_seed, lr=lr, weight_decay=wd, dropout=dp, log=False,activation=args.activation,model_name=model_name, save_path = model_dir,save_model = save_model, feat_normalize = args.feat_normalize, balance = args.balance)
    
    acc = gnn.fit()
    return acc
  
    
    

def main():
    torch_seed = 187


    torch.manual_seed(torch_seed)

    data_name = args.dataset
    hidden_size = args.hidden_size
    lr =args.lr
    new = args.new
    wd = args.wd
    dp = args.dp 
    epochs = args.epochs
    save_model= args.save_model
    num_seeds = args.num_seeds
    multi_splits = args.ms
    if multi_splits:
        accs = []
        accs_low=[]
        accs_high = []
        for i in range(args.num_shuffle):
            adj, features, labels = new_read_data(data_name)
            tvt_nids = load_obj('../new_data/' + data_name + '_tvt_nids_{}.pkl'.format(i))
            if isinstance(labels, np.ndarray):
                labels = torch.LongTensor(labels)



            model_name = args.gnn
            torch_seeds = list(range(num_seeds))

            for torch_seed in torch_seeds:
                acc = run(data_name, adj, features, labels, tvt_nids ,torch_seed ,model_name, hidden_size, wd,dp, epochs, save_model, lr=lr, args=args)
                accs.append(acc)

        print(np.asarray(accs).mean(), np.asarray(accs).std())
        
    else:
        if new:
            adj, features, labels = new_read_data(data_name)
            tvt_nids = load_obj('../new_data/' + data_name + '_tvt_nids_{}.pkl'.format(0))
        else:
            adj, features, labels, tvt_nids = read_data(data_name)

        if isinstance(labels, np.ndarray):
            labels = torch.LongTensor(labels)



        model_name = args.gnn
        torch_seeds = list(range(num_seeds))
        accs = []
        accs_low=[]
        accs_high = []
    
            
    save_dir = 'experiment_results/' + data_name + '/' 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    file_name = model_name + '_ms_' + str(multi_splits) 
    save_path = save_dir + file_name

    save_results = save_path + '_results.txt'

    with open(save_results, 'wb') as f:
        pickle.dump(accs, f)



    with open(save_path, 'a') as f:
        f.write('==============================================================\n')
        for attr, value in args.__dict__.items():
            line = str(attr) + ': ' + str(value) + '\n'
            f.write(line)
        f.write('--------------------------\n')
        line = 'overall: mean: {}, std: {}\n'.format(np.asarray(accs).mean(), np.asarray(accs).std()) 
        f.write(line)
        f.write('==============================================================\n')










        
            



            

            
            




if __name__ == "__main__":
    main()
