import dgl
import torch
import numpy as np
import scipy.sparse as sp

# sampler for LP
neg_sampler = None
neg_sampler_g = None

# sammplers for GM
gm_samplers = {'gm_sampler':None,
              'augmentations':None,}

# sampler for DGI
dgi_sampler = None

# sampler for metis
metis_samplers = {
              'feature_dropper':None}

# sampler for metis
par_samplers = {
              'feature_dropper':None}

# sammplers for grace
grace_samplers = {'graph_sampler':None,
              'edge_dropper':None,
              'feature_dropper':None}

def P_link_prediction(g, batch_size, neg_ratio=1, mask_ratio=0.4, device='cpu', use_saint=False):
    global neg_sampler, neg_sampler_g
    if neg_sampler == None:
        neg_sampler = dgl.dataloading.negative_sampler.GlobalUniform(neg_ratio, True)
        neg_sampler_g = dgl.dataloading.SAINTSampler('node', budget=batch_size*30)
    # sample a batch of positive edges
    if use_saint:
        g = neg_sampler_g.sample(g, 0)
    positive_eid = torch.randint(0, g.number_of_edges(), (batch_size*30,))
    sg = dgl.edge_subgraph(g, positive_eid)
    # nodes involved with these edges and the resulted subgraphs
    involved_nodes = sg.ndata[dgl.NID]
    sg = dgl.khop_in_subgraph(g, involved_nodes, 2)[0]
    all_eid = sg.edata[dgl.EID]
    # src and dst of positive edges
    new_positive_eid = torch.nonzero(positive_eid.unsqueeze(1) == all_eid.repeat(len(positive_eid), 1)).T[1]
    u, v = sg.edges()
    pos_u, pos_v = u[new_positive_eid], v[new_positive_eid]
    
    neg_u, neg_v = neg_sampler(sg, pos_u)
    
    sg.remove_edges(new_positive_eid)
    sg.ndata['feat'] = feature_masking(sg.ndata['feat'], mask_ratio, 2)
    return sg, pos_u, pos_v, neg_u, neg_v

# for this task we use in-bathc negatives 
def P_graph_matching(g, batch_size, khop=2, device ='cpu', sub_graph_size=100,
    node_drop_ratio=0.1, edge_drop_ratio=0.1, feat_drop_ratio=0.1):
    global gm_samplers
    if gm_samplers['gm_sampler'] == None:
        gm_samplers['gm_sampler'] = dgl.dataloading.SAINTSampler('node', budget=sub_graph_size)
        gm_samplers['augmentations'] = [dgl.transforms.DropEdge(edge_drop_ratio), \
                                        dgl.transforms.DropNode(node_drop_ratio), \
                                        dgl.transforms.FeatMask(feat_drop_ratio, ['feat'])]

    aug_type = np.random.choice(3, batch_size, replace=True)  
    aug_type = {k:aug_type[k] for k in range(batch_size)}
    graphs_v1 = [gm_samplers['gm_sampler'].sample(g, i) for i in range(batch_size)]
    graphs_v2 = [g.clone() for g in graphs_v1]
    for i, g in enumerate(graphs_v1):
        gm_samplers['augmentations'][aug_type[i]](g) 
    for i, g in enumerate(graphs_v2):
        gm_samplers['augmentations'][aug_type[i]](g)
    # graphs_v1 = [ gm_samplers['feature_dropper'](gm_samplers['edge_dropper'](gm_samplers['node_dropper'](g.clone()))) for i, g in enumerate(graphs)]
    # graphs_v2 = [ gm_samplers['feature_dropper'](gm_samplers['edge_dropper'](gm_samplers['node_dropper'](g.clone()))) for i, g in enumerate(graphs)]
    bg1, bg2 = dgl.batch(graphs_v1), dgl.batch(graphs_v2)
    return bg1, bg2


def P_dgi(g, batch_size, k_hop=3, device ='cpu',
        feat_drop_ratio=0.3, batch_size_multiplier=10, use_saint=False):
        if g.number_of_nodes() > batch_size*batch_size_multiplier:
            if use_saint:
                global dgi_sampler
                if dgi_sampler == None:
                    dgi_sampler = dgl.dataloading.SAINTSampler('node', budget=batch_size*batch_size_multiplier)
                g = dgi_sampler.sample(g, 0)
            else:
                node_idx = np.random.choice(g.number_of_nodes(), batch_size*batch_size_multiplier, replace=False)
                g = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
        else:
            g = g.clone()
        return g, g.ndata['feat'], feature_perm(g.ndata['feat'])


