import os
import random
from argparse import Namespace
from typing import Any, List, Tuple

import numpy as np  # type: ignore
import torch
from sklearn.datasets import make_circles, make_moons  # type: ignore
from torch.utils.data import Dataset

T = torch.Tensor


def get_biased_sample_idx(x: Any, y: Any, k_shot: int) -> Tuple[Any, ...]:
    classes = np.unique(y)
    n_sections = 2  # (n-way + kshot) * classes needs to be equally divisible by n_sections

    sx, sy, qx, qy = np.empty((0, 2)), np.empty((0,)), np.empty((0, 2)), np.empty((0,))
    for c in classes:
        class_idx = np.argwhere(y == c).squeeze(1)
        class_x, class_y = x[class_idx], y[class_idx]

        x_or_y = 0 if np.sign(np.random.rand() - 0.5) < 0 else 1  # choose x or y index randomly
        section = np.random.permutation(n_sections)  # which half of the data to get
        x_idx = np.argsort(class_x[:, x_or_y])

        def sec(n: int) -> int:
            return int(n * (x_idx.shape[0] // n_sections))

        # get the support and qeury sets for this class which are split by section (whichever biased section we chose)
        spt_x = class_x[x_idx[sec(section[0]) : sec(section[0] + 1)]]  # get the proper third
        spt_y = class_y[x_idx[sec(section[0]) : sec(section[0] + 1)]]  # get the proper third
        qry_x = class_x[x_idx[sec(section[1]) : sec(section[1] + 1)]]
        qry_y = class_y[x_idx[sec(section[1]) : sec(section[1] + 1)]]

        # collect random k of the biased support sets into one and leave the rest for the qeury set
        spt_perm = np.random.permutation(spt_x.shape[0])
        sx = np.concatenate((sx, spt_x[spt_perm[:k_shot]]))
        sy = np.concatenate((sy, spt_y[spt_perm[:k_shot]]))
        qx = np.concatenate((qx, spt_x[spt_perm[k_shot:]], qry_x))
        qy = np.concatenate((qy, spt_y[spt_perm[k_shot:]], qry_y))

    return sx, sy, qx, qy


class ToyDataset(Dataset):
    def __init__(self, seed: int = 0, k_shot: int = 10, total_tasks: int = 100, test_shots: int = 50):

        self.seed = seed
        self.k_shot = k_shot
        self.total_tasks = total_tasks
        self.test_shots = test_shots
        self.dim: int

        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    def sample_uniform(self) -> T:
        x = torch.linspace(-3, 3, 100)
        return torch.stack(torch.meshgrid(x, x), dim=-1).view(-1, self.dim)

    def __len__(self) -> int:
        return self.total_tasks

    def gen_random_task(self) -> Tuple[T, T, T, T]:
        raise NotImplementedError()

    def __getitem__(self, i: int) -> Tuple[T, T, T, T]:
        xs, ys, xq, yq = self.gen_random_task()
        xs.n_way = self.n_way  # type: ignore
        xs.k_shot = self.k_shot  # type: ignore

        return xs, ys, xq, yq


class MetaMoons(ToyDataset):
    def __init__(self, seed: int = 0, k_shot: int = 10, total_tasks: int = 100, test_shots: int = 50):
        super().__init__(seed=seed, k_shot=k_shot, total_tasks=total_tasks, test_shots=test_shots)

        self.n_way = 2
        self.dim = 2
        self.name = "moons"
        self.path = os.path.join("toy-moons", "2-way", f"{k_shot}-shot", f"{test_shots}-testshot")

    def gen_random_task(self) -> Tuple[T, T, T, T]:
        noise = np.random.rand() * .25
        x, y = make_moons(n_samples=self.n_way * (self.k_shot + self.test_shots), noise=noise, random_state=self.seed)

        if np.random.rand() > 0.5:
            y = 1 - y  # randomly invert the classes so the task is more random

        sx, sy, qx, qy = get_biased_sample_idx(x, y, self.k_shot)
        sx, sy, qx, qy = torch.from_numpy(sx).float(), torch.from_numpy(sy).long(), torch.from_numpy(qx).float(), torch.from_numpy(qy).long()
        return sx, sy, qx, qy


class MetaCircles(ToyDataset):
    def __init__(self, seed: int = 0, k_shot: int = 10, total_tasks: int = 100, test_shots: int = 50):
        super().__init__(seed=seed, k_shot=k_shot, total_tasks=total_tasks, test_shots=test_shots)

        self.n_way = 2
        self.dim = 2
        self.name = "circles"
        self.path = os.path.join("toy-circles", "2-way", f"{k_shot}-shot", f"{test_shots}-testshot")

    def gen_random_task(self) -> Tuple[T, T, T, T]:
        noise = np.random.rand() * .25
        scale = np.random.rand() * 0.8
        x, y = make_circles(n_samples=self.n_way * (self.k_shot + self.test_shots), noise=noise, factor=scale, random_state=self.seed)

        if np.random.rand() > 0.5:
            y = 1 - y  # randomly invert the classes so the task is more random

        sx, sy, qx, qy = get_biased_sample_idx(x, y, self.k_shot)
        sx, sy, qx, qy = torch.from_numpy(sx).float(), torch.from_numpy(sy).long(), torch.from_numpy(qx).float(), torch.from_numpy(qy).long()
        return sx, sy, qx, qy


class MetaGaussians(ToyDataset):
    def __init__(
        self,
        seed: int = 0,
        n_way: int = 5,
        k_shot: int = 5,
        total_tasks: int = 100,
        test_shots: int = 15,
        mu_rng: List[int] = [-5, 5],
        var_rng: List[float] = [0.1, 1.0],
        dim: int = 2
    ):
        super().__init__(seed=seed, k_shot=k_shot, total_tasks=total_tasks, test_shots=test_shots)

        self.name = "2d-gaussians"
        self.mu_rng = mu_rng
        self.n_way = n_way
        self.var_rng = var_rng
        self.var = var_rng
        self.dim = dim
        self.name = "gausian"
        self.path = os.path.join("toy-gaussian", f"{n_way}-way", f"{k_shot}-shot", f"{test_shots}-testshot")

    def sample(self, N: torch.distributions.MultivariateNormal, variant: str = "uniform") -> Tuple[T, T]:
        train, test = N.sample((self.k_shot,)).transpose(0, 1), N.sample((self.test_shots,)).transpose(0, 1)
        return train, test

    def gen_random_task(self) -> Tuple[T, T, T, T]:
        # sample mus and sigmas uniformyl according to their range
        mus = torch.rand((self.n_way, self.dim)) * (self.mu_rng[1] - self.mu_rng[0]) + self.mu_rng[0]

        # decompose PSD sigma as O^TDO with orthogonal O's to make random PSD covariance
        # https://stats.stackexchange.com/questions/2746/how-to-efficiently-generate-random-positive-semidefinite-correlation-matrices
        O = torch.rand((self.n_way, self.dim, self.dim)) * 2 - 1
        O = torch.linalg.qr(O, mode='complete')[0]
        D = torch.stack([torch.eye(self.dim) * torch.rand(self.dim) for i in range(self.n_way)])

        # make the eigenvectors be different lengths in order to make the direction elliptical ratio of 5:1
        tmp = (torch.rand((self.n_way, self.dim)) * (self.var_rng[1] - self.var_rng[0]) + self.var_rng[0]).unsqueeze(1)
        tmp[:, :, 1] = tmp[:, :, 0] / 5
        D = D * tmp
        sigmas = O.transpose(1, 2).bmm(D.bmm(O))

        N = torch.distributions.MultivariateNormal(mus, sigmas)
        labels = torch.randperm(self.n_way)

        train_x, test_x = self.sample(N)

        mu, sigma = train_x.mean(dim=(0, 1)), train_x.std(dim=(0, 1))

        train_x = (train_x - mu) / sigma
        test_x = (test_x - mu) / sigma
        train_y = labels.unsqueeze(-1).repeat(1, self.k_shot)
        test_y = labels.unsqueeze(-1).repeat(1, self.test_shots)

        train_x, train_y, test_x, test_y = train_x.reshape(-1, self.dim).numpy(), train_y.reshape(-1).numpy(), test_x.reshape(-1, self.dim).numpy(), test_y.reshape(-1).numpy()
        x, y = np.concatenate((train_x, test_x)), np.concatenate((train_y, test_y))

        assert x.shape[0] % 2 == 0, f"x needs to be evenly divisible by 2 (got shape {x.shape}) for the toy Gaussian, if not you have to fix 'get biased sample function'"
        sx, sy, qx, qy = get_biased_sample_idx(x, y, self.k_shot)

        return torch.from_numpy(sx).float(), torch.from_numpy(sy).long(), torch.from_numpy(qx).float(), torch.from_numpy(qy).long()


class MetaAll(ToyDataset):
    def __init__(
        self,
        args: Namespace,
        seed: int = 0,
        n_way: Tuple[int, ...] = (2, 2, 10),
        k_shot: Tuple[int, ...] = (5, 5, 5),
        total_tasks: int = 100,
        test_shots: Tuple[int, ...] = (15, 15, 15),
        gaussians_mu_rng: List[int] = [-5, 5],
        gaussians_var_rng: List[float] = [0.1, 1.0],
        dim: int = 2
    ):
        super().__init__(seed=seed, k_shot=1, total_tasks=total_tasks, test_shots=1)

        self.datasets = [
            MetaMoons(seed, k_shot[0], total_tasks, test_shots[0]),
            MetaCircles(seed, k_shot[1], total_tasks, test_shots[1]),
            MetaGaussians(seed, n_way[2], k_shot[2], total_tasks, test_shots[2], gaussians_mu_rng, gaussians_var_rng, dim)
        ]

        self.name = "toy-meta-all"
        self.args = args
        self.n_way_all = n_way
        self.k_shot_all = k_shot
        self.test_shots_all = test_shots
        self.total_tasks = total_tasks * 3
        self.dim = dim
        self.path = os.path.join("toy", f"{n_way}-way", f"{k_shot}-shot", f"{test_shots}-testshot")

    def sample_uniform(self) -> T:
        x = torch.linspace(-3, 3, 100)
        return torch.stack(torch.meshgrid(x, x), dim=-1).view(-1, 2)

    def __getitem__(self, i: int) -> Tuple[T, T, T, T]:
        dataset = np.random.permutation(3)[0]
        xs, ys, xq, yq = self.datasets[dataset].gen_random_task()  # type: ignore

        xs.n_way = self.n_way_all[dataset]
        xs.k_shot = self.k_shot_all[dataset]

        return xs, ys, xq, yq
