import numpy as np
from scipy import special

epsilon = 2**(-50)

def _ws(M):
    """
    Creates the vector of hidden variables.
    """
    return np.random.normal(0, 1, M)
    
def _F(N, M):
    """
    Creates the random features.
    """
    return np.random.normal(0,M**(-1/2),(N,M))

def _cs(ci, co):
    """
    Returns the affinity matrix given affinity coefficients.
    """
    return np.array([[ci, co],[co, ci]])

def _pS(ss, rho=0):
    """
    Returns the prior on the communities. Rho is the proportion of nodes that are revealed, rho=0 for unsupervised.
    """
    N = len(ss)
    pS = np.ones((N, 2))/2
    if rho==0:
        return pS
    else:
        iis = np.random.choice(N, int(N*rho), replace=False)
        pS[iis, 0] = epsilon+(1-epsilon)*ss[iis]
        pS[iis, 1] = epsilon+(1-epsilon)*(1-ss[iis])
        return pS

def createGraph(ss, ci, co):
    """
    Generates the observed graph given parameters and communities of the nodes. It is implemented efficiently.
    
    ss: array of 0–1, the communities
    ci, co: floats, parameters of the SBM
    
    L: int, number of edges
    edgesIn: array of length L. Indice l contains the padded list of edges going into edge l, the reversed edge excepted
    edgesInT: array of length N. Indice n contains the padded list of edges going to node n
    corrLN: array of size L. Indice l is the indice of the node from which edge l starts. Translates l-indices into n-indices
    """
    N = len(ss)
    
    is0 = np.nonzero(1-ss)[0]
    L0 = len(is0)
    is1 = np.nonzero(ss)[0]
    L1 = len(is1)
    iis0 = np.nonzero(np.triu(np.random.random((L0,L0))<ci/N, 1))
    iis1 = np.nonzero(np.triu(np.random.random((L1,L1))<ci/N, 1))
    iis01 = np.nonzero(np.random.random((L0,L1))<co/N)
    
    edgesIJ = np.hstack((np.vstack((is0[iis0[0]], is0[iis0[1]])),\
               np.vstack((is1[iis1[0]], is1[iis1[1]])), np.vstack((is0[iis01[0]], is1[iis01[1]])))).T
    
    edgesD = {}
    edgesL = []
    corrLN = []
    edgesI = [[] for i in range(N)]
    L = 0

    for i,j in edgesIJ:
        edgesD[(i,j)] = L
        edgesD[(j,i)] = L+1
        edgesL.append((i,j))
        edgesL.append((j,i))
        corrLN.append(i)
        corrLN.append(j)
        edgesI[i].append(j)
        edgesI[j].append(i)
        L += 2

    edgesIn = [[] for l in range(L)]
    LmaxEdgesIn = 0
    for l in range(L):
        i, j = edgesL[l]
        for k in edgesI[i]:
            if k!=j:
                edgesIn[l].append(edgesD[(k, i)])
        LmaxEdgesIn = max(LmaxEdgesIn, len(edgesIn[l]))
    for l in range(L):
        edgesIn[l] += [L]*(LmaxEdgesIn-len(edgesIn[l]))
    edgesIn = np.array(edgesIn)

    edgesInT = [[] for i in range(N)]
    LmaxEdgesInT = 0
    for i in range(N):
        for j in edgesI[i]:
            edgesInT[i].append(edgesD[(j, i)])
        LmaxEdgesInT = max(LmaxEdgesInT, len(edgesInT[i]))
    for i in range(N):
        edgesInT[i] += [L]*(LmaxEdgesInT-len(edgesInT[i]))
    edgesInT = np.array(edgesInT)
    
    return L, edgesIn, edgesInT, corrLN


def logMatExp(logChi, cs):
    """
    Performs log(cs @ exp(logChi)) in a stable manner. Similar to scipy.special.logsumexp.
    """
    logChiMax = np.amax(logChi, axis=1, keepdims=True)
    tmp = np.exp(logChi-logChiMax)
    return logChiMax+np.log(np.dot(tmp, cs))

