import torch
import numpy as np


def log_t_normalizing_const(nu, d):
    nom = torch.lgamma(torch.tensor((nu+d)/2))
    denom = torch.lgamma(torch.tensor(nu/2)) + d/2 * (np.log(nu) + np.log(np.pi))
    return nom - denom

def gamma_regularizer(mu, logvar, n_dim, const_2bar1, gamma, tau, nu):
    mu_norm_sq = torch.linalg.norm(mu, ord=2, dim=1).pow(2)
    trace_var = nu / (nu + n_dim - 2) * torch.sum(logvar.exp(),dim=1)
    log_det_var = -gamma / (2+2*gamma) * torch.sum(logvar,dim=1)

    return torch.mean(mu_norm_sq + trace_var - nu * const_2bar1 * log_det_var.exp() + nu * tau)

def log_Pareto_normalizing_const(m,nu1, nu2):
    nom = torch.lgamma(torch.tensor(nu1+m))
    denom = torch.lgamma(torch.tensor(nu1)) + m * np.log(nu2)
    return nom - denom

def gamma_pow_div(nu, gamma, m, theta1, theta2):
    gamma_ratio = gamma / (1+gamma)
    const = (nu+m) * (gamma_ratio * log_Pareto_normalizing_const(m,nu-1,nu)).exp()
    term_1 = (-gamma_ratio * theta2.log().sum(dim=1)).exp() * (1 + (torch.log(theta1) - torch.log(theta2)).exp().sum(dim=1) /(nu -1))
    term_2 = (-gamma_ratio * theta1.log().sum(dim=1)).exp() * ((nu + m - 1)/ (nu - 1))
    return const*(term_1 - term_2)

# for ZetaVAE
def log_Zeta_normalizing_const(sigma,s, t):
    log_term = s*np.log(t) + (s-1) * torch.log(sigma)
    HZ_ftn = torch.special.zeta(s, t * sigma) 
    return log_term + torch.log(HZ_ftn)

def mu_function(sigma, nu, n):
    log_term = log_Zeta_normalizing_const(sigma,nu+n-1,nu+n)
    log_term -= log_Zeta_normalizing_const(sigma,nu+n,nu+n)
    mu = (nu+n)*sigma*(log_term.exp() - 1)
    return mu

def continuous_zeta(x, sigma, nu):
    return (sigma * nu + torch.tensor(x).floor())**(-nu)

def G_inv(y, sigma, nu):
    A = continuous_zeta(0, sigma, nu)
    B = (sigma * nu)**(-nu + 1) / (nu - 1)
    c = A / (A+B)
    c2 = ((1 - y)* (A+B) * (nu -1))**(-1 / (nu - 1)) +  ( 1 - (sigma * nu))
    value = torch.where(y / c <= 1, y / c, c2)
    return value

def acceptance_ratio(x,sigma,nu):
    fx = continuous_zeta(x, sigma, nu) 
    # M = (continuous_zeta(0, sigma, nu) + (sigma * nu)**(-nu + 1) / (nu - 1))
    Mgx = torch.where(x <= 1, continuous_zeta(0, sigma, nu),  (sigma * nu + x - 1) ** (-nu))
    return fx / Mgx