import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np
import pdb


def mask_x(x, flags):

    if flags is None:
        flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
    return x * flags[:,:,None]


def mask_adjs(adjs, flags):
    """
    :param adjs:  B x N x N or B x C x N x N
    :param flags: B x N
    :return:
    """
    if flags is None:
        flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)

    if len(adjs.shape) == 4:
        flags = flags.unsqueeze(1)  # B x 1 x N
    adjs = adjs * flags.unsqueeze(-1)
    adjs = adjs * flags.unsqueeze(-2)
    return adjs


def mask_adjs_tensor(adjs, flags):
    """
    :param adjs:  B x T x N x N or B x C x T x N x N
    :param flags: B x N
    :return:
    """
    if flags is None:
        flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)
    flags = flags.unsqueeze(1)  # B x 1 x N

    if len(adjs.shape) == 5:
        flags = flags.unsqueeze(1)  # B x 1 x 1 x N

    adjs = adjs * flags.unsqueeze(-1)
    adjs = adjs * flags.unsqueeze(-2)
    return adjs


def node_flags(adj, eps=1e-5):
    flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)

    if len(flags.shape)==3:
        flags = flags[:,0,:]
    return flags


def init_features(init, adjs=None, nfeat=10):

    if init=='zeros':
        feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
    elif init=='ones':
        feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
    elif init=='deg':
        feature = adjs.sum(dim=-1).to(torch.long)
        num_classes = nfeat
        feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
    else:
        raise NotImplementedError(f'{init} not implemented')

    flags = node_flags(adjs)

    return mask_x(feature, flags)


# def init_flags(graph_list, config):
def init_flags(graph_list, batch_size, config):
    max_node_num = config.data.max_node_num
    graph_tensor = graphs_to_tensor(graph_list, max_node_num)
    idx = np.random.randint(0, len(graph_list), batch_size)
    flags = node_flags(graph_tensor[idx])

    return flags


def gen_noise(x, flags, sym=True):
    z = torch.randn_like(x)
    if sym:
        z = z.triu(1)
        z = z + z.transpose(-1, -2)
        #####
        if len(z.shape) == 3:   # B x N x N
            z = mask_adjs(z, flags)
        elif len(z.shape) == 4: # B x T x N x N
            z = mask_adjs_tensor(z, flags)
        else:
            raise ValueError('Wrong shape of adj')
    else:
        z = mask_x(z, flags)
    return z


def quantize(adjs, thr=0.5):
    # check_sym(adjs)
    adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
    return adjs_


def dequantize(x, flags, sym=True, c=0.8):
    noise = c * torch.rand_like(x)
    if sym:
        noise = noise.triu(1)
        noise = noise + noise.transpose(-1,2)
        noise = mask_adjs(noise, flags)
    else:
        noise = mask_x(noise, flags)
    return x + noise


def adjs_to_graphs(adjs, is_cuda=False):
    graph_list = []
    for adj in adjs:
        if is_cuda and type(adjs).__name__ != 'ndarray':
            adj = adj.detach().cpu().numpy()
        G = nx.from_numpy_matrix(adj)
        # G.remove_edges_from(G.selfloop_edges())
        G.remove_edges_from(nx.selfloop_edges(G))
        G.remove_nodes_from(list(nx.isolates(G)))
        if G.number_of_nodes() < 1:
            G.add_node(1)
        graph_list.append(G)
    return graph_list


#####
def quantize_mol_tensor(adjs):                  # adjs: 32, 4, 9, 9 or 4, 9, 9
    if type(adjs).__name__ == 'Tensor':
        adjs = adjs.detach().cpu().numpy()
    adjs = np.argmax(adjs, axis=-3)             # 32, 9, 9
    adjs[adjs == 3] = -1
    adjs += 1                                   # bonds 0, 1, 2, 3 -> 1, 2, 3, 0
    return adjs


#####
def quantize_mol(adjs):                         # adjs: 32, 9, 9
    if type(adjs).__name__ == 'Tensor':
        adjs = adjs.detach().cpu()
    else:
        adjs = torch.tensor(adjs)
    adjs[adjs >= 2.5] = 3
    adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
    adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
    adjs[adjs < 0.5] = 0
    return np.array(adjs.to(torch.int64))


def check_sym(adjs):
    tr_adjs = adjs.transpose(-1,-2)
    if not (adjs-tr_adjs).abs().sum([0,1,2]) < 1e-2:
        raise ValueError('Not symmetric')


def node_feature_to_matrix(x):
    """
    :param x:  BS x N x F
    :return:
    x_pair: BS x N x N x 2F
    """
    x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1)  # BS x N x N x F
    x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1)  # BS x N x N x 2F

    return x_pair


def pow_tensor(x, cnum):
    # x : B x N x N
    x_ = x.clone()
    xc = [x.unsqueeze(1)]
    for _ in range(cnum-1):
        x_ = torch.bmm(x_, x)
        xc.append(x_.unsqueeze(1))
    xc = torch.cat(xc, dim=1)

    return xc


def add_loop(adj, flags):
    adj = adj.clone()
    N = adj.shape[-1]
    idx = torch.arange(N, dtype=torch.long, device=adj.device)
    if len(adj.shape)==2:
        adj[idx, idx] = 1.
    elif len(adj.shape)==3:
        adj[:, idx, idx] = 1. 
    elif len(adj.shape)==4: # B x C x N x N
        adj[:, :, idx, idx] = 1. 
    return mask_adjs(adj, flags)
    

def delete_loop(adj, flags):
    adj = adj.clone()
    N = adj.shape[-1]
    idx = torch.arange(N, dtype=torch.long, device=adj.device)
    if len(adj.shape)==2:
        adj[idx, idx] = 0.
    elif len(adj.shape)==3:
        adj[:, idx, idx] = 0. 
    elif len(adj.shape)==4: # B x C x N x N
        adj[:, :, idx, idx] = 0. 
    return mask_adjs(adj, flags)


def eyes_like(adj, flags):

    eye = torch.zeros_like(adj)
    N = adj.shape[-1]
    idx = torch.arange(N, dtype=torch.long, device=adj.device)
    if len(adj.shape)==2:
        eye[idx, idx] = 1.
    elif len(adj.shape)==3:
        eye[:,idx,idx] = 1.
    return mask_adjs(eye, flags)


def pad_adjs(ori_adj, node_number):
    a = ori_adj
    ori_len = a.shape[-1]
    if ori_len == node_number:
        return a
    if ori_len > node_number:
        raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
    a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
    a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
    # a = np.logical_or(a, np.identity(node_number))
    return a


def graphs_to_tensor(graph_list, max_node_num):
    adjs_list = []
    max_node_num = max_node_num

    for g in graph_list:
        assert isinstance(g, nx.Graph)
        node_list = []
        for v, feature in g.nodes.data('feature'):
            node_list.append(v)

        adj = nx.to_numpy_matrix(g, nodelist=node_list)
        padded_adj = pad_adjs(adj, node_number=max_node_num)
        adjs_list.append(padded_adj)
    del graph_list

    adjs_np = np.asarray(adjs_list)
    del adjs_list

    adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
    del adjs_np

    return adjs_tensor
