from typing import Optional, Union

import numpy as np
import torch
from numpy.random import RandomState
from sklearn.utils import check_random_state


class Dataset:
    def __init__(
        self,
        X: np.ndarray,
        y: np.ndarray,
        random_state: Optional[Union[RandomState, int]],
    ):
        self.random_state = check_random_state(random_state)
        self.format_dataset(X, y)
        self.meta = {}
        self.infer_task_type()

    def format_dataset(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        if np.issubdtype(y.dtype, np.floating):
            self.y = torch.tensor(y, dtype=torch.float32)
        else:
            self.y = torch.tensor(y, dtype=torch.long)

    def infer_task_type(self):
        y_np = self.y.numpy()
        if np.issubdtype(y_np.dtype, np.floating):
            self.meta["task"] = "regression"
        else:
            unique_labels = np.unique(y_np)
            if len(unique_labels) == 2:
                self.meta["task"] = "binary_classif"
            elif len(unique_labels) > 2:
                self.meta["task"] = "multi_class_classif"
            else:
                raise ValueError(
                    "Unable to determine task type: labels may not fit binary, multi-class, or regression task types."
                )

    def get_len(self):
        return len(self.X)

    def get_num_features(self):
        return self.X.shape[1]

    def get_item(self, index):
        return self.X[index], self.y[index]

    def nosify(self, noise_type: str, noise_ratio: float, **kwargs):
        self.meta["original_y"] = self.y
        if noise_type == "flip":
            if self.meta["task"] not in ["binary_classif", "multi_class_classif"]:
                raise ValueError(
                    f"Cannot flip labels for task type: {self.meta['task']}"
                )

            self._flip_labels(noise_ratio)

        elif noise_type == "gauss":
            if self.meta["task"] != "regression":
                raise ValueError(
                    f"Cannot apply Gaussian noise to task type: {self.meta['task']}"
                )

            self._add_gaussian_noise(noise_ratio, **kwargs)

        else:
            raise ValueError(f"Unsupported noise type: {noise_type}")

    def _flip_labels(self, noise_ratio: float):
        num_samples = self.get_len()
        num_noisy = int(noise_ratio * num_samples)
        indices_to_flip = self.random_state.choice(
            num_samples, num_noisy, replace=False
        )
        self.meta["noisy_idx"] = indices_to_flip
        y_np = self.y.numpy().copy()

        if self.meta["task"] == "binary_classif":
            y_np[indices_to_flip] = 1 - y_np[indices_to_flip]

        elif self.meta["task"] == "multi_class_classif":
            for idx in indices_to_flip:
                current_label = y_np[idx]
                possible_labels = np.setdiff1d(
                    np.arange(np.max(y_np) + 1), [current_label]
                )
                y_np[idx] = self.random_state.choice(possible_labels)

        self.y = torch.tensor(y_np, dtype=self.y.dtype)

    def _add_gaussian_noise(self, noise_ratio: float, **kwargs):
        std_dev = kwargs.get("std_dev", 0.1)
        num_samples = self.get_len()
        num_noisy = int(noise_ratio * num_samples)
        indices_to_noise = self.random_state.choice(
            num_samples, num_noisy, replace=False
        )
        self.meta["noisy_idx"] = indices_to_noise
        noise = self.random_state.normal(loc=0, scale=std_dev, size=num_noisy)

        y = self.y.numpy().copy()
        y[indices_to_noise] += noise

        self.y = torch.tensor(y, dtype=self.y.dtype)