import math
import utils
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torch.distributions as dist

# from torch.distributions.utils import broadcast_all


def get_grad_ell_linear_gauss_p_gauss_q(y, likelihood_pdf, nat_params, approximation_pdf, device='cpu'):
    R = Fn.softplus(likelihood_pdf.log_R)
    b = likelihood_pdf.readout_fn.bias.data
    C = likelihood_pdf.readout_fn.weight.data
    n_latents = C.shape[-1]

    ell_grad_s_1 = torch.einsum('l n, ...n -> ...l', C.T, (1 / R) * (y - b))
    ell_grad_s_2 = C.mT * (1 / R) @ C
    ell_grad_s_2_vec = torch.ones(list(ell_grad_s_1.shape)[:-1] + [n_latents ** 2]).to(device) * ell_grad_s_2.reshape(-1)
    ell_grad_s = torch.cat([ell_grad_s_1, ell_grad_s_2_vec], dim=-1)

    return ell_grad_s


def get_grad_ell_nonlinear_gauss_p_gauss_q(y, likelihood_pdf, nat_params, approximation_pdf, device='cpu'):
    R = Fn.softplus(likelihood_pdf.log_R)
    n_latents = approximation_pdf.n_latents

    mean_params = approximation_pdf.natural_to_mean(nat_params)
    m, M_vec = mean_params[..., :n_latents], mean_params[..., n_latents:]

    if nat_params.dim() == 3:
        C = torch.vmap(torch.vmap(torch.vmap(torch.func.jacfwd(likelihood_pdf.readout_fn))))(m.unsqueeze(0)).squeeze(0)
    elif nat_params.dim() == 2:
        C = torch.vmap(torch.vmap(torch.func.jacfwd(likelihood_pdf.readout_fn)))(m.unsqueeze(0)).squeeze(0)

    b = likelihood_pdf.readout_fn(m) - torch.einsum('...n l, ...l -> ...n', C, m)

    ell_grad_s_1 = torch.einsum('...l n, ...n -> ...l', C.mT, (1 / R) * (y - b))
    ell_grad_s_2 = C.mT * (1 / R) @ C
    ell_grad_s_2_vec = torch.ones(list(ell_grad_s_1.shape)[:-1] + [n_latents ** 2]) * ell_grad_s_2.flatten(-2)
    ell_grad_s = torch.cat([ell_grad_s_1, ell_grad_s_2_vec], dim=-1)

    return ell_grad_s


class GaussianLikelihood(nn.Module):
    def __init__(self, readout_fn, n_neurons, R_diag, device='cpu', fix_R=False):
        super(GaussianLikelihood, self).__init__()

        self.n_neurons = n_neurons
        self.readout_fn = readout_fn

        if fix_R:
            self.log_R = utils.softplus_inv(R_diag)
        else:
            self.log_R = torch.nn.Parameter(utils.softplus_inv(R_diag))

    def get_ell(self, y, z):
        mean = self.readout_fn(z)
        cov = Fn.softplus(self.log_R.to(y.device))
        log_prob = -0.5 * ((y - mean)**2 / cov + torch.log(cov) + math.log(2 * math.pi))
        log_p_y = log_prob.sum(dim=-1)
        return log_p_y

    def readout(self, z):
        f = self.readout_fn(z)
        v = torch.sqrt(Fn.softplus(self.log_R)) * torch.randn_like(f)
        y = f# + v

        return y


class PoissonLikelihood(nn.Module):
    def __init__(self, readout_fn, n_neurons, delta, device='cpu', p_mask=0.0):
        super(PoissonLikelihood, self).__init__()
        self.delta = delta
        self.device = device
        self.n_neurons = n_neurons
        self.readout_fn = readout_fn

    def get_ell(self, y, z, reduce_neuron_dim=True):
        log_exp = math.log(self.delta) + self.readout_fn(z) # C @ z
        log_p_y = -torch.nn.functional.poisson_nll_loss(log_exp, y, full=True, reduction='none')

        if reduce_neuron_dim:
            return log_p_y.sum(dim=-1)
        else:
            return log_p_y

class BernoulliLikelihood(nn.Module):
    def __init__(self, readout_fn, n_neurons, device='cpu'):
        super(BernoulliLikelihood, self).__init__()
        self.device = device
        self.n_neurons = n_neurons
        self.readout_fn = readout_fn

    def get_ell(self, y, z):
        logits = self.readout_fn(z)
        log_p_y = -Fn.binary_cross_entropy_with_logits(logits, y.type(z.dtype)[None, :, None], reduction='none').sum(dim=-1)
        # TODO: figure out better way
        return log_p_y.sum(dim=-1)

    def predict_and_score(self, y, m):
        logits_prd = self.readout_fn(m)
        y_hat = (logits_prd >= 0.)

        n_correct = (y == y_hat).sum()
        acc = n_correct / y.numel()

        return y_hat, acc.item()



