import numpy as np

epsilon = 2**(-100)

def _v(P):
    """
    Returns the centroid v of the Gaussian mixture.
    
    P: int, dimension of the features
    """
    return np.random.normal(0, 1, P)

def _u(N):
    """
    Returns the communities u. We take them 0-1.
    
    N: int, number of nodes
    """
    return 1*(np.random.random(N)<0.5)

def _B(v, u, mu):
    """
    Returns the observed features of the cSBM.
    
    v: (P,) array, the centroid
    u: (N,) array, the 0-1 communities
    mu: float, snr
    
    B: (P,N) array, the features
    """
    P, N = len(v), len(u)
    return np.outer(v, 2*u-1)*np.sqrt(mu/N)+np.random.normal(0, 1, (P, N))

def mapParameters(d, _lambda):
    """
    Map to the parameters of the SBM.
    
    d: float, average degree
    _lambda: float, snr
    
    ci, co: floats, the affinity coefficients
    """
    ci = d+_lambda*np.sqrt(d)
    co = d-_lambda*np.sqrt(d)
    return ci, co

def _logpU(u, rho=0):
    """
    Returns the prior on the communities and the test indices.
    
    u: (N,) array of 0-1, ground truth communities
    rho: float, proportion of train node labels; rho=0 for unsupervised.
    
    logpU: (N,2) array, the log-prior on the groups
    isTest: ((1-rho)*N,) array, the test node indices
    """
    N = len(u)
    pU = np.ones((N, 2))*0.5
    if rho==0:
        return np.log(pU), np.arange(N)
    else:
        iis = np.random.permutation(N)
        isTrain, isTest = iis[:int(N*rho)], iis[int(N*rho):]
        pU[isTrain, 1] = epsilon+u[isTrain]
        pU[isTrain, 0] = epsilon+(1-u[isTrain])
        return np.log(pU), isTest

def createGraph(u, ci, co):
    """
    Generates the observed graph given the parameters of the SBM and the communities. It is implemented efficiently.
    
    u: (N,) array of 0–1, ground truth communities
    ci, co: floats, parameters of the SBM
    
    L: int, number of edges
    edgesIn: (L,x) array. Index l contains the padded list of edges going into edge l, the reversed edge excepted
    edgesInT: (L,y) array. Index n contains the padded list of edges going to node n
    corrLN: (L,) array. Index l is the index of the node from which edge l starts. Translates l-indices into n-indices
    """
    N = len(u)
    
    is0 = np.nonzero(1-u)[0]
    L0 = len(is0)
    is1 = np.nonzero(u)[0]
    L1 = len(is1)
    iis0 = np.nonzero(np.triu(np.random.random((L0,L0))<ci/N, 1))
    iis1 = np.nonzero(np.triu(np.random.random((L1,L1))<ci/N, 1))
    iis01 = np.nonzero(np.random.random((L0,L1))<co/N)
    
    edgesIJ = np.hstack((np.vstack((is0[iis0[0]], is0[iis0[1]])),\
               np.vstack((is1[iis1[0]], is1[iis1[1]])), np.vstack((is0[iis01[0]], is1[iis01[1]])))).T
    
    edgesD = {}
    edgesL = []
    corrLN = []
    edgesI = [[] for i in range(N)]
    L = 0

    for i,j in edgesIJ:
        edgesD[(i,j)] = L
        edgesD[(j,i)] = L+1
        edgesL.append((i,j))
        edgesL.append((j,i))
        corrLN.append(i)
        corrLN.append(j)
        edgesI[i].append(j)
        edgesI[j].append(i)
        L += 2

    edgesIn = [[] for l in range(L)]
    LmaxEdgesIn = 0
    for l in range(L):
        i, j = edgesL[l]
        for k in edgesI[i]:
            if k!=j:
                edgesIn[l].append(edgesD[(k, i)])
        LmaxEdgesIn = max(LmaxEdgesIn, len(edgesIn[l]))
    for l in range(L):
        edgesIn[l] += [L]*(LmaxEdgesIn-len(edgesIn[l]))
    edgesIn = np.array(edgesIn)

    edgesInT = [[] for i in range(N)]
    LmaxEdgesInT = 0
    for i in range(N):
        for j in edgesI[i]:
            edgesInT[i].append(edgesD[(j, i)])
        LmaxEdgesInT = max(LmaxEdgesInT, len(edgesInT[i]))
    for i in range(N):
        edgesInT[i] += [L]*(LmaxEdgesInT-len(edgesInT[i]))
    edgesInT = np.array(edgesInT)
    
    return L, edgesIn, edgesInT, corrLN

