from __future__ import annotations
from abc import ABC, abstractproperty
from typing import Any, Dict, List, Optional, Set
import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin
from sklearn.feature_extraction import DictVectorizer
from sklearn.impute import SimpleImputer
from sklearn.pipeline import make_pipeline, make_union
from sklearn.preprocessing import MinMaxScaler, PowerTransformer, StandardScaler
from tsbench.config import Config
from tsbench.config.model.models import SeasonalNaiveModelConfig
from tsbench.experiments.tracking import Tracker


class ConfigTransformer(TransformerMixin):
    """
    The config transformer transforms a configuration (model + dataset) into a real-valued vector.
    """

    def __init__(
        self,
        add_model_features: bool = True,
        add_dataset_statistics: bool = True,
        add_seasonal_naive_performance: bool = False,
        add_catch22_features: bool = False,
        tracker: Optional[Tracker] = None,
    ):
        """
        Args:
            add_model_type: Whether a one-hot encoding of the model type as well as model
                hyperparameters should be added.
            add_dataset_statistics: Whether simple dataset statistics ought to be added.
            add_seasonal_naive_performance: Whether to add the nCRPS performance of Seasonal Naïve.
                Requires the cacher to be set.
            add_catch_22_features: Whether a dataset's catch22 features ought to be added.
            tracker: An optional tracker to obtain the performance of Seasonal Naïve.
        """
        assert any(
            [
                add_model_features,
                add_dataset_statistics,
                add_seasonal_naive_performance,
                add_catch22_features,
            ]
        ), "ConfigTransformer must be given at least some group of features."
        assert (
            not add_seasonal_naive_performance or tracker is not None
        ), "Cacher must be set if seasonal naive performance is used."

        self.encoders = []
        if add_model_features:
            self.encoders.append(ModelEncoder())
        if add_dataset_statistics:
            self.encoders.append(DatasetStatisticsEncoder())
        if add_seasonal_naive_performance and tracker is not None:
            self.encoders.append(SeasonalNaivePerformanceEncoder(tracker))
        if add_catch22_features:
            self.encoders.append(DatasetCatch22Encoder())

        self.pipeline = make_union(*self.encoders)

    @property
    def feature_names_(self) -> List[str]:
        """
        Returns the feature names for the columns of transformed configurations.
        """
        return [f for e in self.encoders for f in e.feature_names_]

    def fit(self, X: List[Config]) -> ConfigTransformer:
        """
        Uses the provided configurations to fit the transformer pipeline.

        Args:
            X: The input configurations.
        """
        self.pipeline.fit(X)
        return self

    def transform(self, X: List[Config]) -> np.ndarray:
        """
        Transforms the given configurations according to the fitted transformer pipeline.

        Args:
            X: The input configurations.

        Returns:
            A NumPy array of shape [N, D]. N is the number of input configurations, D the dimension
                of the vectorized representation.
        """
        return self.pipeline.transform(X)


# -------------------------------------------------------------------------------------------------
# pylint: disable=missing-class-docstring,missing-function-docstring


class Encoder(ABC):
    @abstractproperty
    def feature_names_(self) -> List[str]:
        pass


class NanImputer:
    def fit(self, _X: List[Dict[str, Any]], _y: Any = None) -> NanImputer:
        return self

    def transform(self, X: List[Dict[str, Any]], _y: Any = None) -> List[Dict[str, Any]]:
        df = pd.DataFrame(X)
        return df.to_dict("records")


class Selector:
    def __init__(self, use: Optional[Set[str]] = None, ignore: Optional[Set[str]] = None):
        assert bool(use) != bool(ignore), "One of `use` or `ignore` must be set."

        self.use = use or set()
        self.ignore = ignore or set()

    def fit(self, _X: List[Dict[str, Any]], _y: Any = None) -> Selector:
        return self

    def transform(self, X: List[Dict[str, Any]], _y: Any = None) -> List[Dict[str, Any]]:
        return [
            {
                k: v
                for k, v in item.items()
                if (not self.use or k in self.use) and (not self.ignore or k not in self.ignore)
            }
            for item in X
        ]


