import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)


def gett_all(ijlm: str, A, B):
    i, j, l, m = list(ijlm)
    pexp = i+'a,'+j+'a,'+l+'b,'+m+'b->'
    qexp = i+'a,'+j+'a,'+l+'a,'+m+'a->'
    pval = jnp.einsum(pexp, A, B, A, B)
    pqval = pval - jnp.einsum(qexp, A, B, A, B)
    return pval, pqval


@jax.jit
def estimate_dimensionality(A, B):
    """
    Estimate the dimensionality of the population A.
    B is either the same as A or a different trial of the same population.
    """

    P, Q = A.shape

    nf = (P * Q) ** 0.5

    t1, t1d = gett_all('ijji', A/nf, B/nf)
    t2, t2d = gett_all('iiii', A/nf, B/nf)
    t3, t3d = gett_all('ijjj', A/nf, B/nf)
    # t4,t4d = gett_all('iiij', A/nf)#<
    t5, t5d = gett_all('ijjl', A/nf, B/nf)
    t6, t6d = gett_all('iijj', A/nf, B/nf)
    t7, t7d = gett_all('iijl', A/nf, B/nf)
    # t8,t8d = gett_all('ijll', A/nf)#<
    t9, t9d = gett_all('ijlm', A/nf, B/nf)

    f1 = P / (P - 2)
    f2 = 2 / (P - 2)
    f3 = (1/(P-1))*(1/(P-2))

    denom_n = t1 - 2/P * t5 + (1/P)**2 * t9
    denom_s = P/(P-3) * (
        t1
        - f1 * t2
        + f2 * (2*t3 - t5)
        + f3 * (t6 - 2*t7 + t9)
    )
    denom_d = (P/(P-3))*(Q/(Q-1)) * (
        t1d
        - f1 * t2d
        + f2 * (2*t3d - t5d)
        + f3 * (t6d - 2*t7d + t9d)
    )

    numer_n = t6 - 2/P * t7 + (1/P)**2 * t9
    numer_s = P/(P-3) * (
        t6
        - 2/(P-1) * t7
        + 1/(P-2) * (4*t3 - P*t2)
        + f3 * (t9 - 4*t5 + 2*t1 - t6)
    )
    numer_d = (P/(P-3))*(Q/(Q-1))*(
        t6d
        - 2/(P-1) * t7d
        + 1/(P-2) * (4*t3d - P*t2d)
        + f3 * (t9d - 4*t5d + 2*t1d - t6d)
    )

    numer_s_col = t6d - 2/P * t7d + (1/P)**2 * t9d
    denom_s_col = t1d - 2/P * t5d + (1/P)**2 * t9d

    naive = [numer_n, denom_n]
    row_exp = [numer_s,  denom_s]
    col_exp = [numer_s_col, denom_s_col]
    double_exp = [numer_d, denom_d]

    return [naive, row_exp, col_exp, double_exp]


@jax.jit
def estimate_dimensionality_no_centering(A, B):
    """
    Estimate the dimensionality of the population A.
    B is either the same as A or a different trial of the same population.
    """
    P, Q = jnp.shape(A)

    nf = (P * Q) ** 0.5

    t1, t1d = gett_all('iijj', A/nf, B/nf)
    t3, t3d = gett_all('ijij', A/nf, B/nf)
    t4, t4d = gett_all('iiii', A/nf, B/nf)

    numer_n = t1
    numer_s = P/(P-1) * (t1 - t4)
    numer_s_col = Q/(Q-1) * t1d
    numer_d = P/(P-1) * Q/(Q-1) * (t1d - t4d)

    denom_n = t3
    denom_s = P/(P-1) * (t3 - t4)
    denom_s_col = Q/(Q-1) * t3d
    denom_d = P/(P-1) * Q/(Q-1) * (t3d - t4d)

    naive = [numer_n, denom_n]
    row_exp = [numer_s,  denom_s]
    col_exp = [numer_s_col, denom_s_col]
    double_exp = [numer_d, denom_d]

    return [naive, row_exp, col_exp, double_exp]


def get_dimensionality(A, B, P: int, Q: int):
    """
    Estimate the dimensionality of the population A.
    B is either the same as A or a different trial of the same population.
    """

    A = jnp.array(A)
    B = jnp.array(B)
    assert A.shape == B.shape, "A and B must have the same dimensions"
    assert A.shape == (P, Q), f"A.shape != (P, Q), {A.shape} != ({P}, {Q})"

    returns = estimate_dimensionality(A, B)
    returns = np.array(returns)

    return returns


