from __future__ import annotations
from typing import Any, List, Literal, Optional
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import QuantileTransformer
from tsbench.experiments.metrics import Performance


class PerformanceTransformer(TransformerMixin):
    """
    The performance transformer transforms performances into model outputs for supervised learning
    as well as model outputs to performance objects.
    """

    def __init__(
        self, transform: Optional[Literal["quantile"]] = None, metrics: Optional[List[str]] = None
    ):
        """
        Args:
            transform: The kind of transform to apply to all performance metrics. If `None`, no
                transform is applied.
            metrics: The performance metrics to transform. If `None`, all performance metrics are
                transformed. If provided, metrics that are not present will be converted into NaNs.
        """
        self.encoder = PerformanceEncoder(metrics)

        steps = [self.encoder]
        if transform == "quantile":
            steps.append(QuantileTransformer())

        self.pipeline = make_pipeline(*steps)

    @property
    def features_names_(self) -> List[str]:
        """
        Returns the feature names for the columns of the transformed performance objects.
        """
        return self.encoder.feature_names_

    def fit(self, y: List[Performance]) -> PerformanceTransformer:
        """
        Uses the provided performances to fit the performance transformer.

        Args:
            y: The performance objects.
        """
        self.pipeline.fit(y)
        return self

    def transform(self, y: List[Performance]) -> np.ndarray:
        """
        Transforms the provided performance object into NumPy arrays according to the fitted
        transformer.

        Args:
            y: The performance objects.

        Returns:
            An array of shape [N, K] of transformed performance objects (N: the number of
                performance objects, K: number of performance metrics).
        """
        return self.pipeline.transform(y)

    def inverse_transform(self, y: np.ndarray) -> List[Performance]:
        """
        Transforms the provided NumPy arrays back into performance objects according to the fitted
        transformer.

        Args:
            y: A NumPy array of shape [N, K] of performances (N: number of performances, K: number
                of performance metrics).

        Returns:
            The performance objects.
        """
        return self.pipeline.inverse_transform(y)


# -------------------------------------------------------------------------------------------------
# pylint: disable=missing-class-docstring,missing-function-docstring


class PerformanceEncoder:
    def __init__(self, metrics: Optional[List[str]] = None):
        self.metrics = metrics
        self.all_feature_names_: List[str]
        self.feature_names_: List[str]

    def fit(self, X: List[Performance], _y: Any = None) -> PerformanceEncoder:
        df = Performance.to_dataframe(X)
        self.all_feature_names_ = df.columns
        if self.metrics is None:
            self.feature_names_ = df.columns
        else:
            assert all(m in df.columns for m in self.metrics)
            self.feature_names_ = self.metrics
        return self

    def transform(self, X: List[Performance], _y: Any = None) -> np.ndarray:
        df = Performance.to_dataframe(X)
        return df[self.feature_names_].to_numpy()

    def inverse_transform(self, X: np.ndarray, _y: Any = None) -> np.ndarray:
        df = pd.DataFrame(X, columns=self.feature_names_).assign(
            **{col: np.nan for col in set(self.all_feature_names_) - set(self.feature_names_)}
        )
        return [Performance.from_dict(row.to_dict()) for _, row in df.iterrows()]
