import numpy as np
import scipy.special as special

from autograd import grad, hessian
from methods.reup import bayesian_utils

def projected_gradient_descent(R_ij, M_ij, prior_Sigma, prior_m, m, tau, d, iterations, lr):
    curr_Sigma = prior_Sigma
    lst_loss = []

    #epsilon
    inv_prior_Sigma = np.linalg.inv(prior_Sigma)
    inner_prod = np.sqrt(np.trace( (inv_prior_Sigma + tau * R_ij * M_ij) @ (inv_prior_Sigma + tau * R_ij * M_ij).T ))
    epsilon = d ** (1 - d) * np.exp( - m / prior_m * inner_prod * (np.sqrt(d) + d) )

    for it in range(iterations):
        inv_curr_Sigma = np.linalg.inv(curr_Sigma)
        grad = - prior_m / m * inv_curr_Sigma + inv_prior_Sigma + tau * R_ij * M_ij

        #grad_descent
        curr_Sigma = curr_Sigma - lr * grad

        #projection
        curr_Sigma = projected_D(curr_Sigma, epsilon, d)
        
        #store the loss
        l_loss = l_posterior(curr_Sigma, m, inv_prior_Sigma, prior_m, M_ij, R_ij, tau)
        lst_loss.append(l_loss)

    big_l_loss = big_l_posterior(curr_Sigma, m, inv_prior_Sigma, prior_m, M_ij, R_ij, tau, d)

    return curr_Sigma, big_l_loss, lst_loss

def projected_D(S, epsilon, d):
    lmbd, V = np.linalg.eigh(S)
    y = lmbd - epsilon
    y_ast = projection_simplex(y, epsilon, d)
    lmbd_ast = y_ast + epsilon
    lmbd_ast_matrix = np.diag(lmbd_ast)
    Sigma = V @ lmbd_ast_matrix @ V.T
    return Sigma

def K_set(u, d, epsilon):
  sum_set = np.array([ ( u[np.arange(k + 1)].sum() - (1 - epsilon) * d ) / (k + 1) for k in range(d)])
  K = np.argwhere(sum_set < u).reshape(-1, )
  return K

def projection_simplex(x, epsilon, d):
  u = np.sort(x)
  M = np.max(K_set(u, d, epsilon))
  indices = np.arange(M + 1)
  eta = (u[indices].sum() - (1 - epsilon) * d) / (M + 1)
  z = np.maximum(x - eta, 0.1 * np.ones(d))
  return z

def big_l_posterior(pos_Sigma, pos_m, inv_prior_Sigma, prior_m, M_ij, R_ij, tau, d):
    multivariate_digamma = bayesian_utils.multivariate_digamma(pos_m/2 ,d)
    trace_term = pos_m * np.trace((inv_prior_Sigma + tau * R_ij * M_ij) @ pos_Sigma)
    big_ell = trace_term - pos_m/2 * d - special.multigammaln(pos_m/2 , d) + (pos_m - prior_m) /2 * multivariate_digamma
    return big_ell

def l_posterior(pos_Sigma, pos_m, inv_prior_Sigma, prior_m,  M_ij, R_ij, tau):
    det_pos_Sigma = np.linalg.det(pos_Sigma)
    ell = - prior_m / pos_m * np.log(det_pos_Sigma) + np.trace((inv_prior_Sigma + tau * R_ij * M_ij) @ pos_Sigma) 
    return ell

def posterior_inference(set_m, R_ij, M_ij, prior_Sigma, prior_m, tau, d, iterations, lr):
    opt_Sigma, opt_m, opt_lst_loss = None, None, None
    opt_loss = np.inf

    for m in set_m:
        pos_Sigma, big_l_loss, lst_loss = projected_gradient_descent(R_ij, M_ij, prior_Sigma, prior_m, m, tau, d, iterations, lr)
        
        if big_l_loss < opt_loss:
            opt_loss = big_l_loss
            opt_m = m
            opt_Sigma = pos_Sigma
            opt_lst_loss = lst_loss
    return opt_Sigma, opt_m, opt_lst_loss