import torch
import networkx as nx
import numpy as np
import math


# -------- Mask batch of node features with 0-1 flags tensor --------
def mask_x(x, flags):
    if flags is None:
        flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
    return x * flags[:,:,None]


# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
def mask_adjs(adjs, flags):
    """
    :param adjs:  B x N x N or B x C x N x N
    :param flags: B x N
    :return:
    """
    if flags is None:
        flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)

    if len(adjs.shape) == 4:
        flags = flags.unsqueeze(1)  # B x 1 x N
    adjs = adjs * flags.unsqueeze(-1)
    adjs = adjs * flags.unsqueeze(-2)
    return adjs


# -------- Create flags tensor from graph dataset --------
def node_flags(adj, eps=1e-5):
    flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)
    if len(flags.shape)==3:
        flags = flags[:,0,:]
    return flags


# -------- Sample initial flags tensor from the training graph set --------
def get_init_flags(graph_list, config, batch_size=None):
    if batch_size is None:
        batch_size = config.data.batch_size
    max_node_num = config.data.max_node_num
    graph_tensor = graphs_to_tensor(graph_list, max_node_num)
    idx = np.random.randint(0, len(graph_list), batch_size)
    flags = node_flags(graph_tensor[idx])
    return flags


# -------- Generate noise --------
def gen_noise(x, flags, sym=True):
    z = torch.randn_like(x)
    if sym:
        z = z.triu(1)
        z = z + z.transpose(-1,-2)
        z = mask_adjs(z, flags)
    else:
        z = mask_x(z, flags)
    return z


# -------- Quantize generated graphs --------
def quantize(adjs, thr=0.5):
    adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
    return adjs_


# -------- Quantize generated molecules --------
# adjs: 32 x 9 x 9
def quantize_mol(adjs):                         
    if type(adjs).__name__ == 'Tensor':
        adjs = adjs.detach().cpu()
    else:
        adjs = torch.tensor(adjs)
    adjs[adjs >= 2.5] = 3
    adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
    adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
    adjs[adjs < 0.5] = 0
    return np.array(adjs.to(torch.int64))


def adjs_to_graphs(adjs, is_cuda=False):
    graph_list = []
    for adj in adjs:
        if is_cuda:
            adj = adj.detach().cpu().numpy()
        G = nx.from_numpy_matrix(adj)
        G.remove_edges_from(nx.selfloop_edges(G))
        G.remove_nodes_from(list(nx.isolates(G)))
        if G.number_of_nodes() < 1:
            G.add_node(1)
        graph_list.append(G)
    return graph_list


# -------- Check if the adjacency matrices are symmetric --------
def check_sym(adjs, print_val=False):
    sym_error = (adjs-adjs.transpose(-1,-2)).abs().sum([0,1,2])
    if not sym_error < 1e-2:
        raise ValueError(f'Not symmetric: {sym_error:.4e}')
    if print_val:
        print(f'{sym_error:.4e}')


# -------- Create higher order adjacency matrices --------
def pow_tensor(x, cnum):
    # x : B x N x N
    x_ = x.clone()
    xc = [x.unsqueeze(1)]
    for _ in range(cnum-1):
        deg_inv_sqrt = x.abs().sum(dim=-1).clamp(min=1).pow(-0.5)
        x_normalized = deg_inv_sqrt.unsqueeze(-1) * x * deg_inv_sqrt.unsqueeze(-2)
        x_ = torch.bmm(x_, x_normalized)
        mask = torch.ones([x.shape[-1], x.shape[-1]]) - torch.eye(x.shape[-1])
        x_ = x_ * mask.unsqueeze(0).to(x.device)
        xc.append(x_.unsqueeze(1))
    xc = torch.cat(xc, dim=1)
    return xc


