import torch
from utils.graph_utils import *
import networkx as nx

def implicitConstr_transform (xs, adjs, dataset):
    if dataset in ['QM9', 'ZINC250k']:
        # molecular
        constraints = ["quantize_mol", "lcc"]
    else:
        # non-molecular
        constraints = ["quantize", "no_selfloop", "no_island"]
    T_xs, T_adjs = torch.clone (xs), torch.clone(adjs)
    for constraint in constraints:
        if constraint == 'quantize':
            # just mask
            T_adjs = quantize(T_adjs) * T_adjs
            # clamp
            # T_adjs = torch.clamp(T_adjs, min=0.5, max=1.)
            # soft masking
        elif constraint == 'quantize_mol':
            # just mask
            T_adjs = T_adjs * (quantize_mol(T_adjs) >= 1)
            T_xs = T_xs * quantize (T_xs)
            # clamp
            # T_adjs = torch.clamp(T_adjs, min=0.5, max=4.)
            # T_xs = torch.clamp(T_adjs, min=0.5, max=1.)
        elif constraint == 'lcc':
            lcc_mask = torch.zeros_like(T_adjs)
            for i in range(T_adjs.shape[0]):
                G = nx.from_numpy_matrix(T_adjs.cpu().detach().numpy())
                largest_cc = list(max(nx.connected_components(G), key=len))
                for n in largest_cc:
                    lcc_mask[i, n, largest_cc] = 1
            T_adjs = T_adjs * lcc_mask
        elif constraint == 'no_selfloop':
            T_adjs[:, torch.arange(T_adjs.shape[1]), torch.arange(T_adjs.shape[2])] = 0
        elif constraint == 'no_island':
            # will already be zero after no selfloop.
            T_adjs = T_adjs
    return T_xs, T_adjs


def soft (x, lambda1):
    return torch.where(torch.abs(x) > lambda1, x - torch.sign(x)*lambda1, torch.zeros_like(x))

def hard (x, lambda1):
    return torch.where(torch.abs(x) > lambda1, x, torch.zeros_like(x))

def plus_fn (x):
    return torch.where(x > 0, x, torch.zeros_like(x))

def bisection(v, func, a=0, b=None, epsilon=1e-5, iter_max=100):
    miu = a
    for _ in range(int(iter_max)):
        miu = (a + b) / 2
        # print (epsilon, func(miu), func(a), a, b, b-a, miu)
        # return miu
        # Check if middle point is root
        if (func(miu) == 0.0):
            break
        # Decide the side to repeat the steps
        if (func(miu) * func(a) < 0):
            b = miu
        else:
            a = miu
        if ((b - a) <= epsilon):
            break
    return miu

