from typing import Union

import numpy as np
import pandas as pd
import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.utils._encode import _check_unknown, _nandict
from sklearn.utils.validation import check_is_fitted, column_or_1d, _num_samples

__all__ = ['LabelEncoder', 'LabelEncoderWithNaN']


def _valid_ind(y: pd.Series):
    """
    Helper function to return a boolean mask of valid (non-NaN) entries in a pandas Series.

    Args:
        y (pd.Series): The input series to check for NaN values.

    Returns:
        pd.Series: Boolean mask indicating non-NaN entries.
    """
    return y.fillna('') != ''


def _valid_dtype(dtype):
    if dtype == bool:
        return int
    return dtype


def _map_to_integer(values, uniques):
    """Map values based on its position in uniques."""
    table = _nandict({val: i for i, val in enumerate(uniques)})
    return np.array([table.get(v, np.nan) for v in values])


class LabelEncoderWithNaN(LabelEncoder):
    """
    Extension of the standard LabelEncoder to handle NaN values.

    This class handles missing values (NaN) in categorical data by ignoring them during fitting and transforming them
    to NaN after encoding. It ensures that NaN values are not accidentally encoded as a valid label.

    Methods:
        fit(y): Fits the label encoder, ignoring NaN values.
        fit_transform(y): Fits the encoder and returns the transformed labels.
        transform(y): Transforms labels to normalized encoding, with NaN preserved.
        inverse_transform(y): Converts normalized encoding back to original labels, preserving NaN.
    """

    def fit(self, y):
        """
        Fit label encoder, ignoring NaN values.

        Args:
            y (pd.Series or array-like): Input labels to fit the encoder on.

        Returns:
            self: Returns an instance of the fitted encoder.
        """
        y = pd.Series(y)
        y = y[_valid_ind(y)]
        self._dtype = y.dtype
        y = y.astype(_valid_dtype(self._dtype)).astype(str)
        self.classes_ = y.value_counts().index.to_numpy()
        return self

    def fit_transform(self, y):
        """
        Fit label encoder and return the encoded labels.

        Args:
            y (pd.Series or array-like): Input labels to fit the encoder on and transform.

        Returns:
            np.ndarray: Encoded labels, with NaN values preserved.
        """
        self.fit(y)
        return self.transform(y)

    def transform(self, y):
        """
        Transform labels to normalized encoding, with NaN preserved.

        Args:
            y (pd.Series or array-like): Labels to encode.

        Returns:
            np.ndarray: Encoded labels, with NaN values preserved.
        """
        y = pd.Series(y)
        valid_ind = _valid_ind(y)
        y_ = y[valid_ind].astype(_valid_dtype(self._dtype)).astype(str)
        result = np.full(y.shape, fill_value=np.nan)
        result[valid_ind.to_numpy()] = self._transform(y_)
        return result

    def _transform(self, y):
        """Transform labels to normalized encoding.

        Parameters
        ----------
        y : array-like of shape (n_samples,)
            Target values.

        Returns
        -------
        y : array-like of shape (n_samples,)
            Labels as normalized encodings.
        """
        check_is_fitted(self)
        y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)
        # transform of empty array is empty array
        if _num_samples(y) == 0:
            return np.array([])

        return self._encode(y, uniques=self.classes_)

    @staticmethod
    def _encode(values, *, uniques, check_unknown=True):
        if values.dtype.kind in "OUS":
            try:
                return _map_to_integer(values, uniques)
            except KeyError as e:
                raise ValueError(f"y contains previously unseen labels: {str(e)}")
        else:
            if check_unknown:
                diff = _check_unknown(values, uniques)
                if diff:
                    raise ValueError(f"y contains previously unseen labels: {str(diff)}")
            return np.searchsorted(uniques, values)

    def inverse_transform(self, y: Union[np.ndarray, torch.Tensor]):
        """
        Convert normalized encoding back to original labels, preserving NaN.

        Args:
            y (np.ndarray or torch.Tensor): Encoded labels to be converted back.

        Returns:
            np.ndarray: Original labels with NaN values preserved.
        """
        if len(self.classes_) == 0:
            return np.full(len(y), pd.NA)
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        y = y.astype(float)
        y[y < 0] = 0
        y[y >= len(self.classes_)] = len(self.classes_) - 1
        nan_ind = np.isnan(y)
        y = np.nan_to_num(y, nan=0).astype(int)
        y = super().inverse_transform(y)
        y[nan_ind] = pd.NA
        y[~nan_ind] = y[~nan_ind].astype(_valid_dtype(self._dtype)).astype(self._dtype)
        return y
