import random
random.seed(10)
import numpy as np
import argparse
import os.path as osp
import random
#import nni
from torch_geometric.utils import negative_sampling
import torch
from torch_geometric.utils import dropout_adj, degree, to_undirected
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv
from simple_param.sp import SimpleParam
from pGRACE.model import Encoder, GRACE
from pGRACE.functional import drop_feature, drop_edge_weighted, \
    degree_drop_weights, \
    evc_drop_weights, pr_drop_weights, \
    feature_drop_weights, drop_feature_weighted_2, feature_drop_weights_dense
from pGRACE.eval import log_regression, MulticlassEvaluator
from pGRACE.utils import get_base_model, get_activation, \
    generate_split, compute_pr, eigenvector_centrality, feature_norm, load_pokec, label_classification,load_fb, link_prediction, print_statistics, sens_classification
from pGRACE.dataset import get_dataset
from torch_geometric.data import Data

def train():
    model.train()
    optimizer.zero_grad()

    def drop_edge(idx: int):
        global drop_weights

        if args.drop_scheme == 'uniform':
            if idx==1:
                return dropout_adj(edges, p=args.drop_edge_rate_1)[0]
            elif idx==2:
                return dropout_adj(edges, p=args.drop_edge_rate_2)[0]
        elif args.drop_scheme in ['degree', 'evc', 'pr']:
            if idx == 1:
                return drop_edge_weighted(edges, drop_weights, p=args.drop_edge_rate_1, threshold=0.7)
            elif idx == 2:
                return drop_edge_weighted(edges, drop_weights, p=args.drop_edge_rate_2, threshold=0.7)
        else:
            raise Exception(f'undefined drop scheme: {args.drop_scheme}')

    edge_index_1 = drop_edge(1)
    edge_index_2 = drop_edge(2)
    x_1 = drop_feature(features, args.drop_feature_rate_1)
    x_2 = drop_feature(features, args.drop_feature_rate_2)

    if args.drop_scheme in ['pr', 'degree', 'evc']:
        x_1 = drop_feature_weighted_2(features, feature_weights, args.drop_feature_rate_1)
        x_2 = drop_feature_weighted_2(features, feature_weights, args.drop_feature_rate_2)

    z1 = model(x_1, edge_index_1)
    z2 = model(x_2, edge_index_2)

    loss = model.loss(z1, z2, batch_size=1024 if args.dataset == 'Coauthor-Phy' else None)
    loss.backward()
    optimizer.step()

    return loss.item()
def test_lp():
    model.eval()
    z = model(features, edges)
    results = [link_prediction(z, edges, edges_t, np.load('neg_edges_tr'+str(c+1)+'_'+name+'.npy'), np.load('neg_edges_t'+str(c+1)+'_'+name+'.npy'), sens) for c in range(3)]
    statistics = {}
    for key in results[0].keys():
        values = [r[key] for r in results]
        statistics[key] = {
            'mean': np.mean(values),
            'std': np.std(values)}
    print_statistics(statistics, 'Link prediction')

    #sens_classification(z, sens, ratio=0.1)

def test():
    model.eval()
    z = model(features, edges)
    label_classification(z, labels, sens, ratio=0.1)
    #sens_classification(z, sens, ratio=0.1)
    #evaluator = MulticlassEvaluator()
    #if args.dataset == 'WikiCS':
    #    accs = []
    #    for i in range(20):
    #        acc = log_regression(z, dataset, evaluator, split=f'wikics:{i}', num_epochs=800)['acc']
    #        accs.append(acc)
    #    acc = sum(accs) / len(accs)
    #else:
    #    acc = log_regression(z, dataset, evaluator, split='rand:0.1', num_epochs=3000, preload_split=split)['acc']

    #if final and use_nni:
    #    nni.report_final_result(acc)
    #elif use_nni:
    #    nni.report_intermediate_result(acc)

    #return acc