def _go(chisP, om, V):
    Z = (1+(2*chisP-1)*special.erf(om*(2*V)**(-1/2)))/2
    Z = np.maximum(epsilon, Z)
    return (2*np.pi*V)**(-1/2)*(2*chisP-1)*np.exp(-om**2/2/V)/Z

def fa(Lambda, Gamma):
    return Gamma/(Lambda+1)

def fv(Lambda, Gamma):
    return 1/(Lambda+1)
    
def _psis(om, V):
    ss = np.ones((len(om), 2))
    ss[:,1] *= -1
    tmp = special.erf(om*(2*V)**(-1/2))
    tmp = np.maximum(epsilon-1, np.minimum(1-epsilon, tmp))
    return (1+ss*np.expand_dims(tmp, 1))/2


def stepAMP(a, v, chisP, goPrev, F):
    """
    Performs one step of AMP on the GLM side. Returns the updated variables.
    
    a: array M
    v: array M
    chisP: array N, namely chis[:,0]
    goPrev: array N
    F: array NxM
    """
    M, N = len(a), len(chisP)
    
    V = np.mean(v)
    om = np.dot(F, a) - V*goPrev
    
    psis = _psis(om, V)
    go = _go(chisP, om, V)
    Lambda = np.sum(go**2)/M
    Gamma = a*Lambda+np.dot(go, F)
    
    a = fa(Lambda, Gamma)
    v = fv(Lambda, Gamma)
    
    return a, v, psis, go

def stepBP(logChis, marginals, psis, corrLN, edgesIn, edgesInT, cs, pS):
    """
    Performs one step of BP on the SBM side. Returns the updated variables.
    
    logChis: array Lx2
    marginals: array Nx2
    psis: array Nx2
    corrLN: array L
    edgesIn: array Lx..
    edgesInT: array Nx..
    cs: array 2x2
    pS: array Nx2
    """
    logChiSum = logMatExp(logChis, cs)
    logChiSum = np.vstack((logChiSum, np.zeros(2)))  # to vectorize on edgesIn, which is padded with L-indices
    
    hs = np.mean(np.dot(marginals, cs), axis=0, keepdims=True)
    logChis = np.log(pS)[corrLN]+np.log(psis)[corrLN]-hs+np.sum(logChiSum[edgesIn,:], axis=1)
    chisT = pS*np.exp(-hs+np.sum(logChiSum[edgesInT,:], axis=1))
    marginals = psis*chisT
    
    logChis -= special.logsumexp(logChis, axis=1, keepdims=True)
    Z = np.sum(chisT, axis=1, keepdims=True)
    chisT /= Z
    marginals /= np.sum(marginals, axis=1, keepdims=True)
    
    return logChis, chisT, marginals


def overlapS(marginals, ss, rho=0):
    """
    Computes the overlap between the estimated marginals and the ground truth for the s variables.
    
    ss: array N, made of ±1
    marginals: array Nx2
    """
    return (abs(np.mean(np.sign(2*marginals[:,0]-1)*(2*ss-1)))-rho)/(1-rho)

def overlapW(a, ws):
    """
    Computes the overlap between the estimated mean of the marginals and the ground truth for the w variables.
    
    ws: array M
    a: array M
    """
    return abs(np.sum(a*ws)/max(epsilon, np.linalg.norm(a))/np.linalg.norm(ws))
    
def initPrior(N, L, M, varInit, F):
    """
    Initialize the needed variables according to the prior (uninformative) distribution.
    """
    logChis = (2*np.random.random((L, 2))-1)*varInit
    logChis -= special.logsumexp(logChis, axis=1, keepdims=True)
    chisT = np.ones((N, 2))/2
    marginals = np.ones((N, 2))/2
    
    a = np.random.normal(0, 1, M)*varInit
    v = np.ones(M)
    
    V = np.mean(v)
    om = np.dot(F, a)
    goPrev = _go(chisT[:,0], om, V)
    psis = _psis(om, V)
    
    return logChis, chisT, psis, marginals, a, v, goPrev
