import random
from typing import List, Tuple

import numpy as np  # type: ignore
import torch
from sklearn.datasets import make_circles, make_moons  # type: ignore
from torch.utils.data import Dataset

T = torch.Tensor


class ToyDataset(Dataset):
    def __init__(self, n: int, seed: int = 0):
        self.seed = seed
        self.n = n

        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    def __len__(self) -> int:
        return self.n

    def sample_uniform(self) -> T:
        x = torch.linspace(-3, 3, 100)
        return torch.stack(torch.meshgrid(x, x), dim=-1).view(-1, 2)


class Moons(ToyDataset):
    def __init__(self, n: int, seed: int = 0):
        super().__init__(n, seed=seed)
        self.n = n
        self.name = "moons"
        noise = np.random.rand() * .25
        x, y = make_moons(n_samples=n, noise=noise, random_state=self.seed)

        self.x = torch.from_numpy(x).float()
        self.y = torch.from_numpy(y).long()

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]


class Circles(ToyDataset):
    def __init__(self, n: int, seed: int = 0):
        super().__init__(n, seed=seed)
        self.n_way = 2
        self.name = "circles"

        noise = np.random.rand() * .25
        scale = np.random.rand() * 0.8
        x, y = make_circles(n_samples=n, noise=noise, factor=scale, random_state=self.seed)

        self.x = torch.from_numpy(x).float()
        self.y = torch.from_numpy(y).long()

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]


class Gaussians(ToyDataset):
    def __init__(self, n: int, classes: int = 10, seed: int = 0, mu_rng: List[int] = [-5, 5], var_rng: List[float] = [0.1, 1.0], dim: int = 2):
        super().__init__(n, seed=seed)

        self.n = n
        self.mu_rng = mu_rng
        self.var_rng = var_rng
        self.classes = classes
        self.dim = dim
        self.name = "gausians"

        # sample mus and sigmas uniformyl according to their range
        mus = torch.rand((self.classes, self.dim)) * (self.mu_rng[1] - self.mu_rng[0]) + self.mu_rng[0]

        # decompose PSD sigma as O^TDO with orthogonal O's to make random PSD covariance
        # https://stats.stackexchange.com/questions/2746/how-to-efficiently-generate-random-positive-semidefinite-correlation-matrices
        O = torch.rand((self.classes, self.dim, self.dim)) * 2 - 1
        O = torch.linalg.qr(O, mode='complete')[0]
        D = torch.stack([torch.eye(self.dim) * torch.rand(self.dim) for i in range(self.classes)])

        # make the eigenvectors be different lengths in order to make the direction elliptical ratio of 5:1
        tmp = (torch.rand((self.classes, self.dim)) * (self.var_rng[1] - self.var_rng[0]) + self.var_rng[0]).unsqueeze(1)
        tmp[:, :, 1] = tmp[:, :, 0] / 5
        D = D * tmp
        sigmas = O.transpose(1, 2).bmm(D.bmm(O))

        N = torch.distributions.MultivariateNormal(mus, sigmas)
        y = torch.randperm(self.classes).unsqueeze(1).repeat(1, self.n // self.classes)
        x = N.sample((self.n // self.classes,)).transpose(0, 1)

        mu, sigma = x.mean(dim=(0, 1)), x.std(dim=(0, 1))
        x = (x - mu) / sigma
        self.x = x.reshape(-1, self.dim).float()
        self.y = y.reshape(-1).long()

    def __getitem__(self, i: int) -> Tuple[T, T]:
        return self.x[i], self.y[i]  # type: ignore
