import numpy as np
from tqdm import trange
from instances import CRE_instance

I = 3
J = 5
d = I + J + 3
rng = np.random.default_rng(1)
mu_t = rng.normal(size=1)
s_a_t = np.exp(rng.normal(1))
s_b_t = np.exp(rng.normal(1))
a_t = rng.normal(size=(I, 1)) * s_a_t
b_t = rng.normal(size=(J, 1)) * s_b_t
Y = rng.normal(size=(I, J)) + mu_t + np.tile(a_t, (1, J)) + np.tile(b_t.T, (I, 1))

inst = CRE_instance(Y)
maxit = int(1e7)
burnin = 1000
lr = 0.01
theta_t = rng.normal(size=(d,))
est = np.zeros(d)
traj = []
i = 0
for t in trange(maxit):
    g = inst.U_grad(theta_t)
    eta = rng.standard_normal(d)
    update = -lr * g + np.sqrt(2 * lr) * eta
    theta_new = theta_t + update
    ratio = min(1, np.exp(inst.U(theta_t) - inst.U(theta_new)))
    if rng.uniform() < ratio:
        theta_t = theta_new
    if t > burnin:
        est = est * i / (i + 1) + theta_t / (i + 1)
        i += 1
    if (t+1) % 100000 == 0:
        print(est)
        
np.savetxt(f'cre_mala.txt', est)