# -------- Create padded adjacency matrices --------
def pad_adjs(ori_adj, node_number):
    a = ori_adj
    ori_len = a.shape[-1]
    if ori_len == node_number:
        return a
    if ori_len > node_number:
        raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
    a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
    a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
    return a


def graphs_to_tensor(graph_list, max_node_num):
    adjs_list = []
    max_node_num = max_node_num

    for g in graph_list:
        assert isinstance(g, nx.Graph)
        node_list = []
        for v, feature in g.nodes.data('feature'):
            node_list.append(v)

        adj = nx.to_numpy_matrix(g, nodelist=node_list)
        padded_adj = pad_adjs(adj, node_number=max_node_num)
        adjs_list.append(padded_adj)

    del graph_list

    adjs_np = np.asarray(adjs_list)
    del adjs_list

    adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
    del adjs_np

    return adjs_tensor 


def graphs_to_adj(graph, max_node_num):
    max_node_num = max_node_num

    assert isinstance(graph, nx.Graph)
    node_list = []
    for v, feature in graph.nodes.data('feature'):
        node_list.append(v)

    adj = nx.to_numpy_matrix(graph, nodelist=node_list)
    padded_adj = pad_adjs(adj, node_number=max_node_num)

    adj = torch.tensor(padded_adj, dtype=torch.float32)
    del padded_adj

    return adj


def deviation(z, z_target, flags, is_adj):
    if z_target is None:
        z_target = quantize(z, thr=0.5)
    dev = (z - z_target).abs().sum(dim=(1,2))
    if is_adj:
        dev = dev / flags.sum(dim=-1)**2
    else:
        dev = dev / (flags.sum(dim=-1) * z.shape[-1])
    dev_mean = dev.mean()
    return dev_mean.item()


def convergence(z_list, flags, is_adj):
    z_conv = quantize(z_list[-1], thr=0.5)
    l1_norm = [deviation(z, z_conv, flags, is_adj) for z in z_list]
    return l1_norm


def data_rescale(adj, scales, is_list=False):
    adj = adj * scales
    return adj
    

def rand_perm(x, adj):
    flags = node_flags(adj)
    batch_size = adj.shape[0]
    num_nodes = adj.shape[-1]
    idx = torch.arange(num_nodes, device=adj.device).repeat(batch_size, 1)
    for _ in range(batch_size):
        pidx = torch.where(flags[_]>0)[0]
        idx[_][pidx] = idx[_][pidx[torch.randperm(int(flags[_].sum(-1)))]]
    peye = torch.eye(num_nodes, device=adj.device)[idx]
    px = torch.bmm(peye, x)
    padj = torch.bmm(torch.bmm(peye, adj), peye.transpose(-1,-2))
    return px, padj


def sample_deviation(z):
    zq = quantize(z)
    zd = (z - zq)
    d1p = (zd * zq).max().item()
    d1m = (-zd * zq).max().item()
    d0p = (zd * (1.-zq)).max().item()
    d0m = (-zd * (1.-zq)).max().item()
    return f'({-d0m:.2f} {d0p:.2f}) ({-d1m:.2f} {d1p:.2f})'


def mol_deviation(adj):
    return f'({(-adj*(adj<0)).max().item():.2f} {(adj*(adj*(0.5-adj)>0)).max().item():.2f}) '\
            f'({(adj*((adj-0.5)*(1.0-adj)>0)).max().item():.2f} {(adj*((adj-1.0)*(1.5-adj)>0)).max().item():.2f}) '\
            f'({(adj*((adj-1.5)*(2.0-adj)>0)).max().item():.2f} {(adj*((adj-2.0)*(2.5-adj)>0)).max().item():.2f}) '\
            f'({(adj*((adj-2.5)*(3.0-adj)>0)).max().item():.2f} {(adj*((adj-3.0)*(3.5-adj)>0)).max().item():.2f})'


def pred_thr(pred, thr=0.1):
    predq = quantize(pred)
    return torch.where((pred - predq).abs()<thr, predq, pred)