from typing import List, Literal, Optional
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ProgressBar
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from tsbench.config import Config
from tsbench.experiments.metrics import Performance
from tsbench.experiments.tracking import Tracker
from .base import DatasetFeaturesMixin, Surrogate
from .registry import register_surrogate
from .torch import ListMLELoss, MLPLightningModule
from .transformers import ConfigTransformer, PerformanceTransformer


@register_surrogate("mlp")
class MLPSurrogate(Surrogate, DatasetFeaturesMixin):
    """
    The MLP surrogate predicts a model's performance on a new dataset using an MLP. The MLP
    converts inputs into feature vectors of the same size and uses either ranking or regression to
    predict metrics.
    """

    def __init__(
        self,
        objective: Literal["regression", "ranking"] = "regression",
        discount: Optional[Literal["linear", "quadratic"]] = None,
        hidden_layer_sizes: Optional[List[int]] = None,
        weight_decay: float = 0.0,
        dropout: float = 0.0,
        use_simple_dataset_features: bool = True,
        use_seasonal_naive_performance: bool = False,
        use_catch22_features: bool = False,
        predict: Optional[List[str]] = None,
        tracker: Optional[Tracker] = None,
    ):
        """
        Args:
            objective: The optimization objective for the XGBoost estimators.
            discount: The discount to apply for the ranking loss. If provided, it focuses on
                correctly predicting the top values.
            hidden_layer_sizes: The dimensions of the hidden layers. Defaults to no hidden layers,
                i.e. a linear predictor.
            weight_decay: The weight decay to apply during optimization.
            dropout: The dropout probability of dropout layers applied after every activation
                function.
            use_simple_dataset_features: Whether to use dataset features to predict using a
                weighted average.
            use_seasonal_naive_performance: Whether to use the Seasonal Naïve nCRPS as dataset
                featuers. Requires the cacher to be set.
            use_catch22_features: Whether to use catch22 features for datasets statistics. Ignored
                if `use_dataset_features` is not set.
            predict: The metrics to predict. All if not provided.
            tracker: An optional tracker that can be used to impute latency and number of model
                parameters into model performances.
        """
        super().__init__(tracker)

        self.use_ranking = objective == "ranking"
        self.hidden_layer_sizes = hidden_layer_sizes or []
        self.weight_decay = weight_decay
        self.dropout = dropout

        self.config_transformer = ConfigTransformer(
            add_model_features=True,
            add_dataset_statistics=use_simple_dataset_features,
            add_seasonal_naive_performance=use_seasonal_naive_performance,
            add_catch22_features=use_catch22_features,
            tracker=tracker,
        )
        self.performance_transformer = PerformanceTransformer(
            metrics=predict,
        )

        if objective == "regression":
            self.loss = nn.MSELoss()
        elif objective == "ranking":
            self.loss = ListMLELoss(discount=discount)

        # Fitted properties
        self.trainer_: pl.Trainer
        self.model_: nn.Module

    def fit(self, X: List[Config], y: List[Performance]) -> None:
        # Fit transformers to infer dimensionality
        X_numpy = self.config_transformer.fit_transform(X)
        y_numpy = self.performance_transformer.fit_transform(y)

        input_dim = len(self.config_transformer.feature_names_)
        output_dim = len(self.performance_transformer.features_names_)

        # Initialize model
        layer_sizes = [input_dim] + self.hidden_layer_sizes + [output_dim]
        layers = []
        for i, (in_size, out_size) in enumerate(zip(layer_sizes, layer_sizes[1:])):
            if i > 0:
                layers.append(nn.LeakyReLU())
                if self.dropout > 0:
                    layers.append(nn.Dropout(self.dropout))
            layers.append(nn.Linear(in_size, out_size))

        self.model_ = nn.Sequential(*layers)

        # Initialize the data. For this, we also need to compute group IDs for the datasets
        mapping = {d: i for i, d in enumerate({x.dataset for x in X})}
        train_data = TensorDataset(
            torch.from_numpy(X_numpy).float(),
            torch.from_numpy(y_numpy).float(),
            torch.as_tensor([mapping[x.dataset] for x in X], dtype=torch.long),
        )
        train_loader = DataLoader(train_data, batch_size=X_numpy.shape[0])

        # Train the model
        module = MLPLightningModule(self.model_, self.loss, self.weight_decay)
        self.trainer_ = pl.Trainer(
            max_epochs=1000,
            checkpoint_callback=False,
            logger=False,
            weights_summary=None,
            callbacks=[ProgressBar(refresh_rate=0)],
            gpus=int(torch.cuda.is_available()),
        )
        self.trainer_.fit(module, train_dataloaders=train_loader)

    def _predict(self, X: List[Config]) -> List[Performance]:
        # Get data
        X_numpy = self.config_transformer.transform(X)
        test_data = TensorDataset(torch.from_numpy(X_numpy).float())
        test_loader = DataLoader(test_data, batch_size=X_numpy.shape[0])

        # Run prediction
        module = MLPLightningModule(self.model_, self.loss)
        out = self.trainer_.predict(module, test_loader)
        predictions = out[0].numpy()

        return self.performance_transformer.inverse_transform(predictions)
