import numpy as np
import scipy
import torch
from einops import rearrange
from torch.distributions import Beta


def sample_gaussian_mixture(n_samples: int, means: torch.Tensor, sigma: torch.Tensor):
    # TODO: add special case for means being one dimensional vector (i.e. one mean)
    no_context = True if len(means.shape) <= 2 else False
    samples_shape = (n_samples,) if no_context else (means.shape[0], n_samples)
    sampled_idx = torch.randint(0, means.shape[-2], samples_shape, device=means.device)
    if not no_context:
        batched_idx = torch.arange(means.shape[0]).unsqueeze(1).expand(-1, n_samples)
    sampled_means = means[sampled_idx] if no_context else means[batched_idx, sampled_idx]
    epsilon_size = (
        (n_samples, *means.shape[-1:])
        if no_context
        else (means.shape[0], n_samples, means.shape[-1])
    )
    epsilon = torch.randn(epsilon_size, device=means.device, dtype=means.dtype)
    thetas = epsilon * sigma + sampled_means

    return thetas


def log_prob_scalar(sumdists2, sigma2, dim):
    constant_term = -0.5 * dim * torch.log(2 * torch.pi * sigma2)
    exponent_term = -0.5 * sumdists2 / sigma2

    log_prob = constant_term + exponent_term
    return log_prob


def log_prob_gaussian_mixture(x: torch.Tensor, mixtures: torch.Tensor, sigma: torch.Tensor):
    # x=(n1, 1, d)  x_p=(1, n2, d)
    # TODO needs implementation for a isotropic sigma for each x
    if sigma.numel() > 1:
        # p = 2
        # TODO reimplement -> sigma scales dims then use cdist for efficiency
        # dists = x[:, None] - mixtures
        # dists2 = torch.cdist(x / sigma, mixtures / sigma, p=p) ** p
        dists2 = (
            torch.cdist(
                x / sigma,
                mixtures / sigma,
                p=2,
            )
            ** 2
        )
        # ps = log_prob_diagonal_normal(x=dists, diag_cov=sigma**2)
        ps = log_prob_diagonal_normal(x=dists2, diag_cov=sigma**2, no_scale=True)
        # ps = log_prob_generalized_diagonal_normal(x=dists2, sigma=sigma, p=p)
    else:
        sumdists2 = torch.cdist(x, mixtures) ** 2
        ps = log_prob_scalar(sumdists2=sumdists2, sigma2=sigma**2, dim=x.shape[-1])
    log_prob = torch.logsumexp(ps, dim=-1)

    return log_prob - np.log(mixtures.shape[-2])


def log_prob_diagonal_normal(x: torch.Tensor, diag_cov: torch.Tensor, no_scale=False):
    constant_term = -0.5 * (torch.log(2 * torch.pi * diag_cov)).sum()
    if no_scale:
        exponent_term = -0.5 * x
    else:
        exponent_term = -0.5 * (x**2 / diag_cov).sum(dim=-1)

    log_prob = constant_term + exponent_term
    return log_prob


def log_prob_generalized_diagonal_normal(x: torch.Tensor, sigma, p=2):
    constant_term = np.log(p) - torch.log(2 * sigma) - scipy.special.loggamma(1 / p)
    constant_term = constant_term.sum(-1)
    # constant_term = -0.5 * (torch.log(2 * torch.pi * diag_cov)).sum()
    exponent_term = -x

    log_prob = constant_term + exponent_term
    return log_prob


def log_prob_mixture_betas(x, mixtures):
    alphas, betas = transform_alphas_betas(mixtures)
    alphas = rearrange(alphas, "n d -> 1 n d")
    betas = rearrange(betas, "n d -> 1 n d")
    x = rearrange(x, "m d -> m 1 d")
    dist = Beta(alphas, betas)
    x_ = torch.sigmoid(x)
    log_prob = dist.log_prob(x_).to(x.device)
    log_prob = log_prob.sum(-1)  # all dimensions are independent
    log_prob = torch.logsumexp(log_prob, dim=-1)  # summing across mixture components

    return log_prob


def sample_mixture_betas(n_samples, mixtures):
    alphas, betas = transform_alphas_betas(mixtures)
    indices = torch.randint(0, alphas.shape[0], (n_samples,), device=mixtures.device)
    dist = Beta(alphas[indices], betas[indices])
    samples = dist.rsample().to(mixtures.device)

    return samples


def transform_alphas_betas(mixtures):
    assert mixtures.shape[0] % 2 == 0
    n_eff_mixtures = mixtures.shape[0] // 2
    softplus = torch.nn.Softplus()
    alphas = softplus(mixtures[:n_eff_mixtures])
    betas = softplus(mixtures[n_eff_mixtures:])

    return alphas, betas
