import torch
import torch.nn as nn

from margflow.model_utils import sample_gaussian_mixture, log_prob_gaussian_mixture
from margflow.datasets.utils import possible_mu_f


class TargetMixture(nn.Module):
    def __init__(
        self,
        n_dim,
        n_target_modes,
        sigma,
        means=None,
        bounds=1.0,
        device="cuda",
        dtype=torch.float32,
    ):
        super(TargetMixture, self).__init__()
        self.n_target_modes = n_target_modes
        self.n_dim = n_dim
        self.sigma = (
            sigma
            if isinstance(sigma, torch.Tensor)
            else torch.tensor(sigma, dtype=dtype).to(device)
        )
        self.bounds = bounds
        self.dtype = dtype
        self.device = device

        if means is None:
            self.means = (
                2 * bounds * torch.rand(n_target_modes, n_dim, device=device, dtype=dtype)
                - self.bounds
            )
        else:
            self.means = means

    def log_prob(self, x):
        log_prob = log_prob_gaussian_mixture(x=x, mixtures=self.means, sigma=self.sigma)

        return log_prob

    def sample(self, n_samples):
        samples = sample_gaussian_mixture(means=self.means, sigma=self.sigma, n_samples=n_samples)
        return samples


class TargetMixtureManifold(TargetMixture):
    def __init__(
        self,
        manifold_type,
        n_dim,
        n_target_modes,
        sigma,
        bounds,
        device="cuda",
        dtype=torch.float32,
    ):
        super(TargetMixtureManifold, self).__init__(
            n_dim, n_target_modes, sigma, bounds=bounds, device=device, dtype=dtype
        )
        assert manifold_type in [
            "line",
            "sin",
            "circle",
            "spiral",
        ], "Manifold must be in ['line', 'sin', 'circle', 'spiral']"
        self.manifold_type = manifold_type
        self.circular = True if manifold_type == "circle" else False
        if self.circular:
            self.m_z = torch.linspace(0, 1, n_target_modes + 1, device=device, dtype=dtype)[:-1]
        else:
            self.m_z = torch.linspace(0, 1, n_target_modes, device=device, dtype=dtype)
        mu_f = possible_mu_f(
            torch.tensor(0, dtype=dtype, device=device),
            dim=n_dim,
            kind=manifold_type,
            bound=self.bounds,
        )
        self.means = mu_f(self.m_z)
