import numpy as np 
import itertools 
import torch 

def symmetrize(tensor, batch_index = 0):
    assert batch_index == 0  
    dim = len(tensor.shape) - 1 

    perms = list(itertools.permutations(range(1, dim + 1), dim))

    s_tensor = np.zeros_like(tensor)
    for perm in perms: 
        s_tensor += tensor.transpose((0,) + perm)

    return s_tensor/len(perms)


def harmonic_proj(tensor, symmetric = True, order = 2, verbose = False): 
    assert symmetric 

    if order == 2: 
        tr = np.einsum('tii -> t', tensor)
        delta = np.eye(3)[None, :, :]
        if verbose:
            print(delta.shape, tr.shape)
        return tensor - 1/3*np.einsum('t, tij -> tij', tr, delta)
    
    if order == 4: 
        tr1 = np.einsum('tppij -> tij', tensor)
        tr2 = np.einsum('tppqq -> t', tensor)
        delta = np.eye(3)[None, :, :] 

        h_tensor = tensor.copy() 
        h_tensor -= 6/7*symmetrize(np.einsum('tij, tkl -> tijkl', delta, tr1))
        h_tensor += 3/35*symmetrize(np.einsum('tij, tkl -> tijkl', delta, delta))*tr2[:, None, None, None, None] 

        return h_tensor 
    
def symmetrize_torch(tensor, batch_index = 0):

    assert batch_index == 0  
    dim = len(tensor.shape) - 1 

    perms = list(itertools.permutations(range(1, dim + 1), dim))

    s_tensor = torch.zeros_like(tensor)
    for perm in perms: 
        s_tensor += tensor.permute((0,) + perm)

    return s_tensor/len(perms)


def harmonic_proj_torch(tensor, symmetric = True, order = 2, verbose = False): 
    assert symmetric 
    device = tensor.device
    dtype = tensor.dtype
    B = tensor.shape[0]

    if order <= 1: return tensor 

    if order == 2: 
        tr = torch.einsum('tii -> t', tensor)
        delta = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)
        if verbose:
            print(delta.shape, tr.shape)
        return tensor - 1/3*torch.einsum('t, tij -> tij', tr, delta)
    
    if order == 3: 
        trace = torch.einsum('tiij -> tj', tensor)
        delta = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)

        h_tensor = tensor.clone()
        h_tensor -= 3/5*symmetrize_torch(torch.einsum('tij, tk -> tijk', delta, trace))
        return h_tensor 

    if order == 4: 
        tr1 = torch.einsum('tppij -> tij', tensor)
        tr2 = torch.einsum('tppqq -> t', tensor)
        delta = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)  # (B,3,3)

        h_tensor = tensor.clone() 
        h_tensor -= 6/7*symmetrize_torch(torch.einsum('tij, tkl -> tijkl', delta, tr1))
        h_tensor += 3/35*symmetrize_torch(torch.einsum('tij, tkl -> tijkl', delta, delta))*tr2[:, None, None, None, None] 

        return h_tensor 
    
    if order >= 5: 
        raise NotImplementedError("Harmonic projection for order > 4 not implemented yet.")
    
def STF_projection_torch(tensor, order = 2, verbose = False): 

    return harmonic_proj_torch(symmetrize_torch(tensor), symmetric=True, order=order, verbose=verbose)
    

def trace_product_r(A, B, r):
    """
    Contract r indices between symmetric tensors A (order n) and B (order m).
    Returns a tensor of order n + m - 2r.
    """
    # Example: contract last r indices of A with first r indices of B
    n = A.ndim -1 
    m = B.ndim -1

    if r > min(n, m):
        print("zero trace.")
        return torch.zeros(1)
    
    idx = [chr(i) for i in range(97, 97 + n + m)]  # ['a', 'b', ...]
    idx_A = idx[:n]
    idx_B = idx[n - r:n - r + m]
    idx_out = idx[:n - r] + idx[n:n + m - r]
    einsum_str = f"{'z'+''.join(idx_A)},{'z'+''.join(idx_B)}->{'z'+''.join(idx_out)}"

    return torch.einsum(einsum_str, A, B)

def tensor_power(tensor, power): 
    if power == 1: return tensor 
    return trace_product_r(tensor, tensor_power(tensor, power - 1), r = 0)

