import numpy as np

def mk_hVec(method, c, p):
    # p is a 2d matrix
    # print("c: ", c)
    if method == 'A':
        hVec = []
        supp = p[0]
        prob = p[1]
        for N in range(1, c + 1):
            h = np.zeros(N + 1, dtype=np.double)
            for j in range(len(supp)):
                h += supp[j] ** np.arange(N + 1) * (1 - supp[j]) ** np.arange(N, -1, -1) * prob[j]
            # print("h: ", h)
            hVec.append(h / np.max(np.abs(h)))
            # hVec.append(h)
        
        return hVec
    
    
# A.shape[0]: number of tasks, A.shape[1]: number of workers
def initialization(A, msg): 
    if msg == 'norm11':
        gaus_bias = np.random.randn(A.shape[1], A.shape[0])
        msg_J2I = (1 + gaus_bias) * np.double(A).T
        msg_I2J = np.zeros_like(A, dtype=np.double)
    elif msg == 'ones':
        msg_J2I = np.double(A).T
        msg_I2J = np.zeros_like(A, dtype=np.double)
    else:
        raise ValueError("Invalid initialMsg")
    return msg_J2I, msg_I2J

def all_msg_update_sump_logratio_conv(msg_in, labs_in, lgh):
    E = leaveoneout_v2_ele_sym_logexp(labs_in * msg_in)
    # print("E: ", E)
    n = E.shape[0]
    hEplus = lgh[np.newaxis, 1:] + E
    # print("hEplus: ", hEplus)
    mxplus = np.max(hEplus, axis=1)
    hEminus = lgh[np.newaxis, :-1] + E
    # print("hEminus: ", hEminus)
    mxminus = np.max(hEminus, axis=1)
    msg_out = np.log(np.sum(np.exp(hEplus - mxplus[:, np.newaxis]), axis=1)) - np.log(np.sum(np.exp(hEminus - mxminus[:, np.newaxis]), axis=1)) + mxplus - mxminus
    # print("msg_out: ", msg_out)
    return msg_out

def leaveoneout_v2_ele_sym_logexp(a):
    a = np.squeeze(a)
    eps = np.finfo(float).tiny
    # Handle 0-dimensional array (scalar) case
    if a.ndim == 0:
        # For scalar input, return appropriate format [[0, value]]
        return np.array([[0, a]])
    length = len(a)
    if length == 0:
        return np.array([[0]])
    if length == 1:
        return np.array([0, a])
    E1 = get_sub_ele_sym_logexp(a)
    E2 = get_sub_ele_sym_logexp(a[::-1])
    EE = np.zeros((length, length), dtype=np.double)
    EE[0, :] = E2[length - 2, :length]
    EE[length - 1, :] = E1[length - 2, :length]
    # Modified here in 240731
    for n in range(2, length):
        le = E1[n - 2, :n]
        lmx = np.max(le)
        re = E2[length - n - 1, :(length - n + 1)]
        rmx = np.max(re)
        tmp = np.log(np.maximum(np.convolve(np.exp(le - lmx), np.exp(re - rmx)), eps)) + lmx + rmx
        EE[n-1,:len(tmp)] = tmp
    return EE


def get_sub_ele_sym_logexp(a):
    if len(a) == 0:
        E = 0
        return E
    E = np.zeros((len(a), len(a) + 1), dtype=np.double)
    E[0, 1] = a[0]
    for n in range(1, len(a)):
        for k in range(n):
            ea = E[n - 1, k] + a[n]
            eb = E[n - 1, k + 1]
            emx = np.max([ea, eb])
            E[n, k + 1] = np.log(np.exp(ea - emx) + np.exp(eb - emx)) + emx
        E[n, n + 1] = E[n - 1, n] + a[n]
    return E

def key_alg(maxIter, msg_I2J, msg_J2I, A, lghVec):
    dxIJ = np.where(A != 0)
    dxJI = np.where(A.T != 0)
    for iter in range(maxIter):
        
        old_msg_I2J = msg_I2J.copy()
        old_msg_J2I = msg_J2I.copy()
        
        for i in range(A.shape[0]): # iterate all tasks
            neib = np.where(A[i] != 0)
            msg_in = msg_J2I[neib, i]
            labs_in = A[i, neib]
            # print(msg_in, labs_in)
            msg_I2J[i, neib] = np.sum(labs_in * msg_in) - labs_in * msg_in
            # print(msg_I2J[i, neib])
            
        for j in range(A.shape[1]):
            neib = np.where(A[:, j] != 0)
            # print("neib: ", neib)
            lgh = lghVec[len(neib[0])-1]
            # print("lgh: ", lgh)
            msg_J2I[j, neib] = all_msg_update_sump_logratio_conv(msg_I2J[neib, j], A[neib, j], lgh)
            
        err1 = np.max(np.max(np.tanh((old_msg_I2J[dxIJ]/2)) - np.tanh((msg_I2J[dxIJ]/2))));
        err2 = np.max(np.max(np.tanh((old_msg_J2I[dxJI]/2)) - np.tanh((msg_J2I[dxJI]/2))));
        err = max(err1, err2)
        if err <= 1e-6:
            break
        # print("msg_I2J: ", msg_I2J)
        # print("msg_J2I: ", msg_J2I)
        # print(err)

    # print("iter: ", iter)
    belTask = np.zeros(A.shape[0], dtype=np.double)
    for i in range(A.shape[0]):
        neib = np.where(A[i] != 0)
        raw_score = np.sum(A[i, neib] * msg_J2I[neib, i])
        # Keep raw BP scores to preserve information gain from additional reviews
        belTask[i] = raw_score

    # Return raw belTask values (preserves information gain from additional reviews)
    return belTask

def bp_modified(A, msg, maxIter, paper_per_reviewer, p):
    # paper_per_review is the maximum number of papers assigned to a reviewer
    msg_J2I, msg_I2J = initialization(A, msg)
    # print("initial:\n msg_J2I:", msg_J2I, "\n msg_I2J:", msg_I2J)
    hVec = mk_hVec('A', paper_per_reviewer, p)
    lghVec = [np.log(arr) for arr in hVec]
    # print("lghVec: ", lghVec)
    result = key_alg(maxIter, msg_I2J, msg_J2I, A, lghVec)
    return result