import numpy as np
import scipy as sp
import torch
from sklearn.datasets import make_moons, make_circles, make_swiss_roll

from margflow.datasets.dataset_abstracts import HybridDataset, HybridNologpDataset
from margflow.datasets.target_distribution import TargetMixture, TargetMixtureManifold


class MixtureOfGaussian(HybridDataset):
    def __init__(self, args, means=None):
        super(MixtureOfGaussian, self).__init__(args)
        self.means = means
        self._initialize_target(args)
        self.dataset_suffix += f"_nmog{self.args.n_mog:d}_tsigma{self.args.mog_sigma}"
        self.D = self.args.x_dim
        self.domain_bound = 2

    def _initialize_target(self, args):
        sigma = torch.tensor(self.args.mog_sigma, device=args.device)
        self.target = TargetMixture(
            n_dim=args.x_dim,
            n_target_modes=args.n_mog,
            sigma=sigma,
            means=self.means,
            bounds=self.domain_bound,
            device=args.device,
            dtype=torch.float32,
        )
        self.means = self.target.means

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        samples = self.target.sample(n_samples=n_samples)
        return samples

    def log_prob(self, x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x).float().to(self.args.device)
        logp = self.target.log_prob(x)
        return logp


class MixtureOfGaussianManifold(MixtureOfGaussian):
    def __init__(self, args, means=None):
        super(MixtureOfGaussianManifold, self).__init__(args, means=means)
        self.dataset_suffix += f"_manifold{args.manifold_type}"
        self.domain_bound = 1.5

    def _initialize_target(self, args):
        self.target = TargetMixtureManifold(
            manifold_type=args.manifold_type,
            n_dim=args.x_dim,
            n_target_modes=args.n_mog,
            sigma=args.mog_sigma,
            bounds=self.domain_bound,
            device=args.device,
            dtype=torch.float32,
        )


class UniformSphere(HybridDataset):
    def __init__(self, args):
        super(UniformSquare, self).__init__(args)
        self.compute_log_surface()
        self.dataset_suffix += f"_Usph"
        self.D = self.args.x_dim
        self.domain_bound = 1.1

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        samples = torch.randn([n_samples, self.args.x_dim], device=self.args.device)
        samples /= torch.norm(samples, dim=1).reshape(-1, 1)
        return samples * self.args.radius

    def log_prob(self, x):
        logp = torch.ones(x.shape[0], device=self.args.device)
        norm_const = -self.log_surface_area
        return logp * norm_const

    def compute_log_surface(self):
        d = self.args.x_dim
        r = self.args.radius
        log_const_1 = np.log(2) + 0.5 * d * np.log(np.pi)
        log_const_2 = (d - 1) * np.log(r)
        log_const_3 = -sp.special.loggamma(0.5 * d)
        self.log_surface_area = log_const_1 + log_const_2 + log_const_3


class TwoMoons(HybridNologpDataset):
    def __init__(self, args):
        super(TwoMoons, self).__init__(args)
        assert args.x_dim == 2
        self.D = self.args.x_dim
        self.dataset_suffix += f"_two_moons"
        self.domain_bound = 2.5

        # if args.dataset_kde:
        #     self.kde_cv(n_samples=5_000)

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        noise = 0.05
        samples, labels = make_moons(n_samples=n_samples, noise=noise)
        samples = torch.from_numpy(samples).float().to(self.args.device)
        return samples

    # def log_prob(self, x):
    #     # logp = torch.zeros(x.shape[0], device=self.args.device)
    #     with torch.no_grad():
    #         logp = self.kde_cv.score_samples(x).to(self.args.device)
    #     return logp


class TwoCircles(HybridNologpDataset):
    def __init__(self, args):
        super(TwoCircles, self).__init__(args)
        assert args.x_dim == 2
        self.D = args.x_dim
        self.dataset_suffix += f"_two_circles"
        self.domain_bound = 1.5

        # if args.dataset_kde:
        #     self.kde_cv(n_samples=5_000)

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        samples = make_circles(n_samples=n_samples, factor=0.5)[0]
        # samples *= 2
        # since data lies on a manifold, add small noise
        scale = 0.05
        # scale = 5e-2
        noise = np.random.normal(loc=0.0, scale=scale, size=samples.shape)
        samples = samples + noise
        samples = torch.from_numpy(samples).float().to(self.args.device)

        return samples

    # def log_prob(self, x):
    #     # logp = torch.zeros(x.shape[0], device=self.args.device)
    #     with torch.no_grad():
    #         logp = self.kde_cv.score_samples(x).to(self.args.device)
    #     return logp


