from __future__ import division
import numpy as np
import torch


def isotropic_gauss_loglike(x, mu, sigma, do_sum=True):
    cte_term = -(0.5) * np.log(2 * np.pi)
    det_sig_term = -torch.log(sigma)
    inner = (x - mu) / sigma
    dist_term = -(0.5) * (inner ** 2)
    out = (cte_term + det_sig_term + dist_term)
    if do_sum: out = out.sum()  # sum over all weights
    return out

def sample(self, n_samples=1):
    epsilon = torch.distributions.Normal(0, 1).sample(sample_shape=(n_samples, *self.mean.size()))
    return self.mean + self.std_dev * epsilon


def log_gauss_approximation(lambda_ijl, dd):
    return -0.5 * torch.log(2 * np.pi * lambda_ijl) - torch.pow(dd + lambda_ijl, 2) / (2 * lambda_ijl)


class laplace_prior(object):
    def __init__(self, mu, b):
        self.mu = mu
        self.b = b

    def loglike(self, x, do_sum=True):
        if do_sum:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum()
        else:
            return (-np.log(2 * self.b) - torch.abs(x - self.mu) / self.b)
    def __str__(self):
        return "laplace(mu={}, b={})".format(self.mu, self.b)


class isotropic_gauss_prior(object):
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma

        self.cte_term = -(0.5) * np.log(2 * np.pi)
        self.det_sig_term = -np.log(self.sigma)

    def set_sigma(self, sig1, sig2=None):
        self.sigma = sig1

    def loglike(self, x, do_sum=True):

        dist_term = -(0.5) * ((x - self.mu) / self.sigma) ** 2
        if do_sum:
            return (self.cte_term + self.det_sig_term + dist_term).sum()
        else:
            return (self.cte_term + self.det_sig_term + dist_term)
    def __str__(self):
        return "Gauss(mu={}, sigma={})".format(self.mu, self.sigma)

class spike_slab_2GMM(object):
    def __init__(self, mu1, mu2, sigma1, sigma2, pi):
        self.N1 = isotropic_gauss_prior(mu1, sigma1)
        self.N2 = isotropic_gauss_prior(mu2, sigma2)

        self.pi1 = pi
        self.pi2 = (1 - pi)

    @property
    def sigma(self):
        if self.N1.sigma > self.N2.sigma:
            return self.N1.sigma
        return self.N2.sigma

    def set_sigma(self, sig1, sig2=None):
        self.N1.sigma = sig1
        self.N2.sigma = sig2

    def loglike(self, x):
        N1_ll = self.N1.loglike(x)
        N2_ll = self.N2.loglike(x)

        # Numerical stability trick -> unnormalising logprobs will underflow otherwise
        max_loglike = torch.max(N1_ll, N2_ll)
        normalised_like = self.pi1 + torch.exp(N1_ll - max_loglike) + self.pi2 + torch.exp(N2_ll - max_loglike)
        loglike = torch.log(normalised_like) + max_loglike

        return loglike

    def __str__(self):
        return "spike_slab2GMM(pi={}, sigma1={}, sigma2={})".format(self.pi1, self.N1.sigma, self.N2.sigma)