def trace_torch(A, r): 
    trace_str = [chr(97 +i)*2 for i in range(r)]
    trace_str = ''.join(trace_str)
    return torch.einsum(f'z{trace_str}...->z...', A)


def cross_product_torch(A, B): 
    
    batch_size = A.shape[0] 

    epsilon_tensor = torch.zeros([batch_size, 3, 3, 3])
    epsilon_tensor[:, 0, 1, 2] = epsilon_tensor[:, 1, 2, 0] = epsilon_tensor[:, 2, 0, 1] = 1
    epsilon_tensor[:, 0, 2, 1] = epsilon_tensor[:, 2, 1, 0] = epsilon_tensor[:, 1, 0, 2] = -1

    A_strs = "z"+"".join(["x"] + [chr(97 +i) for i in range(A.dim() -2)])
    B_strs = "z"+"".join(["y"] + [chr(97 +i) for i in range(A.dim() -2, A.dim() -2 + B.dim()-2)])

    eps_str = "zwxy"

    C_strs = "zw" + "".join([A_strs[2:], B_strs[2:]])

    ein_str = f"{eps_str},{A_strs},{B_strs}->{C_strs}"

    tensor_cross_product = symmetrize_torch(torch.einsum(ein_str, epsilon_tensor, A, B)) 
    
    return tensor_cross_product 

