import numpy as np
import argparse
import itertools
from instances import CRE_instance
from core import LMC, stepsize_schedule

I = 3
J = 5
d = I + J + 3
rng = np.random.default_rng(1)
mu_t = rng.normal(size=1)
s_a_t = np.exp(rng.normal(1))
s_b_t = np.exp(rng.normal(1))
a_t = rng.normal(size=(I, 1)) * s_a_t
b_t = rng.normal(size=(J, 1)) * s_b_t
Y = rng.normal(size=(I, J)) + mu_t + np.tile(a_t, (1, J)) + np.tile(b_t.T, (I, 1))

inst = CRE_instance(Y)
truth = np.loadtxt('cre_mala.txt')

stepsize = 'decreasing' # 0.01, 0.0001

nrep = 20
m_list = np.arange(10, 17)
methods = ['mc', 'cud']
errors = {}

for j, m in enumerate(m_list):
    print(m)
    n = 2**m
    if stepsize == 'decreasing':
        lr = stepsize_schedule(n, 0.01, 0.0001, gamma=0.33)
    else:
        lr = stepsize
    errors[('mc', m)] = np.zeros((nrep, 3))
    errors[('cud', m)] = np.zeros((nrep, 3))
    for seed in range(nrep):    
        lmc = LMC(inst, lr, n, seed, cud=False)
        errors[('mc', m)][seed] = np.mean((np.mean(lmc, 0) - truth)**2)

        lmc_cud = LMC(inst, lr, n, seed, cud=True)
        errors[('cud', m)][seed] = np.mean((np.mean(lmc_cud, 0) - truth)**2)