if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--no-cuda', action='store_true', default=True,
                        help='Disables CUDA training.')
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--num-hidden', type=int, default=256)
    parser.add_argument('--num-proj-hidden', type=int, default=256)
    parser.add_argument('--activation', type=str, default='prelu')
    parser.add_argument('--drop-edge-rate-1', type=float, default=0.3)
    parser.add_argument('--drop-edge-rate-2', type=float, default=0.4)
    parser.add_argument('--drop-feature-rate-1', type=float, default=0.1)
    parser.add_argument('--drop-feature-rate-2', type=float, default=0.0)
    parser.add_argument('--tau', type=float, default=0.4)
    parser.add_argument('--epochs', type=int, default=400,
                        help='Number of epochs to train.')
    parser.add_argument('--drop-scheme', type=str, default='degree')
    parser.add_argument('--num-layers', type=int, default=2)
    parser.add_argument('--weight-decay', type=float, default=1e-5,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument('--base-model', type=str, default='GCNConv')
    parser.add_argument('--dataset', type=str, default='pokec')
    parser.add_argument('--param', type=str, default='local:wikics.json')
    parser.add_argument('--seed', type=int, default=39788)
    parser.add_argument('--verbose', type=str, default='train,eval,final')
    parser.add_argument('--save_split', type=str, nargs='?')
    parser.add_argument('--load_split', type=str, nargs='?')
    default_param = {
        'learning-rate': 0.001,
        'num-hidden': 256,
        'num-proj-hidden': 256,
        'activation': 'prelu',
        'base-model': 'GCNConv',
        'num-layers': 2,
        'drop-edge-rate-1': 0.3,
        'drop-edge-rate-2': 0.4,
        'drop-feature-rate-1': 0.1,
        'drop-feature-rate-2': 0.0,
        'tau': 0.4,
        'num-epochs': 3000,
        'weight-decay': 1e-5,
        'drop-scheme': 'degree',
    }
    
    args = parser.parse_known_args()[0]
    learning_rate = args.lr
    num_hidden = args.num_hidden
    num_proj_hidden = args.num_proj_hidden
    activation = ({'relu': F.relu, 'prelu': nn.PReLU()})[args.activation]
    base_model = ({'GCNConv': GCNConv})[args.base_model]
    num_layers = args.num_layers
    drop_edge_rate_1 = args.drop_edge_rate_1
    drop_edge_rate_2 = args.drop_edge_rate_2
    drop_feature_rate_1 = args.drop_feature_rate_1
    drop_feature_rate_2 = args.drop_feature_rate_2
    tau = args.tau
    num_epochs = args.epochs
    weight_decay = args.weight_decay
    # add hyper-parameters into parser
    use_nni = args.param == 'nni'
    if use_nni and args.device != 'cpu':
        args.device = 'cuda'
    
    torch_seed = args.seed
    torch.manual_seed(torch_seed)
    np.random.seed(args.seed)
    #random.seed(12345)
    torch.set_num_threads(8)
    device = torch.device(args.device)

    #path = osp.expanduser('~/datasets')
    #path = osp.join(path, args.dataset)
    #dataset = get_dataset(path, args.dataset)

    #data = dataset[0]
    #data = data.to(device)
    if args.dataset == 'pokec':
        dataset = 'region_job'
        sens_attr = "region"
        predict_attr = "I_am_working_in_field"
        path = "datasets/pokec_dataset/"
        edges,features,labels,sens=load_pokec(dataset, sens_attr, predict_attr, path)
        features = feature_norm(features)
    elif args.dataset =='pokec2':
        dataset = 'region_job_2'
        sens_attr = "region"
        predict_attr = "I_am_working_in_field"
        path = "datasets/pokec_dataset/"
        edges,features,labels,sens=load_pokec(dataset, sens_attr, predict_attr, path)
        features = feature_norm(features)
    elif args.dataset == 'fbucsd':
        path='datasets/socfb-UCSD34'
        dset='UCSD34.mat'
        name='ucsd'
        edges_org, features, sens=load_fb(path,dset)
        features = feature_norm(features)
    elif args.dataset == 'fbberkeley':
        path='datasets/socfb-Berkeley13'
        dset='Berkeley13.mat'
        name='berkeley'
        edges_org, features, sens=load_fb(path,dset)
        features = feature_norm(features)
    if args.dataset[:2]=='fb':
        repeat=5
    else:
        repeat=1
    results=[]
    for r in range(repeat):
        if args.dataset == 'fbberkeley' or  args.dataset == 'fbucsd':
            edge_idx=np.load('lp_orders/'+args.dataset.replace('fb','')+'_edge_order'+str(r+1)+'.npy')
            edges=edges_org[edge_idx,:]
            num_edges=np.shape(edges)[0]
            edges_train = edges[:int(0.9*num_edges),:]
                                                                                                              
            edges_test = edges[int(0.9*num_edges):,:]
            edges = torch.LongTensor(edges_train.T)
            edges_t = torch.LongTensor(edges_test.T)
            neg_edges_tr = negative_sampling(
            edge_index=edges,
            num_nodes=len(sens),
            num_neg_samples=edges.size(1),
                    )
            neg_edges_tr=np.array(neg_edges_tr).T
            neg_edges_t = negative_sampling(
            edge_index=edges_t,
            num_nodes=len(sens),
            num_neg_samples=edges_t.size(1),
                    )
            neg_edges_t=np.array(neg_edges_t).T
            edges = edges.to(device)
            edges_t = edges_t.to(device)
            features= features.to(device)
            sens=sens.to(device)
        
        else:   
            edges = edges.to(device)
            features= features.to(device)
            labels=labels.to(device)
            sens=sens.to(device)


        encoder = Encoder(features.shape[1],num_hidden, activation,
                          base_model=base_model, k=num_layers).to(device)
        model = GRACE(encoder, num_hidden, num_proj_hidden, tau).to(device)
        optimizer = torch.optim.Adam(
            model.parameters(),lr=learning_rate, weight_decay=weight_decay)

        if args.drop_scheme == 'degree':
            drop_weights = degree_drop_weights(edges).to(device)
        elif args.drop_scheme == 'pr':
            drop_weights = pr_drop_weights(edges, aggr='sink', k=200).to(device)
        elif args.drop_scheme == 'evc':
            drop_weights = evc_drop_weights(Data(x=features, edge_index=edges)).to(device)
        else:
            drop_weights = None

        if args.drop_scheme == 'degree':
            edge_index_ = to_undirected(edges)
            node_deg = degree(edge_index_[1])
            if args.dataset == 'pokec' or args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
                feature_weights = feature_drop_weights_dense(features, node_c=node_deg).to(device)
            else:
                feature_weights = feature_drop_weights(features, node_c=node_deg).to(device)
        elif args.drop_scheme == 'pr':
            node_pr = compute_pr(edges)
            if args.dataset == 'pokec'or args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
                feature_weights = feature_drop_weights_dense(features, node_c=node_pr).to(device)
            else:
                feature_weights = feature_drop_weights(features, node_c=node_pr).to(device)
        elif args.drop_scheme == 'evc':
            node_evc = eigenvector_centrality(Data(x=features, edge_index=edges))
            if args.dataset == 'pokec'or args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
                feature_weights = feature_drop_weights_dense(features, node_c=node_evc).to(device)
            else:
                feature_weights = feature_drop_weights(features, node_c=node_evc).to(device)
        else:
            feature_weights = torch.ones((features.size(1),)).to(device)

        log = args.verbose.split(',')

        for epoch in range(1, args.epochs + 1):
            loss = train()
            print(f'(T) | Epoch={epoch:03d}, loss={loss:.4f}')

        #if epoch % 100 == 0:
        #    acc = test()

        #    if 'eval' in log:
        #        print(f'(E) | Epoch={epoch:04d}, avg_acc = {acc}')
        print("=== Final ===")
        if args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
            model.eval()
            z = model(features, edges)
            results.append(link_prediction(z, edges, edges_t, neg_edges_tr, neg_edges_t, sens))
            for key in results[0].keys():
                print(key+' :',results[r][key] )
        else:
            test()
    if args.dataset == 'fbberkeley' or args.dataset == 'fbucsd':
        statistics = {}
        for key in results[0].keys():
            values = [r[key] for r in results]
            statistics[key] = {
                'mean': np.mean(values),
                'std': np.std(values)}
        print_statistics(statistics, 'Link prediction')
        
    
    #if 'final' in log:
    #    print(f'{acc}')
