import numpy as np

from util import isPD, nearestPD, product_of_two_multivariate_normals, project_suff_stats, calc_posterior_params, symmetrize, symmetric_flatten_indices, generate_symmetric_from_triu, fast_sample_multivariate_normal


# https://stackoverflow.com/questions/41515522/numpy-positive-semi-definite-warning
# https://stackoverflow.com/questions/43238173/python-convert-matrix-to-positive-semi-definite
def update_sufficient_statistics(Ex2, Cov_xx_xx, Z, model_prior_params, model_params, noise_covariance, N, delete=True):

    # XX, Xy, yy
    mu_S = calc_mu_S(model_params, Ex2, N)
    Sigma_S = calc_Sigma_S(model_params, Ex2, Cov_xx_xx, N)

    Z_vec = np.hstack((Z['XX'].ravel(), Z['Xy'].ravel(), Z['yy']))[:, None]

    if delete:
        tril_indices = symmetric_flatten_indices(Z['XX'].shape[0])
        d = Z['XX'].shape[0]
        tril_indices = np.append(tril_indices, int(d **2) - 1)
        keep = np.setdiff1d(np.arange(len(mu_S)), np.array(tril_indices), assume_unique=True)

        triu_indices = symmetric_flatten_indices(Z['XX'].shape[0], lower=False)
        noise_covariance = noise_covariance.copy()
        noise_covariance[np.ix_(triu_indices, triu_indices)] /= 2

        mu_S = mu_S[keep]
        Sigma_S = Sigma_S[np.ix_(keep, keep)]
        Z_vec = Z_vec[keep]
        noise_covariance = noise_covariance[np.ix_(keep, keep)]

    mu_combined, Sigma_combined = product_of_two_multivariate_normals(Z_vec, noise_covariance,
                                                                      mu_S, Sigma_S)

    tries = 0
    while not isPD(Sigma_combined) and tries < 3:
        Sigma_combined = nearestPD(Sigma_combined)
        tries += 1

    if False and tries > 1: print(tries)

    max_tries = 20
    for tries in range(max_tries):
        # S_vec = draw_S_vec(mu_combined, Sigma_combined)
        # S_vec = np.random.multivariate_normal(mean=mu_combined.ravel(), cov=Sigma_combined, check_valid='warn')
        S_vec = fast_sample_multivariate_normal(mean=mu_combined.ravel(), cov=Sigma_combined)

        d = Z['XX'].shape[0]
        dim = int(d  * (d+1) /2) - 1
        if delete:
            S = {'XX': generate_symmetric_from_triu(d,  np.append(S_vec[:dim], N)),#S_vec[:d ** 2].reshape((d, d)),
                 'Xy': S_vec[dim:-1][:,None],#S_vec[d ** 2:-1][:, None],
                 'yy': S_vec[-1]}
        else:
            S = {'XX': symmetrize(S_vec[:d ** 2].reshape((d, d))),
                 'Xy': S_vec[d ** 2:-1][:, None],
                 'yy': S_vec[-1]}

        projected = False
        # Only project the good ones
        if tries <= int(max_tries * 0.75)  and S['yy'] > 0 and (np.diag(S['XX']) > 0).all():
            S = project_suff_stats(S) # must project before calc_posterior_params
        elif S['yy'] > 0 and S['XX'][0, 0] > 0:
            S = project_suff_stats(S)
        else:
            continue
        projected = True

        if calc_posterior_params(S, N, model_prior_params)[-1] > .1:
        # if S['XX'][0, 0] > 0 and S['yy'] > 0 and calc_posterior_params(S, N, model_prior_params)[-1] > .1:
            return S

#     print("Failed")
    if not projected:
#         print("Projected")
        S = project_suff_stats(S)

    return S


def draw_S_vec(mu_combined, Sigma_combined):

    S_vec = np.random.multivariate_normal(mean=mu_combined.flatten(), cov=Sigma_combined, check_valid='warn')

    return S_vec


def calc_Cov_xx_xx(Ex2, Ex4):

    Cov_xx_xx = Ex4 - np.einsum('ij,kl->ijkl', Ex2, Ex2)

    return Cov_xx_xx


def calc_mu_S(model_params, Ex2, N):

    theta, sigma_squared = model_params
    theta_flat = theta.ravel()
    Exy = np.einsum('j,ij->i', theta_flat, Ex2)
    Ey2 = sigma_squared + np.dot(theta_flat, Exy)

    mu_S = N * np.hstack((Ex2.ravel(),
                          Exy.ravel(),
                          Ey2)
                         )[:, None]

    return mu_S


def calc_Ey2(Ex2, model_params):

    theta, sigma_squared = model_params
    d = len(theta)

    theta_flat = theta.ravel()
    Ey2 = sigma_squared + np.einsum('i,j,ij->', theta_flat, theta_flat, Ex2)


    return Ey2


def calc_Sigma_S(model_params, Ex2, Cov_xx_xx, N):

    theta, sigma_squared = model_params
    d = len(theta)

    theta_flat = theta.ravel()


    Cov_xx_xy = np.einsum('l,ijkl->ijk', theta_flat, Cov_xx_xx)
    Cov_xx_yy = np.einsum('k,ijk->ij', theta_flat, Cov_xx_xy)

    Cov_xy_xy = np.einsum('j,ijk->ik', theta_flat, Cov_xx_xy) + sigma_squared * Ex2

    tri_theta = np.einsum('j,ij->i', theta_flat, Cov_xx_yy)

    uni_eta = np.einsum('j,ij->i', theta_flat, Ex2)
    Cov_xy_yy = tri_theta + 2 * sigma_squared * uni_eta

    four_theta = np.dot(theta_flat, tri_theta)
    bi_eta = np.dot(theta_flat, uni_eta)
    Cov_yy_yy = 2 * sigma_squared ** 2 + 4 * sigma_squared * bi_eta + four_theta

    if not isPD(Cov_xy_xy):
        Cov_xy_xy = nearestPD(Cov_xy_xy)

    Sigma_S = np.hstack((Cov_xx_xx.reshape(d ** 2, d ** 2),
                         Cov_xx_xy.reshape(d ** 2, d),
                         Cov_xx_yy.reshape(d ** 2, 1)
                         ))
    Sigma_S = np.vstack((Sigma_S,
                         np.hstack((Cov_xx_xy.reshape(d ** 2, d).T,
                                    Cov_xy_xy.reshape(d, d),
                                    Cov_xy_yy.reshape(d, 1)
                                    ))))
    Sigma_S = np.vstack((Sigma_S,
                         np.hstack((Cov_xx_yy.reshape(d ** 2, 1).T,
                                    Cov_xy_yy.reshape(d, 1).T,
                                    Cov_yy_yy.reshape(1, 1))
                                   )))

    Sigma_S *= N

    return Sigma_S
