import numpy as np
import time
from instances import logistic_instance
from core import SGLD

d = 10 
N = 100
rng = np.random.default_rng(1)

rho = .5
Sigma_X = rho ** abs(np.arange(d).reshape(1, -1) - np.arange(d).reshape(-1, 1))
X = rng.standard_normal([N, d]) @ np.linalg.cholesky(Sigma_X).T
beta = rng.standard_normal(d)
prob = 1  / (1 + np.exp(-X @ beta))
Y = rng.binomial(1, prob)
XY = np.concatenate([X, Y[:, None]], axis=1)

inst = logistic_instance(X, Y)
truth = np.loadtxt(f'logistic_truth_d_{d}_N_{N}_rho_{rho}.csv')


def get_errors(samples):
    results = np.zeros(3)
    results[0] = np.mean((np.mean(samples, 0) - truth[0])**2)
    results[1] = np.mean((np.mean(samples**2, 0) - truth[1])**2)
    results[2] = np.mean((np.mean(samples > 0, 0) - truth[2])**2)
    return results

lr = 0.001
nrep = 10
m_list = np.arange(10, 17)
methods = ['mc', 'cud']
errors = {}
times = {key: np.zeros((len(m_list), nrep)) for key in methods}
for j, m in enumerate(m_list):
    print(m)
    n = 2**m
    errors[('mc', m)] = np.zeros((nrep, 3))
    errors[('cud', m)] = np.zeros((nrep, 3))
    for seed in range(nrep):
        start = time.time()
        lmc = SGLD(inst, lr, n, seed, M=10, cud=False)
        times['mc'][j, seed] = time.time() - start

        start = time.time()
        lmc_cud = SGLD(inst, lr, n, seed, M=10, cud=True)
        times['cud'][j, seed] = time.time() - start

        errors[('mc', m)][seed] = get_errors(lmc)
        errors[('cud', m)][seed] = get_errors(lmc_cud)

