import haiku as hk
import jax.numpy as np
import jax.random as random

def interference(update_func, predict_func, qparams_main: hk.Params, qparams_target: hk.Params, s, a, r, snew, done):
    dummy_key = random.PRNGKey(0)
    interf = np.zeros((s.shape[0], s.shape[0]))
    for i in range(s.shape[0]):
        s_i = s[i:i+1, :]
        a_i = a[i:i+1, :]
        r_i = r[i:i+1, :]
        snew_i = snew[i:i+1, :]
        done_i = done[i:i+1, :]
        
        qparams_i, _, _ = update_func(qparams_main, qparams_target, s_i, a_i, r_i, snew_i, done_i)
        for j in range(s.shape[0]):
            s_j = s[j:j+1, :]
            a_j = a[j:j+1, :]

            q_mod = predict_func(qparams_i, dummy_key, s_j)
            q_mod = np.take_along_axis(q_mod, a_j, axis=1)

            q_cur = predict_func(qparams_main, dummy_key, s_j)
            q_cur = np.take_along_axis(q_cur, a_j, axis=1)

            interf = interf.at[j, i].set(q_mod.item() - q_cur.item())
    
    return interf