def sigmoid(xs):
    """
    The sigmoid function.
    """
    return 1/(1+np.exp(-xs))*(xs>-100)


def step(chi, uC, vC, B, corrLN, edgesIn, edgesInT, mu, ci, co, logpU):
    """
    Performs one iteration of AMP-BP.
    
    chi: (L,) array, cavitiy messages for being 1
    uC: (N,) array, estimated marginals for the groups
    vC: (P,) array, estimated marginals for the centroid
    B: (P,N) array, observed features
    corrLN: (L,) array, as returned by createGraph
    edgesIn: (L,x) array, as returned by createGraph
    edgesInT: (L,y) array, as returned by createGraph
    mu: float, snr
    ci, co: floats, affinity coefficients
    logpU: (N,2) array, log-prior on the groups
    
    chi: (L,) array, updated cavitiy messages
    uC: (N,) array, updated estimated marginals for the groups
    vC: (P,) array, updated estimated marginals for the centroid
    """
    L = len(chi)
    P, N = len(vC), len(uC)
    
    diffH = (co-ci)*np.mean(uC)
    sigmaU = 1-uC**2
    
    Bu = np.dot(B, uC)*np.sqrt(mu/N)-np.sum(sigmaU)*vC*mu/N
    Au = np.sum(uC**2)*mu/N
    vC = Bu/(1+Au)
    sigmaV = 1/(1+Au)
    
    Bv = np.dot(B.T, vC)*np.sqrt(mu/N)-sigmaV*uC*mu*P/N
    diffH = diffH+logpU[:,1]-logpU[:,0]+2*Bv
    
    chi = np.hstack((chi, [0.5]))  # to vectorize over edgesIn
    chiM = sigmoid(diffH+np.sum(np.log(co+(ci-co)*chi[edgesInT])-np.log(ci+(co-ci)*chi[edgesInT]), axis=1))
    chi = sigmoid(diffH[corrLN]+np.sum(np.log(co+(ci-co)*chi[edgesIn])-np.log(ci+(co-ci)*chi[edgesIn]), axis=1))
    uC = 2*chiM-1
    
    return chi, uC, vC

    
def phi(chi, uC, vC, sigmaV, ci, co, mu):
    """
    Compute the free entropy.
    
    chi: (L,) array, cavitiy messages for being 1, as returned after a step
    uC: (N,) array, estimated marginals for the groups, as returned after a step
    vC: (P,) array, estimated marginals for the centroid, as returned after a step
    sigmaV: float, estimated variance for the centroids
    ci, co: floats, affinity coefficients
    mu: float, snr
    
    phi: float, the free entropy
    """
    sigmaU = 1-uC**2
    Au = np.sum(uC**2)*mu/N
    Bu = np.dot(B, uC)*np.sqrt(mu/N)-np.sum(sigmaU)*vC*mu/N
    
    Fa = np.sum(-np.log(1+Au)+Bu**2/(1+Au))/2
    Fia = np.sqrt(mu/N)*np.dot(vC, np.dot(B, uC))-\
            mu/N*(np.sum(vC**2)*N/2+np.sum(uC**2)*P*sigmaV-np.sum(vC**2)*np.sum(uC**2)/2)

    Bv = np.dot(B.T, vC)*np.sqrt(mu/N)-sigmaV*uC*mu*P/N
    chi1 = np.hstack((chi, [(1-co)/(ci-co)]))  # to vectorize over edgesIn in Fi; we want 1=co+(ci-co)*chi1[i] for i being a padding index
    chi0 = np.hstack((chi, [(1-ci)/(co-ci)]))  # same
    
    hT1 = -(ci+co)/2-(ci-co)/2*np.mean(uC)+logpU[:,1]+Bv
    hT0 = -(ci+co)/2-(co-ci)/2*np.mean(uC)+logpU[:,0]-Bv
    Fi = np.sum(np.log(np.exp(hT1)*np.prod(co+(ci-co)*chi1[edgesInT], axis=1)+\
                       np.exp(hT0)*np.prod(ci+(co-ci)*chi0[edgesInT], axis=1)))
    
    tmp = 2*(ci-co)*chi[np.arange(0,L,2)]*chi[np.arange(1,L,2)]+\
            (co-ci)*(chi[np.arange(0,L,2)]+chi[np.arange(1,L,2)])+ci
    Fij = np.sum(np.log(tmp))-N*(ci+co)/4
    
    return (Fi+Fa-Fij-Fia)/N

