import math

import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np
from scipy.optimize import least_squares, minimize


# -------- Mask batch of node features with 0-1 flags tensor --------
def mask_x(x, flags):
    if flags is None:
        flags = torch.ones((x.shape[0], x.shape[1]), device=x.device)
    return x * flags[:, :, None]

# -------- Mask batch of adjacency matrices with 0-1 flags tensor --------
def mask_adjs(adjs, flags):
    """
    :param adjs:  B x N x N or B x C x N x N
    :param flags: B x N
    :return:
    """
    if flags is None:
        flags = torch.ones((adjs.shape[0], adjs.shape[-1]), device=adjs.device)

    if len(adjs.shape) == 4:
        flags = flags.unsqueeze(1)  # B x 1 x N
    adjs = adjs * flags.unsqueeze(-1)
    adjs = adjs * flags.unsqueeze(-2)
    return adjs

# -------- Create flags tensor from graph dataset --------
def node_flags(adj, eps=1e-5):
    flags = torch.abs(adj).sum(-1).gt(eps).to(dtype=torch.float32)

    if len(flags.shape) == 3:
        flags = flags[:, 0, :]
    return flags

# -------- Create initial node features --------
def init_features(init, adjs=None, nfeat=10):
    if init == 'zeros':
        feature = torch.zeros((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
    elif init == 'ones':
        feature = torch.ones((adjs.size(0), adjs.size(1), nfeat), dtype=torch.float32, device=adjs.device)
    elif init == 'deg':
        feature = adjs.sum(dim=-1).to(torch.long)
        num_classes = nfeat
        try:
            feature = F.one_hot(feature, num_classes=num_classes).to(torch.float32)
        except:
            print(feature.max())
            raise NotImplementedError(f'max_feat_num mismatch')
    else:
        raise NotImplementedError(f'{init} not implemented')

    flags = node_flags(adjs)

    return mask_x(feature, flags)

# -------- Sample initial flags tensor from the training graph set --------
def init_flags(graph_list, config, batch_size=None):
    if batch_size is None:
        batch_size = config.data.batch_size
    max_node_num = config.data.max_node_num
    graph_tensor = graphs_to_tensor(graph_list, max_node_num)
    idx = np.random.randint(0, len(graph_list), batch_size)
    flags = node_flags(graph_tensor[idx])
    selected_trains = graph_tensor[idx]

    return flags, selected_trains


# -------- Generate noise --------
def gen_noise(x, flags, sym=True):
    z = torch.randn_like(x)
    if sym:
        z = z.triu(1)
        z = z + z.transpose(-1, -2)
        z = mask_adjs(z, flags)
    else:
        z = mask_x(z, flags)
    return z

def gen_spec_noise(adj, flags, u, la):
    z = torch.randn_like(la, device=adj.device)
    return z

# -------- Quantize generated graphs --------
def quantize(adjs, thr=0.5):
    adjs_ = torch.where(adjs < thr, torch.zeros_like(adjs), torch.ones_like(adjs))
    return adjs_

# -------- Quantize generated molecules --------
# adjs: 32 x 9 x 9
def quantize_mol(adjs):
    if type(adjs).__name__ == 'Tensor':
        adjs = adjs.detach().cpu()
    else:
        adjs = torch.tensor(adjs)
    adjs[adjs >= 2.5] = 3
    adjs[torch.bitwise_and(adjs >= 1.5, adjs < 2.5)] = 2
    adjs[torch.bitwise_and(adjs >= 0.5, adjs < 1.5)] = 1
    adjs[adjs < 0.5] = 0
    return np.array(adjs.to(torch.int64))

def adjs_to_graphs(adjs, is_cuda=False):
    graph_list = []
    for adj in adjs:
        if is_cuda:
            adj = adj.detach().cpu().numpy()
        G = nx.Graph(adj)
        G.remove_edges_from(nx.selfloop_edges(G))
        G.remove_nodes_from(list(nx.isolates(G)))
        if G.number_of_nodes() < 1:
            G.add_node(1)
        graph_list.append(G)
    return graph_list

# -------- Check if the adjacency matrices are symmetric --------
def check_sym(adjs, print_val=False):
    sym_error = (adjs - adjs.transpose(-1, -2)).abs().sum([0, 1, 2])
    if not sym_error < 1e-2:
        raise ValueError(f'Not symmetric: {sym_error:.4e}')
    if print_val:
        print(f'{sym_error:.4e}')

# -------- Create higher order adjacency matrices --------
def pow_tensor(x, cnum):
    # x : B x N x N
    x_ = x.clone()
    xc = [x.unsqueeze(1)]
    for _ in range(cnum - 1):
        x_ = torch.bmm(x_, x)
        xc.append(x_.unsqueeze(1))
    xc = torch.cat(xc, dim=1)

    return xc

# -------- Create padded adjacency matrices --------
def pad_adjs(ori_adj, node_number):
    a = ori_adj
    ori_len = a.shape[-1]
    if ori_len == node_number:
        return a
    if ori_len > node_number:
        raise ValueError(f'ori_len {ori_len} > node_number {node_number}')
    a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
    a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
    return a

def graphs_to_tensor(graph_list, max_node_num):
    adjs_list = []
    max_node_num = max_node_num

    for g in graph_list:
        assert isinstance(g, nx.Graph)
        node_list = []
        for v, feature in g.nodes.data('feature'):
            node_list.append(v)

        adj = nx.to_numpy_array(g, nodelist=node_list)
        padded_adj = pad_adjs(adj, node_number=max_node_num)
        adjs_list.append(padded_adj)

    del graph_list

    adjs_np = np.asarray(adjs_list)
    del adjs_list

    adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)
    del adjs_np

    return adjs_tensor

def graphs_to_adj(graph, max_node_num):
    max_node_num = max_node_num

    assert isinstance(graph, nx.Graph)
    node_list = []
    for v, feature in graph.nodes.data('feature'):
        node_list.append(v)

    adj = nx.to_numpy_array(graph, nodelist=node_list)
    padded_adj = pad_adjs(adj, node_number=max_node_num)

    adj = torch.tensor(padded_adj, dtype=torch.float32)
    del padded_adj

    return adj

# This threshold is dependent on edge weight
# empirically,
# 0.1 for community_small, grid, enzymes, and ego_small
# 0.3 for qm9 and zinc250k
def estimate_degrees_iterative(L_norm, type='m', threshold=1e-1):
    threshold = 1e-1 if type == 'g' else 3e-1
    n = L_norm.shape[0]
    estimated_degrees = np.zeros(n, dtype=int)

    for i in range(n):
        connections = np.where(np.abs(L_norm[i, :]) > threshold)[0]
        connections = connections[connections != i]  # 排除自身
        estimated_degrees[i] = len(connections)

    return estimated_degrees

def compute_normalized_laplacian(adj_tensor):
    if not isinstance(adj_tensor, torch.Tensor):
        adj_tensor = torch.tensor(adj_tensor, dtype=torch.float32)

    if adj_tensor.dim() == 3:
        batch_size, n, _ = adj_tensor.shape
    else:
        adj_tensor = adj_tensor.unsqueeze(0)
        batch_size, n, _ = adj_tensor.shape

    binary_adj = (adj_tensor != 0).to(torch.int)
    degrees = torch.sum(binary_adj, dim=2)

    eigenvalues = torch.zeros((batch_size, n), device=adj_tensor.device)
    eigenvectors = torch.zeros((batch_size, n, n), device=adj_tensor.device)

    for b in range(batch_size):
        adj = adj_tensor[b].cpu().numpy()
        deg = degrees[b].cpu().numpy()

        D_inv_sqrt = np.zeros((n, n), dtype=float)
        d_inv_sqrt_values = np.where(deg > 0, 1.0 / np.sqrt(deg), 0)
        np.fill_diagonal(D_inv_sqrt, d_inv_sqrt_values)

        L_norm = np.eye(n) - D_inv_sqrt @ adj @ D_inv_sqrt

        evals, evecs = np.linalg.eigh(L_norm)

        eigenvalues[b] = torch.from_numpy(evals)
        eigenvectors[b] = torch.from_numpy(evecs)

    eigenvalues = eigenvalues - 1
    return eigenvalues, eigenvectors, degrees

def inverse_laplacian_transform(eigenvalues, eigenvectors, type):
    if not isinstance(eigenvalues, torch.Tensor):
        eigenvalues = torch.tensor(eigenvalues, dtype=torch.float32)
    if not isinstance(eigenvectors, torch.Tensor):
        eigenvectors = torch.tensor(eigenvectors, dtype=torch.float32)

    eigenvalues = eigenvalues + 1
    batch_size, n = eigenvalues.shape

    adj_matrix = torch.zeros((batch_size, n, n), device=eigenvalues.device)

    for b in range(batch_size):
        evals = eigenvalues[b].cpu().numpy()
        evecs = eigenvectors[b].cpu().numpy()

        lambda_diag = np.diag(evals)

        L_norm = evecs @ lambda_diag @ evecs.T
        deg = estimate_degrees_iterative(L_norm, type)
        D_sqrt = np.zeros((n, n), dtype=float)
        d_sqrt_values = np.where(deg > 0, np.sqrt(deg), 0)
        np.fill_diagonal(D_sqrt, d_sqrt_values)

        adj = D_sqrt @ (np.eye(n) - L_norm) @ D_sqrt

        adj = (adj + adj.T) / 2
        np.fill_diagonal(adj, 0)
        adj[np.abs(adj) < 1e-10] = 0

        adj_matrix[b] = torch.from_numpy(adj)

    return adj_matrix

def node_feature_to_matrix(x):
    """
    :param x:  BS x N x F
    :return:
    x_pair: BS x N x N x 2F
    """
    x_b = x.unsqueeze(-2).expand(x.size(0), x.size(1), x.size(1), -1)  # BS x N x N x F
    x_pair = torch.cat([x_b, x_b.transpose(1, 2)], dim=-1)  # BS x N x N x 2F

    return x_pair