def get_dimensionality_avg(Phi, P_ratio, Q_ratio, numit, is_jax=True):

    if Phi.ndim == 2:
        Phi = np.broadcast_to(Phi, (2, *Phi.shape))

    T_tot, P_tot, Q_tot = Phi.shape
    P = max(4, int(P_tot*P_ratio))
    Q = max(4, int(Q_tot*Q_ratio))

    if is_jax:
        Phi = jnp.array(Phi)

        def single_trial(key):
            key_T, key_P, key_Q = jax.random.split(key, 3)
            idx_T = jax.random.randint(key_T, (2,), 0, T_tot)
            #random.choice(a, size=None, replace=True, p=None)
            idx_T = jax.random.choice(key_T, jnp.arange(T_tot),shape=(2,),replace=False)
            idx_P = jax.random.permutation(key_P, P_tot)[:P]
            idx_Q = jax.random.permutation(key_Q, Q_tot)[:Q]

            Phi_a = Phi[idx_T[0]][idx_P, :][:, idx_Q]
            Phi_b = Phi[idx_T[1]][idx_P, :][:, idx_Q]
            return estimate_dimensionality(Phi_a, Phi_b)

        keys = jax.random.split(jax.random.PRNGKey(42), numit)

#        try:
#            Ms = jax.vmap(single_trial)(keys)
#            # Ms = jax.lax.map(single_trial, keys, batch_size=50)
#        except Exception as e:
#            print(e)
#            Ms = jax.lax.map(single_trial, keys, batch_size=1)
        Ms=[]
        for i in range(numit):
            key_T, key_P, key_Q = jax.random.split(keys[i], 3)
            idx_T = jax.random.randint(key_T, (2,), 0, T_tot)
            #random.choice(a, size=None, replace=True, p=None)
            idx_T = jax.random.choice(key_T, jnp.arange(T_tot),shape=(2,),replace=False)
            idx_P = jax.random.permutation(key_P, P_tot)[:P]
            idx_Q = jax.random.permutation(key_Q, Q_tot)[:Q]

            Phi_a = Phi[idx_T[0]][idx_P, :][:, idx_Q]
            Phi_b = Phi[idx_T[1]][idx_P, :][:, idx_Q]        
            Ms.append(estimate_dimensionality(Phi_a, Phi_b))
        Ms=jnp.array(Ms)
        Ms=jnp.moveaxis(Ms, 0, -1)

    else:
        Ms = []
        for i in range(numit):
            np.random.seed(i)
            #idx_T = np.random.randint(0, T_tot, 2)
            idx_T = np.random.choice(np.arange(T_tot),size=2,replace=False)
            idx_P = np.random.permutation(P_tot)[:P]
            idx_Q = np.random.permutation(Q_tot)[:Q]

            Phi_a = Phi[idx_T[0]][idx_P, :][:, idx_Q]
            Phi_b = Phi[idx_T[1]][idx_P, :][:, idx_Q]
            results = [get_dimensionality(Phi_a, Phi_b, P, Q)]

            Ms.append(results)
        Ms = np.moveaxis(Ms, 0, -1)

    return Ms


def get_dimensionality_avg2(Phi, P_ratio, Q_ratio, numit, is_jax=False):

    if Phi.ndim == 2:
        Phi = np.broadcast_to(Phi, (2, *Phi.shape))

    T_tot, P_tot, Q_tot = Phi.shape
    P = max(4, int(P_tot*P_ratio))
    Q = max(4, int(Q_tot*Q_ratio))

    Phi = jnp.array(Phi)

    def single_trial(idx_T, idx_P, idx_Q):
        Phi_a = Phi[idx_T[0]][idx_P, :][:, idx_Q]
        Phi_b = Phi[idx_T[1]][idx_P, :][:, idx_Q]
        return estimate_dimensionality(Phi_a, Phi_b)

    idx_T = []
    idx_P = []
    idx_Q = []
    for i in range(numit):

        if not is_jax:
            np.random.seed(i)
            idx_T.append(np.random.randint(0, T_tot, 2))
            idx_P.append(np.random.permutation(P_tot)[:P])
            idx_Q.append(np.random.permutation(Q_tot)[:Q])
        else:
            key = jax.random.PRNGKey(i)
            idx_T.append(jax.random.randint(key, (2,), 0, T_tot))
            idx_P.append(jax.random.permutation(key, P_tot)[:P])
            idx_Q.append(jax.random.permutation(key, Q_tot)[:Q])

    idx_T = np.array(idx_T)
    idx_P = np.array(idx_P)
    idx_Q = np.array(idx_Q)

    try:
        Ms = jax.vmap(single_trial)(idx_T, idx_P, idx_Q)
        # Ms = jax.lax.map(single_trial, idx_T, idx_P, idx_Q, batch_size=50)
    except Exception as e:
        print(e)
        Ms = []
        for i in range(numit):
            Ms.append(single_trial(idx_T[i], idx_P[i], idx_Q[i]))
        Ms = np.moveaxis(Ms, 0, -1)

    return Ms
