import numpy as np
import pandas as pd
from scipy.special import ndtr
import time
import matplotlib.pyplot as plt
import seaborn as sns
from instances import linear_instance
from core import LMC

rng = np.random.default_rng(1)
N = 20
d = 100
rho = 0.5
Sigma_X = rho ** abs(np.arange(d).reshape(1, -1) - np.arange(d).reshape(-1, 1))
X = rng.standard_normal([N, d]) @ np.linalg.cholesky(Sigma_X).T
beta = rng.standard_normal(d)
sigma = .5
Y = X @ beta + rng.standard_normal(N) * sigma

inst = linear_instance(X, Y, sigma)

Sigma_posterior = np.linalg.inv(X.T @ X / sigma**2 + np.eye(d))
mu_posterior = Sigma_posterior @ (X.T @ Y / sigma**2)
E_x = mu_posterior
E_x2 = np.diag(Sigma_posterior) + mu_posterior**2
E_indc = ndtr(mu_posterior / np.sqrt(np.diag(Sigma_posterior)))


def get_errors(samples):
    results = np.zeros(3)
    results[0] = np.mean((np.mean(samples, 0) - E_x)**2)
    results[1] = np.mean((np.mean(samples**2, 0) - E_x2)**2)
    results[2] = np.mean((np.mean(samples > 0, 0) - E_indc)**2)
    return results

nrep = 20
lr = 0.001
m_list = np.arange(10, 17)
methods = ['mc', 'cud']
errors = {}
times = {key: np.zeros((len(m_list), nrep)) for key in methods}
for j, m in enumerate(m_list):
    print(m)
    n = 2**m
    errors[('mc', m)] = np.zeros((nrep, 3))
    errors[('cud', m)] = np.zeros((nrep, 3))
    for seed in range(nrep):
        start = time.time()
        lmc = LMC(inst, lr, n, seed, cud=False)
        times['mc'][j, seed] = time.time() - start

        start = time.time()
        lmc_cud = LMC(inst, lr, n, seed, cud=True)
        times['cud'][j, seed] = time.time() - start

        errors[('mc', m)][seed] = get_errors(lmc)
        errors[('cud', m)][seed] = get_errors(lmc_cud)

