import abc
from dataclasses import dataclass

import numpy as np
import pandas as pd
import reservoirpy as rpy
from datafold import TSCDataFrame
from reservoirpy.model import Model
from reservoirpy.nodes import Reservoir, Ridge

rpy.verbosity(0)
from sklearn.pipeline import Pipeline
from swimnetworks import Dense

from computational_experiments.real_world.src.data_utils import Normalizer
from kirnn import KIRNN

MAX_SEED = np.iinfo(np.int64).max


@dataclass
class BaseModel(abc.ABC):
    target: str
    n_features: int
    time_delay: int
    horizon: int
    rng: np.random.Generator

    _normalizer: Normalizer = None

    def __post_init__(self):
        self._normalizer = Normalizer()

    def train(self, train_df: pd.DataFrame, val_df: pd.DataFrame):
        self._normalizer.fit(train_df)
        train_df = self._normalizer.transform(train_df)
        val_df = self._normalizer.transform(val_df)

        self._train(train_df, val_df)

    def predict(self, df: pd.DataFrame) -> pd.DataFrame:
        normalized_df = self._normalizer.transform(df)
        normalized_predictions = self._predict(normalized_df)
        predictions = self._normalizer.transform(normalized_predictions, inverse=True)
        return predictions

    def compute_metric(
            self, predicted_df: pd.DataFrame, true_df: pd.DataFrame, metric: str = "abs"
    ) -> pd.Series:
        true_values = true_df[self.target][self.time_delay:]
        predicted_values = predicted_df[self.target]

        if true_values.shape != predicted_values.shape:
            raise ValueError(
                "True and predicted values have different shape: "
                f"{true_values.shape=}, "
                f"{predicted_values.shape=}"
            )

        if metric == "abs":
            error = predicted_values - true_values
            return error
        if metric == "rel":
            error = predicted_values - true_values
            true_abs = np.abs(true_values)
            predicted_abs = np.abs(predicted_values)
            return error / np.maximum(true_abs, predicted_abs)

        raise ValueError(f"Unknown metric: {metric}.")

    @abc.abstractmethod
    def _train(self, train_df: pd.DataFrame, val_df: pd.DataFrame):
        pass

    @abc.abstractmethod
    def _predict(self, df: pd.DataFrame):
        pass


class SWIM(BaseModel):
    _model: KIRNN = None

    def __init__(
            self,
            target: str,
            n_features: int,
            time_delay: int,
            horizon: int,
            rng: np.random.Generator,
            layer_width: int,
            activation: str,
            regularization_scale: float,
            **__,
    ):
        super().__init__(
            target=target,
            n_features=n_features,
            time_delay=time_delay,
            horizon=horizon,
            rng=rng,
        )

        steps = [
            (
                "hidden",
                Dense(
                    layer_width=layer_width,
                    activation=activation,
                    parameter_sampler=activation,
                    random_seed=rng.integers(MAX_SEED),
                ),
            ),
        ]
        network_dictionary = Pipeline(steps=steps)
        # Passing 'time_delay-1' to have exactly 'time_delay' timestamps as initial state.
        kirnn_model = KIRNN(
            dictionary=network_dictionary,
            n_features_in=self.n_features,
            rcond=regularization_scale,
            time_delay=time_delay - 1,
        )
        self._model = kirnn_model

    def _train(self, train_df: pd.DataFrame, _: pd.DataFrame):
        train_tsc = self._pandas_to_tsc(train_df)
        try:
            self._model.fit(train_tsc)
        except np.linalg.LinAlgError as e:
            print("LinAlgError occured.")

    def _predict(self, df: pd.DataFrame) -> pd.DataFrame:
        tsc_df = self._pandas_to_tsc(df)
        chunk_size = self.time_delay + self.horizon

        chunk_start = 0
        predictions = []
        while chunk_start + self.time_delay < len(tsc_df):
            chunk = tsc_df[chunk_start: chunk_start + chunk_size]
            delayed = chunk[: self.time_delay]
            t_eval = chunk.time_values()[self.time_delay:]
            chunk_predictions = self._model.predict(delayed, time_values=t_eval)
            # The first prediction is the last value of the delayed chunk.
            predictions.append(chunk_predictions[1:])
            chunk_start += self.horizon

        predictions = np.row_stack(predictions)
        predictions_df = pd.DataFrame(predictions, columns=df.columns)
        predictions_df.set_index(df.index[self.time_delay:], inplace=True)
        return predictions_df

    def _pandas_to_tsc(self, df: pd.DataFrame):
        df = df.copy()
        df["id"] = 0  # identify this as the only time series (assign id=0)
        df.set_index(["id", df.index], inplace=True)
        return TSCDataFrame(df)


class ESN(BaseModel):
    warmup_steps: int
    _model: Model = None
    _reservoir: Reservoir = None
    _reservoir: Ridge = None

    def __init__(
            self,
            target: str,
            n_features: int,
            time_delay: int,
            horizon: int,
            rng: np.random.Generator,
            layer_width: int,
            activation: str,
            regularization_scale: float,
            **__,

    ):
        super().__init__(
            target=target,
            n_features=n_features,
            time_delay=time_delay,
            horizon=horizon,
            rng=rng,
        )

        self._reservoir = Reservoir(layer_width,
                                    input_scaling=1,
                                    sr=0.1,
                                    lr=regularization_scale,
                                    rc_connectivity=0.5,
                                    input_connectivity=0.5,
                                    activation=activation,
                                    seed=rng.integers(MAX_SEED),
                                    )
        self._readout = Ridge(ridge=1e-6)
        self.warmup_steps = time_delay
        self._model = self._reservoir >> self._readout

    def _train(self, train_df: pd.DataFrame, _: pd.DataFrame):
        self._model.fit(train_df.to_numpy()[:-1, :], train_df.to_numpy()[1:, :], warmup=self.warmup_steps)

    def _predict(self, df: pd.DataFrame) -> pd.DataFrame:
        chunk_size = self.time_delay + self.horizon

        chunk_start = 0
        predictions = []
        while chunk_start + self.time_delay < len(df):
            chunk = df[chunk_start: chunk_start + chunk_size]
            delayed = chunk[: self.time_delay]

            # warm-up phase
            y = self._model.run(delayed.to_numpy(), reset=True)  # warming up reservoir
            local_horizon = len(chunk) - len(
                delayed)  # defining local horizon because at the end of the dataset possibly we dont neet a full horizon on predictions
            chunk_predictions = np.zeros((local_horizon, df.shape[-1]))

            chunk_predictions[0, :] = y[-1]
            for i in range(local_horizon - 1):
                chunk_predictions[i + 1, :] = self._model(chunk_predictions[i, :])

            # The first prediction is the last value of the delayed chunk.
            predictions.append(chunk_predictions)
            chunk_start += self.horizon

        predictions = np.row_stack(predictions)
        predictions_df = pd.DataFrame(predictions, columns=df.columns)
        predictions_df.set_index(df.index[self.time_delay:], inplace=True)
        return predictions_df
