import numpy as np
import torch
from scipy.stats._qmc import Halton


def sample_gaussian_mixture_spike(means, spikes, sigma, n_samples):
    n_means, dim = means.shape
    sampled_idx = torch.randint(0, n_means, (n_samples,), device=means.device)
    sampled_means = means[sampled_idx]
    r1 = torch.randn(n_samples, dim, device=means.device, dtype=means.dtype)
    r2 = torch.randn(n_samples, 1, device=means.device, dtype=means.dtype)
    thetas = sampled_means + r1 * sigma + r2 * spikes[sampled_idx]
    return thetas


def log_prob_diagonal_spike(x_no_spike, x_spike, sigma, spike_norm):
    # the normalization constant dimension is artificially lowered by one -> spike is calculated separately
    constant_term = -0.5 * (x_no_spike.shape[-1] - 1) * torch.log(2 * torch.pi * sigma**2)
    exponent_term = -0.5 * (x_no_spike**2 / (sigma**2)).sum(dim=-1)
    # spike dimension treated differently
    total_spike_var = spike_norm**2 + sigma**2
    constant_term_spike = -0.5 * torch.log(2 * torch.pi * total_spike_var)
    exponent_term_spike = -0.5 * (x_spike**2 / total_spike_var)

    log_prob = constant_term + exponent_term + constant_term_spike[None] + exponent_term_spike

    return log_prob


def parzen_log_prob_spike(x, x_p, spikes, sigma):
    dists = x[:, None] - x_p  # distance matrix (n, m, d)
    spike_norm = torch.norm(spikes, dim=-1, keepdim=True)
    spike_normed = spikes / spike_norm  # m, d
    dproj = (
        spike_normed[None, :, None, :] @ dists[..., None]
    )  # (1, m,  1, d) x (n, m, d, 1) = (n,m,1,1)
    dist_no_spike = dists - dproj[..., 0] * spike_normed[None]
    ps = log_prob_diagonal_spike(
        x_no_spike=dist_no_spike,
        x_spike=dproj[..., 0, 0],
        sigma=sigma,
        spike_norm=spike_norm[:, 0],
    )
    log_prob = torch.logsumexp(ps, dim=-1)

    return log_prob - np.log(x_p.shape[0])


def get_distance_related_costs_unsorted(
    sampled_means,
    samples_base,
    complexity,
    use_manifold_length_costs,
    use_manifold_trivial_solution_costs,
):
    # only regularizes mean positions, NOT the path through them
    dim, zdim = sampled_means.shape[-1], samples_base.shape[-1]
    # pairwise distance of sampled_means and samples_base
    means_dist = torch.cdist(sampled_means, sampled_means).sum()
    base_dist = torch.cdist(samples_base, samples_base).sum()
    return calc_distance_related_costs(
        means_dist,
        base_dist,
        dim,
        zdim,
        complexity,
        use_manifold_length_costs,
        use_manifold_trivial_solution_costs,
    )


def get_distance_related_costs_sorted(
    sampled_means,
    samples_base,
    complexity,
    use_manifold_length_costs,
    use_manifold_trivial_solution_costs,
):
    dim, zdim = sampled_means.shape[-1], samples_base.shape[-1]
    # pairwise distance of neighbouring sampled_means and samples_base
    means_dist = torch.norm(sampled_means[1:] - sampled_means[:-1], dim=-1).sum()
    base_dist = torch.norm(samples_base[1:] - samples_base[:-1], dim=-1).sum()
    return calc_distance_related_costs(
        means_dist,
        base_dist,
        dim,
        zdim,
        complexity,
        use_manifold_length_costs,
        use_manifold_trivial_solution_costs,
    )


def calc_distance_related_costs(
    means_dist,
    base_dist,
    dim,
    zdim,
    complexity,
    use_manifold_length_costs,
    use_manifold_trivial_solution_costs,
):
    # adjustments of distances due to different dimensions
    ratio = np.sqrt(dim / zdim)
    # calculation of costs.
    # for gaussian base: in general apprx sqrt(dim) * sqrt(2)
    if use_manifold_length_costs:  # high distances add costs
        manifold_length_costs = torch.max(
            means_dist.new_zeros(1), means_dist - ratio * complexity * base_dist
        )
    else:
        manifold_length_costs = torch.zeros(1, device=means_dist.device, dtype=means_dist.dtype)
    if use_manifold_trivial_solution_costs:  # very low distances add costs
        manifold_trivial_solution_costs = torch.max(
            means_dist.new_zeros(1), ratio * base_dist - means_dist
        )
    else:
        manifold_trivial_solution_costs = torch.zeros(
            1, device=means_dist.device, dtype=means_dist.dtype
        )
    return manifold_length_costs, manifold_trivial_solution_costs


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad


def remove_neg_inf(x: torch.Tensor):
    all_infinite = torch.all(~torch.isfinite(x))
    all_finite = torch.all(torch.isfinite(x))
    if not all_finite and not all_infinite:
        finite_mask = torch.isfinite(x)
        min_finite = x[finite_mask].min()
        # x = torch.where(finite_mask, x, min_finite)
        x = torch.where(x != -float("inf"), x, min_finite)
    return x


def sample_quasi_random_mixture(means, sigma, n_samples, eps_rel):
    n_means, dim = means.shape
    sampled_idx = torch.randint(0, n_means, (n_samples,), device=means.device)
    sampled_means = means[sampled_idx]

    halton_sampler = Halton(means.shape[-1], scramble=True)
    quasi_samples = halton_sampler.random(n_samples)
    quasi_samples = (
        sigma
        * np.sqrt(2)
        * torch.erfinv(2 * torch.tensor(quasi_samples, device=means.device, dtype=means.dtype) - 1)
    )
    eps = torch.randn(n_samples, dim, device=means.device, dtype=means.dtype) * eps_rel
    samples = sampled_means + quasi_samples + eps
    return samples