class SwissRoll(HybridNologpDataset):
    def __init__(self, args):
        super(SwissRoll, self).__init__(args)
        assert args.x_dim == 2
        self.D = args.x_dim
        self.dataset_suffix += f"_swiss_roll"
        self.domain_bound = 3

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        samples = make_swiss_roll(n_samples=n_samples, noise=0)[0]
        samples = samples[:, [0, 2]]
        samples /= 5
        # since data lies on a manifold, add small noise
        noise = np.random.normal(loc=0.0, scale=5e-2, size=samples.shape)
        samples = samples + noise
        samples = torch.from_numpy(samples).float().to(self.args.device)

        return samples


class Checkerboard(HybridDataset):
    def __init__(self, args):
        super(Checkerboard, self).__init__(args)
        assert args.x_dim == 2
        self.D = args.x_dim
        self.dataset_suffix += f"_checkerboard"
        self.domain_bound = 1.5

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        x1 = np.random.rand(n_samples) * 4 - 2
        x2_ = np.random.rand(n_samples) - np.random.randint(0, 2, n_samples) * 2
        x2 = x2_ + (np.floor(x1) % 2)
        samples = np.concatenate([x1[:, None], x2[:, None]], 1) * 2
        samples = torch.from_numpy(samples).float().to(self.args.device)
        samples /= 4

        return samples

    def log_prob(self, x):
        indices_boundary = (x[:, 0] >= -1) * (x[:, 0] <= 1) * (x[:, 1] >= -1) * (x[:, 1] <= 1)
        indices_board = (
            (torch.floor((x[:, 0] + 1) * 2) + torch.floor((x[:, 1] + 1) * 2)) % 2
        ) == 0
        log_prob_outside = -5
        log_prob = log_prob_outside * torch.ones(x.shape[0], device=self.args.device)
        log_prob[indices_board * indices_boundary] = -np.log(
            2
        )  # Uniform over total support area 8
        # sigma = 0.2
        # log_prob[~indices_board] = -x[~indices_board].square().sum(-1) / (2 * sigma**2) - np.log(
        #     2 * np.pi * sigma**2
        # )
        return log_prob


class Pinwheel(HybridNologpDataset):
    def __init__(self, args):
        super(Pinwheel, self).__init__(args)
        assert args.x_dim == 2
        self.D = args.x_dim
        self.dataset_suffix += f"_pinwheel"
        self.domain_bound = 1

        # n_samples = 5_000
        # if args.logp_estimator == "kde":
        #     cv = self.kde_cv(n_samples=n_samples)
        # elif args.logp_estimator == "parzen":
        #     self.sample_mixtures(n_samples=n_samples, sigma=0.05)
        # else:
        #     raise ValueError("Unknown logp_estimator: must be 'kde' or 'parzen' or 'None'")

    def sample(self, n_samples, data_type: str = "train") -> torch.Tensor:
        radial_std = 0.3
        tangential_std = 0.1
        num_classes = 5
        num_per_class = n_samples // 5
        rate = 0.25
        rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)

        features = np.random.randn(num_classes * num_per_class, 2) * np.array(
            [radial_std, tangential_std]
        )
        features[:, 0] += 1.0
        labels = np.repeat(np.arange(num_classes), num_per_class)

        angles = rads[labels] + rate * np.exp(features[:, 0])
        rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
        rotations = np.reshape(rotations.T, (-1, 2, 2))
        samples = 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations))
        samples = torch.from_numpy(samples).float().to(self.args.device)
        samples /= 5

        return samples

    # def log_prob(self, x):
    #     if self.args.logp_estimator == "kde":
    #         logp = torch.clamp(self.kde_cv.score_samples(x).to(self.args.device), min=-20)
    #     elif self.args.logp_estimator == "parzen":
    #         logp = log_prob_gaussian_mixture(
    #             x=x, mixtures=self.X, sigma=torch.tensor(self.sigma, device=self.args.device)
    #         )
    #     else:
    #         raise NotImplementedError("Log prob not available for this dataset. Use args.")
    #     return logp