def stepParameterEstimation(chi, uC, vC, sigmaV, B, ci, co, mu):
    """
    Performs one step of parameter estimation. The messages chi, uC, vC and sigmaV must be the fixed point of AMP-BP run at ci, co, mu.
    
    chi: (L,) array, cavitiy messages for being 1, as returned by AMP-BP
    uC: (N,) array, estimated marginals for the groups, as returned by AMP-BP
    vC: (P,) array, estimated marginals for the centroid, as returned by AMP-BP
    sigmaV: float, estimated variance for the centroids,  as estimated by AMP-BP
    B: (P,N) array, observed features
    ci, co: floats, current estimate of the affinity coefficients
    mu: float, current estimate of the snr
    
    upCi, upCo, upMu: floats, new estimates of the parameters
    """
    N, P, L = len(uC), len(vC), len(chi)
    alpha = N/P
    
    a = ci*(chi[np.arange(0,L,2)]*chi[np.arange(1,L,2)]+\
            (1-chi[np.arange(0,L,2)])*(1-chi[np.arange(1,L,2)]))
    b = 2*(ci-co)*chi[np.arange(0,L,2)]*chi[np.arange(1,L,2)]+\
            (co-ci)*(chi[np.arange(0,L,2)]+chi[np.arange(1,L,2)])+ci
    upCi = 4/N*np.sum(a/b)
    
    a = co*((1-chi[np.arange(0,L,2)])*chi[np.arange(1,L,2)]+\
            chi[np.arange(0,L,2)]*(1-chi[np.arange(1,L,2)]))
    upCo = 4/N*np.sum(a/b)
    
    r = alpha*np.dot(vC, np.dot(B, uC))/np.sqrt(N)/(alpha*np.sum(vC**2)+sigmaV*np.sum(uC**2))
    upMu = r**2
    
    return upCi, upCo, upMu


def overlapU(uC, u, isTest):
    """
    Computes the overlap between the estimation and the ground truth for the u variables.
    
    u: (N,) array of 0-1, ground truth communities
    uC: (N,) array, estimated marginals
    isTest: ((1-rho)*N,) array, test node indices
    
    qU: float, between 0 and 1
    """
    u = 2.*u[isTest]-1
    uC = np.sign(uC[isTest])
    return (max(np.mean(uC==u), np.mean(uC==-u))-0.5)/(1-0.5)
    
def overlapV(vC, v):
    """
    Computes the overlap between the estimation and the ground truth for the v variables.
    
    v: (P,) array, ground truth centroid
    vC: (P,) array, estimated centroid
    
    qV: float, between 0 and 1
    """
    return abs(np.sum(vC*v))/max(epsilon, np.linalg.norm(vC))/np.linalg.norm(v)


def initPrior(N, P, L, varInit):
    """
    Initializes the needed variables according to the prior uninformative distribution.
    
    N, P, L: ints, dimensions of the problem
    varInit: float, initial variation from 0
    
    chi: (L,) array, cavitiy messages for being 1
    uC: (N,) array, estimated marginals for the groups
    vC: (P,) array, estimated marginals for the centroid
    """
    chi = 0.5+(2*np.random.random(L)-1)*varInit
    uC = (2*np.random.random(N)-1)*varInit
    vC = (2*np.random.random(P)-1)*varInit
    
    return chi, uC, vC

def initInformed(u, v, corrLN):
    """
    Initializes the needed variables at the ground truth.
    
    u: (N,) array of 0-1, the ground truth communities
    v: (P,) array, the centroid
    corrLN: (L,) array, as returned by createGraph
    
    chi: (L,) array, cavitiy messages for being 1
    uC: (N,) array, estimated marginals for the groups
    vC: (P,) array, estimated marginals for the centroid
    """
    chi = u[corrLN]
    uC = 2.*u-1
    vC = v
    
    return chi, uC, vC