# -------------------------------------------------------------------------------------------------


class ModelEncoder(Encoder):
    def __init__(self):
        self.model_vectorizer = DictVectorizer(dtype=np.float32, sparse=False, sort=True)
        self.hp_vectorizer = DictVectorizer(dtype=np.float32, sparse=False, sort=True)
        self.pipeline = make_pipeline(
            NanImputer(),
            make_union(
                make_pipeline(
                    Selector(use={"model"}),
                    self.model_vectorizer,
                ),
                make_pipeline(
                    Selector(ignore={"model"}),
                    self.hp_vectorizer,
                    SimpleImputer(strategy="mean"),
                    StandardScaler(),
                ),
            ),
        )

    @property
    def feature_names_(self) -> List[str]:
        return self.model_vectorizer.feature_names_ + self.hp_vectorizer.feature_names_

    def fit(self, X: List[Config], _y: Any = None) -> ModelEncoder:
        self.pipeline.fit([x.model.asdict() for x in X])
        return self

    def transform(self, X: List[Config], _y: Any = None) -> np.ndarray:
        return self.pipeline.transform([x.model.asdict() for x in X])


class DatasetStatisticsEncoder(Encoder):
    def __init__(self):
        self.vectorizer = DictVectorizer(dtype=np.float32, sparse=False, sort=True)
        self.pipeline = make_union(
            make_pipeline(
                Selector(use={"integer_dataset"}),
                DictVectorizer(dtype=np.float32, sparse=False, sort=True),
            ),
            make_pipeline(
                Selector(ignore={"integer_dataset"}),
                self.vectorizer,
                MinMaxScaler(),  # required for numerical stability
                PowerTransformer(),
            ),
        )

    @property
    def feature_names_(self) -> List[str]:
        return ["integer_dataset"] + self.vectorizer.feature_names_

    def fit(self, X: List[Config], _y: Any = None) -> DatasetStatisticsEncoder:
        self.pipeline.fit([x.dataset.stats() for x in X])
        return self

    def transform(self, X: List[Config], _y: Any = None) -> np.ndarray:
        return self.pipeline.transform([x.dataset.stats() for x in X])


class SeasonalNaivePerformanceEncoder(Encoder):
    def __init__(self, tracker: Tracker):
        self.tracker = tracker
        self.scaler = StandardScaler()

    @property
    def feature_names_(self) -> List[str]:
        return ["seasonal_naive_ncrps"]

    def fit(self, X: List[Config], _y: Any = None) -> SeasonalNaivePerformanceEncoder:
        self.scaler.fit(self._get_performance_array(X))
        return self

    def transform(self, X: List[Config], _y: Any = None) -> np.ndarray:
        return self.scaler.transform(self._get_performance_array(X))

    def _get_performance_array(self, X: List[Config]) -> np.ndarray:
        return np.array(
            [
                self.tracker.get_performance(
                    Config(SeasonalNaiveModelConfig(), x.dataset)
                ).mean_weighted_quantile_loss.mean
                for x in X
            ]
        )[:, None]


class DatasetCatch22Encoder(Encoder):
    def __init__(self):
        self.vectorizer = DictVectorizer(dtype=np.float32, sparse=False, sort=True)
        self.pipeline = make_pipeline(
            self.vectorizer,
            PowerTransformer(),
        )

    @property
    def feature_names_(self) -> List[str]:
        return self.vectorizer.feature_names_

    def fit(self, X: List[Config], _y: Any = None) -> DatasetCatch22Encoder:
        datasets = {x.dataset for x in X}
        features = {d: d.catch22().mean().to_dict() for d in datasets}
        self.pipeline.fit([features[x.dataset] for x in X])
        return self

    def transform(self, X: List[Config], _y: Any = None) -> np.ndarray:
        datasets = {x.dataset for x in X}
        features = {d: d.catch22().mean().to_dict() for d in datasets}
        return self.pipeline.transform([features[x.dataset] for x in X])