def P_metis(g, batch_size, feat_drop_ratio=0.3, k_hop=3, device ='cpu'):
    global metis_samplers
    if metis_samplers['feature_dropper'] == None:
        metis_samplers['feature_dropper'] = dgl.transforms.FeatMask(feat_drop_ratio, ['feat'])
    # node_idx = np.random.choice(g.number_of_nodes(), batch_size*2, replace=False)
    # g = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
    # metis_samplers['feature_dropper'](g)
    node_idx = np.random.choice(g.number_of_nodes(), batch_size, replace=False)
    g = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
    metis_samplers['feature_dropper'](g)
    return g, g.ndata['feat'], g.ndata['node_assignment']


def P_par(g, batch_size, feat_drop_ratio=0.3, k_hop=3, device ='cpu'):
    global par_samplers
    if par_samplers['feature_dropper'] == None:
        par_samplers['feature_dropper'] = dgl.transforms.FeatMask(feat_drop_ratio, ['feat'])
    node_idx = np.random.choice(g.number_of_nodes(), batch_size*2, replace=False)
    g = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
    par_samplers['feature_dropper'](g)
    return g, g.ndata['feat'], g.ndata['par']

def P_grace(g, batch_size, edge_drop_ratio=0.3, feat_drop_ratio=0.3, k_hop=3, device ='cpu',\
    batch_size_multiplier=10, use_saint=False):
    # node_idx = np.random.choice(g.number_of_nodes(), batch_size, replace=False)
    # g = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
    global grace_samplers
    if grace_samplers['graph_sampler'] == None:
        grace_samplers['graph_sampler']= dgl.dataloading.SAINTSampler('node', budget=batch_size*batch_size_multiplier)
        grace_samplers['edge_dropper_1'] = dgl.transforms.DropEdge(edge_drop_ratio)
        grace_samplers['feature_dropper_1'] = dgl.transforms.FeatMask(feat_drop_ratio, ['feat'])
        grace_samplers['edge_dropper_2'] = dgl.transforms.DropEdge(edge_drop_ratio)
        grace_samplers['feature_dropper_2'] = dgl.transforms.FeatMask(feat_drop_ratio, ['feat'])
        
    if g.number_of_nodes() > batch_size*batch_size_multiplier:
        if use_saint:
            g1 = grace_samplers['graph_sampler'].sample(g, 0)
        else:
            node_idx = np.random.choice(g.number_of_nodes(), batch_size*batch_size_multiplier, replace=False)
            g1 = dgl.khop_in_subgraph(g, node_idx, k=k_hop)[0]
    else:
        g1 = g.clone()
    g2 = g1.clone()
    grace_samplers['feature_dropper_1'](grace_samplers['edge_dropper_1'](g1))
    grace_samplers['feature_dropper_2'](grace_samplers['edge_dropper_2'](g2))
    return g1, g1.ndata['feat'], g2, g2.ndata['feat']

def edge_dropping(g, drop_ratio, augmentation_type):
    if augmentation_type == 0:
        g = g.clone()
        n_edges = g.number_of_edges()
        edge_idx = np.random.choice(n_edges, int(n_edges*drop_ratio), replace=False)
        g.remove_edges(edge_idx)
        return g
    else:
        return g

def node_dropping(g, drop_ratio, augmentation_type):
    if augmentation_type == 1:
        g = g.clone()
        n_nodes = g.number_of_nodes()
        node_idx = np.random.choice(n_nodes, int(n_nodes*drop_ratio), replace=False)
        g.remove_nodes(node_idx)
        return g
    else:
        return g

def feature_masking(X, mask_ratio, augmentation_type):
    if augmentation_type == 2:
        # X = X.clone()
        # n_feature = X.shape[1]
        # feat_idx = np.random.choice(n_feature, int(n_feature*mask_ratio), replace=False)
        # X[:, feat_idx] = 0
        X = torch.nn.functional.dropout(X, mask_ratio)
        return X
    else:
        return X

def feature_perm(X):
    perm = torch.randperm(X.shape[0])
    return X[perm]