def satisfies (adjs, xs, constraint_config, ngraph=None, zero_tol=1e-4, tol=0.01):
    adjs = torch.stack(adjs) if type(adjs) is list else adjs
    xs = torch.stack(xs) if type(xs) is list else xs 
    if constraint_config.constraint == 'None':
        if ngraph is not None:
            return True
        else:
            return torch.ones (adjs.shape[0], dtype=bool, device=adjs.device)
    elif constraint_config.constraint == 'L1-adj':
        if constraint_config.params[0] == 'zeros':
            adj0 = torch.zeros_like (adjs)
        else:
            adj0 = torch.load (constraint_config.params[0])
        # print (torch.norm(adj - adj0, p=1), constraint_config.params[0])
        row_inds, col_inds = torch.triu_indices(adjs.shape[-1], adjs.shape[-1])
        if ngraph is not None:
            return torch.norm((adjs - adj0[ngraph])[row_inds, col_inds], p=1) <= constraint_config.params[1] + tol*constraint_config.params[1]
        else:
            return torch.norm((adjs - adj0)[:, row_inds, col_inds], p=1, dim=1) <= constraint_config.params[1] + tol*constraint_config.params[1]
    elif constraint_config.constraint == 'L2-adj':
        if constraint_config.params[0] == 'zeros':
            adj0 = torch.zeros_like (adjs)
        else:
            adj0 = torch.load (constraint_config.params[0])
        # print (torch.norm(adj - adj0, p=1), constraint_config.params[0])
        row_inds, col_inds = torch.triu_indices(adjs.shape[-1], adjs.shape[-1])
        if ngraph is not None:
            return torch.norm((adjs - adj0[ngraph])[row_inds, col_inds], p=2) <= constraint_config.params[1] + tol*constraint_config.params[1]
        else:
            return torch.norm((adjs - adj0)[:, row_inds, col_inds], p=2, dim=1) <= constraint_config.params[1] + tol*constraint_config.params[1]
    elif constraint_config.constraint == 'Spectral-radius':
        if ngraph is not None:
            eigs = torch.linalg.eigvalsh (adjs)
            return eigs[-1] <= constraint_config.params[0] + tol*constraint_config.params[0]
        else:
            eigs = torch.linalg.eigvalsh (adjs)
            eigs[torch.abs(eigs) <= zero_tol] = 0
            return eigs[:, -1] <= constraint_config.params[0] + tol*constraint_config.params[0]
    elif constraint_config.constraint == 'Rank':
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            _, eigs, _ = torch.linalg.svd (L)
            return ((torch.abs(eigs) > zero_tol).sum() <= constraint_config.params[0] + tol*constraint_config.params[0])
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            _, eigs, _ = torch.linalg.svd (Ls)
            return (torch.abs(eigs) > zero_tol).sum(dim=1) <= constraint_config.params[0] + tol*constraint_config.params[0]
    elif constraint_config.constraint == 'Nconn_atleast':
        # eigendecomposition
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            _, eigs, _ = torch.linalg.svd (L)
            return ((torch.abs(eigs) <= zero_tol).sum() >= constraint_config.params[0] - tol*constraint_config.params[0])
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            _, eigs, _ = torch.linalg.svd (Ls)
            return ((torch.abs(eigs) <= zero_tol).sum(dim=1) >= constraint_config.params[0] - tol*constraint_config.params[0])
    elif constraint_config.constraint == 'Nconn_atmost':
        # eigendecomposition
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            _, eigs, _ = torch.linalg.svd (L)
            return ((torch.abs(eigs) <= zero_tol).sum() <= constraint_config.params[0] + tol*constraint_config.params[0])
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            _, eigs, _ = torch.linalg.svd (Ls)
            return ((torch.abs(eigs) <= zero_tol).sum(dim=1) <= constraint_config.params[0] + tol*constraint_config.params[0])
    elif constraint_config.constraint == 'Eigenvalue-Box':
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            # eigs = torch.linalg.eigvalsh (L)
            _, eigs, _ = torch.linalg.svd (L)
            eigs = torch.flip(eigs, [0])
            eigs[torch.abs(eigs) <= zero_tol] = 0
            return ((eigs[constraint_config.params[0]] >= constraint_config.params[1]-tol*constraint_config.params[1]) &
                    (eigs[constraint_config.params[0]] <= constraint_config.params[2]+tol*constraint_config.params[2]))
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            # eigs = torch.linalg.eigvalsh (Ls)
            _, eigs, _ = torch.linalg.svd (Ls)
            eigs = torch.flip(eigs, [1])
            eigs[torch.abs(eigs) <= zero_tol] = 0
            return ((eigs[:, constraint_config.params[0]] >= constraint_config.params[1]-tol*constraint_config.params[1]) &
                    (eigs[:, constraint_config.params[0]] <= constraint_config.params[2]+tol*constraint_config.params[2]))
    elif constraint_config.constraint == 'Cheeger-bound':
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            # eigs = torch.linalg.eigvalsh (L)
            _, eigs, _ = torch.linalg.svd (L)
            eigs = torch.flip(eigs, [0])
            eigs[torch.abs(eigs) <= zero_tol] = 0
            cheeger_chi, dmax = constraint_config.params[0], constraint_config.params[1]
            lbound, ubound = cheeger_chi**2/(2*dmax), 2*cheeger_chi
            # find the first non-trivial index
            nnz_inds = eigs > 0
            nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[0], end=0, step=-1, 
                                                device=adjs.device, dtype=adjs.dtype)
            first_nnz_inds = torch.argmax(nnz_srtd, 1)
            return ((eigs[first_nnz_inds] >= lbound-tol*lbound) &
                    (eigs[first_nnz_inds] <=  ubound+tol*ubound))
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            # eigs = torch.linalg.eigvalsh (Ls)
            _, eigs, _ = torch.linalg.svd (Ls)
            eigs = torch.flip(eigs, [1])
            eigs[torch.abs(eigs) <= zero_tol] = 0
            cheeger_chi, dmax = constraint_config.params[0], constraint_config.params[1]
            lbound, ubound = cheeger_chi**2/(2*dmax), 2*cheeger_chi #*adjs.sum(dim=2).max(dim=1)[0]
            # find the first non-trivial index
            nnz_inds = eigs > 0
            nnz_srtd = nnz_inds * torch.arange(start=eigs.shape[1], end=0, step=-1, 
                                                device=adjs.device, dtype=adjs.dtype)
            first_nnz_inds = torch.argmax(nnz_srtd, 1)
            return ((eigs[torch.arange(eigs.shape[0]), first_nnz_inds] >= lbound-tol*lbound) &
                    (eigs[torch.arange(eigs.shape[0]), first_nnz_inds] <=  ubound+tol*ubound))
    elif constraint_config.constraint == 'Eigensum':
        if ngraph is not None:
            L = torch.diag (torch.sum(adjs, dim=1)) - adjs
            # eigs = torch.linalg.eigvalsh (L)
            _, eigs, _ = torch.linalg.svd (L)
            eigs[torch.abs(eigs) <= zero_tol] = 0
            return (eigs.sum() <= constraint_config.params[0] + tol*constraint_config.params[0])
        else:
            Ls = torch.stack([torch.diag (s) for s in torch.sum(adjs, dim=1)]) - adjs
            _, eigs, _ = torch.linalg.svd (Ls)
            eigs[torch.abs(eigs) <= zero_tol] = 0
            return eigs.sum(dim=1) <= constraint_config.params[0] + tol*constraint_config.params[0]
    elif constraint_config.constraint == 'Num-triangles':
        # eigendecomposition
        bound = constraint_config.params[0] if constraint_config.params[0] != 0 else zero_tol
        if ngraph is not None:
            return (torch.trace(torch.matrix_power(adjs, 3))/6 - bound <= (1+tol)*bound)
        else:
            return (torch.diagonal(torch.matrix_power(adjs, 3), dim1=1, dim2=2).sum(dim=1)/6 <= (1+tol)*bound)
        # bound = constraint_config.params[0] if constraint_config.params[0] != 0 else zero_tol
        # if ngraph is not None:
        #     return ((torch.trace(torch.matrix_power(adjs, 3))/6 - bound).abs() <= zero_tol)
        # else:
        #     return ((torch.diagonal(torch.matrix_power(adjs, 3), dim1=1, dim2=2).sum(dim=1)/6 - bound).abs() <= zero_tol)
    elif constraint_config.constraint == 'Diameter':
        d = constraint_config.params[0]
        if ngraph is not None:
            I = torch.eye(adjs.shape[1], dtype=adjs.dtype, device=adjs.device)
            return ((torch.matrix_power(I + adjs, d).abs() <= zero_tol).sum() == 0)
        else:
            Is = torch.stack([torch.eye(adjs.shape[1], dtype=adjs.dtype, device=adjs.device) for _ in range(adjs.shape[0])])
            return ((torch.matrix_power(Is + adjs, d).abs() <= zero_tol).sum(dim=1).sum(dim=1) == 0)
    elif constraint_config.constraint == 'Valency':
        valencies = torch.tensor(constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        if ngraph is not None:
            if "no_hidden_hs" not in constraint_config.params:
                atoms_exist = torch.cat([torch.any(xs[i] > 0.5).ravel() for i in range(xs.shape[0])])
                adj_n = adjs[atoms_exist][:, atoms_exist]
                x_n = xs[atoms_exist]
                wtd_vals = (x_n @ valencies[:, None]).reshape(-1)
                # if torch.any (adj_n.sum(dim=1) > (wtd_vals + tol*wtd_vals)):
                #     print (adj, x, atoms_exist)
                # print([adj_n[j].sum() <= x_n[j].dot(valencies) for j in range(adj_n.shape[0])])
                return (torch.all (adj_n.sum(dim=1) <= (wtd_vals + tol*wtd_vals)))
            else:
                return (torch.all (adjs.sum(dim=1) == xs @ valencies))
        else:
            if "no_hidden_hs" not in constraint_config.params:
                wtd_vals = torch.matmul (xs, valencies[:, None]).squeeze()
                atoms_exist = torch.any(xs > 0.5, dim=2, keepdim=True).squeeze()
                # torch.save(atoms_exist, "atoms_exist_sat.pt")
                # X_atoms_exist = atoms_exist.repeat(1, 1, Xs.shape[2])
                a_atoms_exist = torch.einsum ('ij,ik->ijk', atoms_exist, atoms_exist)
                adj_n = ((adjs*a_atoms_exist).sum(dim=2) <= (wtd_vals + tol*wtd_vals))
                return torch.all ((adjs*a_atoms_exist).sum(dim=2) <= (wtd_vals + tol*wtd_vals), dim=1)
            else:
                wtd_vals = torch.matmul (xs, valencies[:, None]).squeeze()
                return torch.all (adjs.sum(dim=2) <= (wtd_vals + tol*wtd_vals), dim=1)
    elif constraint_config.constraint == 'Atom-Count':
        atomCounts = torch.tensor(constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        atomCounts[atomCounts == 0] = zero_tol
        if ngraph is not None:
            return torch.all (xs.sum(dim=0) <= atomCounts)
        else:
            return torch.all (xs.sum(dim=1) <= atomCounts[None, :], dim=1)
    elif constraint_config.constraint == 'Mol-Weight':
        atomWeights = torch.tensor(constraint_config.params[0], dtype=xs.dtype, device=xs.device)
        if ngraph is not None:
            return ((xs @ atomWeights).sum(dim=0) <= constraint_config.params[1] + tol*constraint_config.params[1])
        else:
            # print ((xs @ atomWeights[:, None]).shape, (xs @ atomWeights[:, None]).sum(dim=1).shape)
            return ((xs @ atomWeights[:, None]).sum(dim=1) <= constraint_config.params[1] + tol*constraint_config.params[1])
    elif constraint_config.constraint == 'Regression':
        c_theta = torch.load(constraint_config.params[0])
        b_theta = torch.load(constraint_config.params[1])
        graph_vec = torch.cat((xs.reshape(xs.shape[0],-1), adjs.reshape(xs.shape[0],-1)), dim=1)
        if ngraph is not None:
            return torch.dot(c_theta, graph_vec) <= b_theta + tol*b_theta
        else:
            return (torch.einsum('ij,j->i', graph_vec, c_theta) <= b_theta + tol*b_theta)
    else:
        raise NotImplementedError (f"{constraint_config.constraint} not supported")