def even_tvec_torch(A, B, r): 
    tensor_trace = trace_product_r(A, B, r//2)
    return STF_projection_torch(tensor_trace, order = tensor_trace.ndim - 1, verbose=False) 

def odd_tvec_torch(A, B, r): 
    cross_product = cross_product_torch(A, B)
    tensor_trace = trace_torch(cross_product, (r - 1)//2)
    return harmonic_proj_torch(tensor_trace, order = tensor_trace.ndim - 1, verbose=False)

def degree_2_invariants_torch(Rhn, Dhn, Qsn): 

    I1 = torch.einsum('tijk, tijk -> t', Qsn, Qsn)
    I2 = torch.einsum('tij, tij -> t', Rhn, Rhn)
    I3 = torch.einsum('tij, tij -> t', Dhn, Dhn)
    I4 = torch.einsum('tij, tij -> t', Rhn, Dhn)

    return torch.vstack((I1, I2, I3, I4)).T.float()

def degree_2_tensor_basis_model_torch(x, coeffs):

    Rhn, Dhn, Qsn = x 

    coefficients = (coeffs[:].T).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    epsilon_tensor = torch.zeros_like(Qsn)
    epsilon_tensor[:, 0, 1, 2] = epsilon_tensor[:, 1, 2, 0] = epsilon_tensor[:, 2, 0, 1] = 1
    epsilon_tensor[:, 0, 2, 1] = epsilon_tensor[:, 2, 1, 0] = epsilon_tensor[:, 1, 0, 2] = -1

    T1 = coefficients[0] * STF_projection_torch(torch.einsum("timn, tmjp, tqn->tijpq", epsilon_tensor, Qsn, Rhn), order = 4)
    T2 = coefficients[1] * STF_projection_torch(torch.einsum("timn, tmjp, tqn->tijpq", epsilon_tensor, Qsn, Dhn), order = 4)
    T3 = coefficients[2] * STF_projection_torch(torch.einsum("tijm, tpqn->tijpq", Qsn, Qsn), order = 4)
    T4 = coefficients[3] * STF_projection_torch(torch.einsum("tij,tpq->tijpq", Rhn, Rhn), order = 4)
    T5 = coefficients[4] * STF_projection_torch(torch.einsum("tij,tpq->tijpq", Dhn, Dhn), order = 4)
    T6 = coefficients[5] * STF_projection_torch(torch.einsum("tij,tpq->tijpq", Rhn, Dhn), order = 4)
    # print(coefficients.shape, T1.shape)
    
    ret = T1 + T2 + T3 + T4 + T5 + T6 
    return ret 


def degree_3_invariants_torch(Rhn, Dhn, Qsn):
    I1 = torch.einsum('tijk, tijk -> t', Qsn, Qsn)
    I2 = torch.einsum('tij, tij -> t', Rhn, Rhn)
    I3 = torch.einsum('tij, tij -> t', Dhn, Dhn)
    I4 = torch.einsum('tij, tij -> t', Rhn, Dhn)

    # '[(p, [(p, p)_2]^1)_4]^1',

    I5_ = even_tvec_torch(Rhn, Rhn, 2)
    I5 = even_tvec_torch(Rhn, I5_, 4)

    #   '[(q, [(q, q)_2]^1)_4]^1',

    I6_ = even_tvec_torch(Dhn, Dhn, 2)
    I6 = even_tvec_torch(Dhn, I6_, 4)


    # '[(p, [(q, q)_2]^1)_4]^1',

    I7_ = even_tvec_torch(Dhn, Dhn, 2)
    I7 = even_tvec_torch(Rhn, I7_, 4)

    # '[(q, [(p, p)_2]^1)_4]^1',

    I8_ = even_tvec_torch(Rhn, Rhn, 2)
    I8 = even_tvec_torch(Dhn, I8_, 4)

    # '[(f, [(p, q)_1]^1)_6]^1',
    I9_ = odd_tvec_torch(Rhn, Dhn, 1)
    I9 = even_tvec_torch(Qsn, I9_, 6)

    # '[([(f, f)_4]^1, q)_4]^1',
    I10_ = even_tvec_torch(Qsn, Qsn, 4)
    I10 = even_tvec_torch(I10_, Dhn, 4)

    # '[([(f, f)_4]^1, p)_4]^1'
    I11_ = even_tvec_torch(Qsn, Qsn, 4)
    I11 = even_tvec_torch(I11_, Rhn, 4)

    return torch.vstack((I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11)).T.float()

def degree_3_tensor_basis_model_torch(x, coeffs):
    Rhn, Dhn, Qsn = x 
    coefficients = (coeffs[:].T).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

    epsilon_tensor = torch.zeros_like(Qsn)
    epsilon_tensor[:, 0, 1, 2] = epsilon_tensor[:, 1, 2, 0] = epsilon_tensor[:, 2, 0, 1] = 1
    epsilon_tensor[:, 0, 2, 1] = epsilon_tensor[:, 2, 1, 0] = epsilon_tensor[:, 1, 0, 2] = -1


    T1 = STF_projection_torch(torch.einsum("timn, tmjp, tqn->tijpq", epsilon_tensor, Qsn, Rhn), order = 4)
    T2 = STF_projection_torch(torch.einsum("timn, tmjp, tqn->tijpq", epsilon_tensor, Qsn, Dhn), order = 4)
    T3 = STF_projection_torch(torch.einsum("tijm, tpqn->tijpq", Qsn, Qsn), order = 4)
    T4 = STF_projection_torch(torch.einsum("tij,tpq->tijpq", Rhn, Rhn), order = 4)
    T5 = STF_projection_torch(torch.einsum("tij,tpq->tijpq", Dhn, Dhn), order = 4)
    T6 = STF_projection_torch(torch.einsum("tij,tpq->tijpq", Rhn, Dhn), order = 4)

    #'[(f, [(p, p)_2]^1)_1]^1',

    T7 = odd_tvec_torch(Qsn, even_tvec_torch(Rhn, Rhn, 2), 1)

    #  '[(f, [(q, q)_2]^1)_1]^1',
    T8 = odd_tvec_torch(Qsn, even_tvec_torch(Dhn, Dhn, 2), 1)

    # '[(f, [(p, q)_2]^1)_1]^1',
    T9 = odd_tvec_torch(Qsn, even_tvec_torch(Rhn, Dhn, 2), 1)

    # '[(f, [(p, q)_1]^1)_2]^1',
    T10 = even_tvec_torch(Qsn, odd_tvec_torch(Rhn, Dhn, 1), 2)

    # '[(f, [(f, p)_4]^1)_0]^1',
    T11 = even_tvec_torch(Qsn, even_tvec_torch(Qsn, Rhn, 4), 0)

    # '[(f, [(f, q)_4]^1)_0]^1',
    T12 = even_tvec_torch(Qsn, even_tvec_torch(Qsn, Dhn, 4), 0)

    # '[([(f, f)_2]^1, p)_2]^1',
    T13 = even_tvec_torch(even_tvec_torch(Qsn, Qsn, 2), Rhn, 2)

    # '[([(f, f)_2]^1, q)_2]^1',
    T14 = even_tvec_torch(even_tvec_torch(Qsn, Qsn, 2), Dhn, 2)

    # '[([(f, f)_4]^1, f)_1]^1',
    T15 = odd_tvec_torch(even_tvec_torch(Qsn, Qsn, 4), Qsn, 1)

    # '[(p, [(f, f)_4]^1)_0]^1',
    T16 = even_tvec_torch(Rhn, even_tvec_torch(Qsn, Qsn, 4), 0)

    # '[(p, [(f, p)_3]^1)_0]^1',
    T17 = even_tvec_torch(Rhn, odd_tvec_torch(Qsn, Rhn, 3), 0)

    # '[(p, [(f, q)_3]^1)_0]^1',
    T18 = even_tvec_torch(Rhn, odd_tvec_torch(Qsn, Dhn, 3), 0)

    # '[(q, [(f, f)_4]^1)_0]^1',
    T19 = even_tvec_torch(Dhn, even_tvec_torch(Qsn, Qsn, 4), 0)

    # '[(q, [(f, p)_3]^1)_0]^1',
    T20 = even_tvec_torch(Dhn, odd_tvec_torch(Qsn, Rhn, 3), 0)

    # '[(q, [(f, q)_3]^1)_0]^1',
    T21 = even_tvec_torch(Dhn, odd_tvec_torch(Qsn, Dhn, 3), 0)

    # '[([(p, p)_2]^1, p)_0]^1',
    T22 = even_tvec_torch(even_tvec_torch(Rhn, Rhn, 2), Rhn, 0)

    #  '[([(p, p)_2]^1, q)_0]^1',
    T23 = even_tvec_torch(even_tvec_torch(Rhn, Rhn, 2), Dhn, 0)

    # '[([(q, q)_2]^1, p)_0]^1',
    T24 = even_tvec_torch(even_tvec_torch(Dhn, Dhn, 2), Rhn, 0)

    # '[([(q, q)_2]^1, q)_0]^1',
    T25 = even_tvec_torch(even_tvec_torch(Dhn, Dhn, 2), Dhn, 0)

    # '[([(p, q)_2]^1, p)_0]^1',
    T26 = even_tvec_torch(even_tvec_torch(Rhn, Dhn, 2), Rhn, 0)

    # '[([(p, q)_2]^1, q)_0]^1']
    T27 = even_tvec_torch(even_tvec_torch(Rhn, Dhn, 2), Dhn, 0)

    ret = coefficients[0]*T1 
    ret += coefficients[1]*T2
    ret += coefficients[2]*T3
    ret += coefficients[3]*T4
    ret += coefficients[4]*T5
    ret += coefficients[5]*T6
    ret += coefficients[6]*T7
    ret += coefficients[7]*T8
    ret += coefficients[8]*T9
    ret += coefficients[9]*T10
    ret += coefficients[10]*T11
    ret += coefficients[11]*T12
    ret += coefficients[12]*T13
    ret += coefficients[13]*T14
    ret += coefficients[14]*T15
    ret += coefficients[15]*T16
    ret += coefficients[16]*T17
    ret += coefficients[17]*T18
    ret += coefficients[18]*T19
    ret += coefficients[19]*T20
    ret += coefficients[20]*T21
    ret += coefficients[21]*T22
    ret += coefficients[22]*T23
    ret += coefficients[23]*T24
    ret += coefficients[24]*T25
    ret += coefficients[25]*T26
    ret += coefficients[26]*T27
    return ret

def rapid_pressure_strain_rate(M, G): 
    '''
    RPSR:
        T_ij = 2.G_{nm}.(M_{imnj} + M_{jmni})

    Inputs:
        M: Mijpq 
            Shape-[*, 3, 3, 3, 3]
        G: Mean velocity gradient 
            Shape-[*, 3, 3]
    Outputs: 
        Tij - RPSR 
            Shape - [*, 3, 3]

    Reference: 2.8.6 in Kassinos. 
    '''
    
    Tij = 2*np.einsum('timnj, tnm -> tij', M, G)
    # Tij += 2*np.einsum('tjmni, tnm -> tij', M, G)
    Tij += 2*np.einsum('tjmni, tnm -> tij', M, G)

    return Tij 


def mean_velocity_gradients_from_parameters(param_values, 
                                           normalization = False): 
    '''
    du/dy = f(p1, p2, p3, p4)
    Inputs: 
        param_values 
    Returns:
        dudy 
        Shape - [*, 3, 3] 
    '''
    
    G_array = [] 
    num_cases = len(param_values)
    
    for i in range(num_cases):
    
        alpha = param_values[i, 0]
        s2 = param_values[i, 1]
        theta = param_values[i, 2]
        phi = param_values[i, 3]
    
        w1 = np.sin(phi)*np.cos(theta)
        w2 = np.sin(phi)*np.sin(theta)
        w3 = np.cos(phi)
    
        G = alpha*np.diag([1, s2, -1-s2])
        G += (1-abs(alpha))*np.array([[0, w3, w2], [-w3, 0, w1], [-w2, -w1, 0]])
    
        if normalization: 
            Sij = 0.5*(G + G.T)
            SijSij = Sij*Sij
            Q = np.sqrt(np.sum(SijSij))
        else: 
            Q = 1 
            
        G_array.append(G/Q)
    return np.array(G_array)


def rapid_pressure_strain_rate(M, G): 
    '''
    RPSR:
        T_ij = 2.G_{nm}.(M_{imnj} + M_{jmni})

    Inputs:
        M: Mijpq 
            Shape-[*, 3, 3, 3, 3]
        G: Mean velocity gradient 
            Shape-[*, 3, 3]
    Outputs: 
        Tij - RPSR 
            Shape - [*, 3, 3]

    Reference: 2.8.6 in Kassinos. 
    '''
    
    Tij = 2*np.einsum('timnj, tnm -> tij', M, G)
    # Tij += 2*np.einsum('tjmni, tnm -> tij', M, G)
    Tij += 2*np.einsum('tjmni, tnm -> tij', M, G)

    return Tij 

def get_M_star_from_M(M):
    '''
    M* - fully symmetric part of M. 

    Inputs: 
        M: shape - [*, 3, 3, 3, 3]
    Outputs: 
        M*: shape - [*, 3, 3, 3, 3]
    Note: 
        It assumes M is symmetric in 
        ij and pq. M_ijpq. 
    '''
    Ms = np.zeros(M.shape)
    Ms += np.einsum('tijpq -> tijpq', M)
    Ms += np.einsum('tipqj -> tijpq', M)
    Ms += np.einsum('tiqjp -> tijpq', M)
    Ms += np.einsum('tpjiq -> tijpq', M)
    Ms += np.einsum('tqjip -> tijpq', M)
    Ms += np.einsum('tpqij -> tijpq', M)

    return 1/6*Ms 




def M_star_model_b_y_linear(xdata, *coefs, flatten = True):
    '''
    M* - fully symmetric part of the M_ijpq tensor. 

    Inputs: 
        b - reynolds stress anisotropy. 
        y - dimensionality anisotropy. 

    '''
    
    # print(coefs)
    coefficients = np.array(coefs).flatten()
    # print(coefficients.shape)
    b = xdata[0]
    y = xdata[1]
    
    # invariants generator 

    num_time_steps = b.shape[0]

    I = np.array([np.eye(3) for i in range(num_time_steps)])   

    M = np.zeros([num_time_steps, 3, 3, 3, 3])
    
    # invariants of the first kind. 


    def expand_type_1_term(a): 
        '''
        returns a_ij.a_pq + a_ip.a_jq + a_iq.a_jp 
        '''
        
        i1 = np.einsum('tij, tpq -> tijpq', a, a) 
        i1 += np.einsum('tip, tjq -> tijpq', a, a) 
        i1 += np.einsum('tiq, tjp -> tijpq', a, a) 

        return i1

    
    def expand_type_2_term(a, b): 
        '''
        returns 
        a_ij.b_pq + a_ip.b_jq + a_iq.b_jp +
        a_jp.b_iq + a_jq.b_ip + a_pq.b_ij 
        '''

        i1 = np.einsum('tij, tpq -> tijpq', a, b) 
        i1 += np.einsum('tip, tjq -> tijpq', a, b) 
        i1 += np.einsum('tiq, tjp -> tijpq', a, b) 
        i1 += np.einsum('tjp, tiq -> tijpq', a, b) 
        i1 += np.einsum('tjq, tip -> tijpq', a, b) 
        i1 += np.einsum('tpq, tij -> tijpq', a, b) 

        return i1 

    M += coefficients[0]*expand_type_1_term(I)
    M += coefficients[1]*expand_type_2_term(I, b)
    M += coefficients[2]*expand_type_2_term(I, y)
    
    return M.flatten() if flatten else M 


def normalized_error_T(a, b):
    '''
    Evaluates:
        (D_ijpq.D_ijpq)^0.5/(A_ijpq.A_ijpq)^0.5
        D = A-B
    Inputs: 
        A: (*, 3, 3)
        B: (*, 3, 3)
    Outputs: 
        Error (*)
    '''
    
    diff = np.abs(a-b)
    return np.sum(diff*diff, axis = (1, 2))**0.5/(np.sum(b*b, axis = (1, 2))**0.5)



def M_decomposition(M_star, Q_star, R, D): 

    num_cases = M_star.shape[0]
    
    levi_civita_tensor = np.zeros([num_cases, 3,3,3])

    def levi_civita(i, j, k):

        i+=1 
        j+=1 
        k+=1 
    
        if (i, j, k) in [(1,2,3), (2,3,1), (3,1,2)]: return 1 
        elif (i, j, k) in [(3,2,1), (1,3,2), (2,1,3)]: return -1
        else: return 0
    
    for i in range(3):
        for j in range(3): 
            for k in range(3): 
                levi_civita_tensor[:, i, j, k] = levi_civita(i, j, k)


    q_sq = np.einsum('tii-> t', R).reshape(-1, 1, 1, 1, 1)
    I = np.array([np.eye(3) for i in range(num_cases)])   

    # print(q_sq)
    
    M = np.zeros(M_star.shape)

    M += M_star 

    M += 1/2*(np.einsum('tzkj, tzil -> tijkl', levi_civita_tensor, Q_star))
    M += -1/2*(np.einsum('tzil, tzkj -> tijkl', levi_civita_tensor, Q_star))

    M += 1/6*(np.einsum('til, tjk -> tijkl', I, I))*q_sq
    M += 1/6*(np.einsum('tik, tlj -> tijkl', I, I))*q_sq
    M += -2*1/6*(np.einsum('tij, tkl -> tijkl', I, I))*q_sq

    M += 3*1/6*(np.einsum('tkl, tij -> tijkl', I, R))
    M += 3*1/6*(np.einsum('tij, tkl -> tijkl', I, D))

    M += 1/6*(np.einsum('tkl, tij -> tijkl', I, D))
    M += 1/6*(np.einsum('tij, tkl -> tijkl', I, R))

    M += -1/6*(np.einsum('til, tkj -> tijkl', I, R))
    M += -1/6*(np.einsum('til, tkj -> tijkl', I, D))

    M += -1/6*(np.einsum('tkj, til -> tijkl', I, R))
    M += -1/6*(np.einsum('tkj, til -> tijkl', I, D))

    M += -1/6*(np.einsum('tik, tlj -> tijkl', I, R))
    M += -1/6*(np.einsum('tik, tlj -> tijkl', I, D))

    M += -1/6*(np.einsum('tlj, tki -> tijkl', I, R))
    M += -1/6*(np.einsum('tlj, tki -> tijkl', I, D))

    return M 


def rapid_term_GLM(bij_array, dudy, q2_array, C2, C3): 
    # Rapid Term - General Linear Model 
    # Using bij. 
    # page 167 of Durbin 
    # contains both LRR, IP. 

    # k = rii/2 
    k_array = q2_array/2 
    delta = np.eye(3)[None, :, :]
    rapid = 2/5*k_array[:, None, None]*(dudy + dudy.transpose(0,2,1))
    rapid += k_array[:, None, None]*C2*(np.einsum('tik, tkj->tij', bij_array, dudy) + np.einsum('tjk,tki->tij', bij_array, dudy) -2/3*np.einsum('tij, tkl, tkl-> tij', delta, bij_array, dudy))
    rapid += k_array[:, None, None]*C3*(np.einsum('tik, tjk->tij', bij_array, dudy) + np.einsum('tjk,tik->tij', bij_array, dudy) -2/3*np.einsum('tij, tkl, tkl->tij', delta, bij_array, dudy))
    return rapid 