import pickle
from pathlib import Path

import torch
from torch.utils.data import DataLoader
import numpy as np

from datasets.base_dataset import BaseDataset
from utils.functions import seed_worker


def gen_data_line(dataset_len=1000):
    np.random.seed(42)
    #####################
    # Generates a dataset:
    #         ┌
    #         │ 1, x_1 > x_0
    #  ƒ(x) = │
    #         │ 0, x_1 <= x_0
    #         └
    # Points exist in a cube [-1, 1]^2
    # classes are be balanced
    #####################

    x = np.random.uniform(low=-1, high=1, size=(dataset_len, 2))
    y = (x[:, 0] > x[:, 1]) * 1

    return x, y


def gen_data_ellipse(dataset_len=1000):
    np.random.seed(42)
    #####################
    # Generates a dataset:
    #         ┌
    #         │ 1, x0^2 + (x1/k)^2 - 1 >= 0
    #  ƒ(x) = │
    #         │ 0, x0^2 + (x1/k)^2 - 1 < 0
    #         └
    # Points exist in a cube [-1, 1]^2, k=2/π
    # Classes are balanced, since the area of the ellipse is
    # πk = π*2/π = 2, which is half of the area of the [-1, 1]^2 cube (=4)
    #####################
    k = 2 / np.pi
    x = np.random.uniform(low=-1, high=1, size=(dataset_len, 2))
    y = x[:, 0] ** 2 + (x[:, 1] / k) ** 2 - 1
    y = (y >= 0) * 1

    return x, y


def gen_data_hypersphere(dataset_len=1000):
    def convert_to_decart(angles, radius, x_dim):
        point = []

        for i in range(x_dim - 1):
            x_i = radius * np.cos(angles[i]) * np.prod([np.sin(angles[j]) for j in range(0, i)])
            point.append(x_i)

        # last dim of the point
        x_d = radius * np.prod([np.sin(angles[j]) for j in range(x_dim - 1)])
        point.append(x_d)

        return point

    np.random.seed(42)

    r0 = 1
    r1 = 2
    assert dataset_len % 2 == 0
    num_points_per_class = dataset_len // 2
    x_dim = 10

    radii = [r0, r1]
    x = []
    y = []

    for c in range(len(radii)):
        for _ in range(num_points_per_class):
            # use hyperspherical coordinates to generate the points
            angles = np.random.uniform(low=0, high=2 * np.pi, size=x_dim - 1)
            x.append(convert_to_decart(angles, radii[c], x_dim))
            y.append(c)

    x = np.array(x)
    y = np.array(y)
    return x, y


GEN_DATA = {
    # name: [function, dataset_shape]
    "line": [gen_data_line, [2, 2]],
    "ellipse": [gen_data_ellipse, [2, 2]],
    "hypersphere": [gen_data_hypersphere, [10, 2]],
}


class SynthDatasetWrapper(BaseDataset):
    def __init__(
        self,
        dataset_path: Path,
        is_train: bool,
        noise_scale: float,
        one_hot_y: bool,
        dataset_len: int,
        alpha_shuffle: float = 0.0,
        type: str = "line",
        **kwargs,
    ):
        super(SynthDatasetWrapper, self).__init__(
            dataset_path=dataset_path,
            is_train=is_train,
            noise_scale=noise_scale,
            one_hot_y=one_hot_y,
            dataset_len=dataset_len,
            alpha_shuffle=alpha_shuffle,
            type=type,
            **kwargs,
        )

    def bake_dataset(
        self,
        dataset_path: Path,
        is_train: bool,
        noise_scale: float,
        one_hot_y: bool,
        dataset_len: int = 1000,
        alpha_shuffle: float = 0.0,
        type: str = "line",
        **kwargs,
    ) -> tuple[torch.FloatTensor, torch.FloatTensor | torch.LongTensor]:
        np.random.seed(42)
        ds_path = dataset_path / "datasets"

        # gen the dataset if not done already
        if not (ds_path / f"synth_{type}_{is_train}_{dataset_len}.pkl").exists():
            data = GEN_DATA[type][0](dataset_len=dataset_len)
            with (ds_path / f"synth_{type}_{is_train}_{dataset_len}.pkl").open(mode="wb") as f:
                pickle.dump(data, f)

        with (ds_path / f"synth_{type}_{is_train}_{dataset_len}.pkl").open(mode="rb") as f:
            data = pickle.load(f)

        x, y = data

        # add noise
        if noise_scale > 0.0:
            noise = noise_scale * np.random.normal(size=x.shape)
            x = (x + noise).astype(np.float32)

        # one hot
        if one_hot_y:
            y = self.onehot_vector(y, 2)

        # sample subset
        subset_indices = np.random.choice(len(x), size=dataset_len, replace=False)
        x = x[subset_indices]
        y = y[subset_indices]

        # shuffle
        if alpha_shuffle > 0.0:
            y = self.incremental_shuffle(y, alpha_shuffle)

        # convert to tensor
        if one_hot_y:
            # if one_hot = True => MSE loss => float tensor
            return torch.Tensor(x), torch.Tensor(y).type(torch.FloatTensor)
        # if one_hot = False => CE loss => long tensor
        return torch.Tensor(x), torch.Tensor(y).type(torch.LongTensor)


def load_datasets(
    dataset_path: Path,
    batch_size: int = 512,
    batch_size_test: int = 512,
    noise_scale: float = 0.0,
    one_hot_encode_y: bool = True,
    alpha_shuffle: float = 0.0,  # 0.0 = no shuffle, 1.0 = shuffle 100% of the data
    train_len: int = 1000,
    test_len: int = 500,
    num_workers: int = 0,
    type: str = "line",
    **kwargs,
) -> tuple[BaseDataset, BaseDataset, DataLoader, DataLoader, list]:
    """Get the SynthDataset.

    Parameters
    ----------
    dataset_path
        Path to the dataset
    batch_size, optional
        Train set batch size, by default 512
    batch_size_test, optional
        Test set batch size, by default 512
    noise_scale, optional
        Sigma value for the Gaussian noise added to the train set X, by default 0.0
    one_hot_encode_y, optional
        Boolean flag to one-hot encode y, by default True
    alpha_shuffle, optional
        The fraction of train samples that will be shuffled, by default 0.0
    train_len, optional
        Length of the train dataset, by default 1000
    test_len, optional
        Length of the test dataset, by default 500
    num_workers, optional
        Number of workers, by default 0
    type, optional
        Number of workers, by default is "line". Possible types: "line", "ellipse".

    Returns
    -------
        Returns the values in the following order: `train_dataset`, `test_dataset`, `train_dataloader`, `test_dataloader`, `dataset_dims`
    """

    train_dataset = SynthDatasetWrapper(
        dataset_path, True, noise_scale, one_hot_encode_y, train_len, alpha_shuffle, type, **kwargs
    )
    test_dataset = SynthDatasetWrapper(
        dataset_path, False, 0.0, one_hot_encode_y, test_len, 0.0, type, **kwargs
    )

    g = torch.Generator()
    g.manual_seed(0)

    train_dataloader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=False,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=batch_size_test,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=False,
        worker_init_fn=seed_worker,
        generator=g,
    )

    return train_dataset, test_dataset, train_dataloader, test_dataloader, GEN_DATA[type][1]
