import abc
from dataclasses import dataclass
from climate.src.climate_utils import Normalizer
from climate.src.lstm_utils import WindowGenerator, FeedBack, compile_and_fit
from swim_rnn import RNN
import numpy as np
from swimnetworks import Dense
from sklearn.pipeline import Pipeline
import pandas as pd
from datafold import TSCDataFrame


MAX_SEED = np.iinfo(np.int64).max

@dataclass
class BaseModel(abc.ABC):

    target: str = None
    time_delay: int = None
    horizon: int = None 
    rng: np.random.Generator = None

    _normalizer: Normalizer = None

    def __post_init__(self):
        self._normalizer = Normalizer()

    def train(self, train_df: pd.DataFrame, val_df: pd.DataFrame):
        self._normalizer.fit(train_df)
        train_df = self._normalizer.transform(train_df)
        val_df = self._normalizer.transform(val_df)

        self._train(train_df, val_df)

    def predict(self, df: pd.DataFrame) -> pd.DataFrame:
        normalized_df = self._normalizer.transform(df)
        normalized_predictions = self._predict(normalized_df)
        predictions = self._normalizer.transform(normalized_predictions,
                                                 inverse=True)
        return predictions

    def compute_metric(self, predicted_df: pd.DataFrame, true_df: pd.DataFrame, metric: str = "abs") -> pd.Series:
        true_values = true_df[self.target][self.time_delay:]
        predicted_values = predicted_df[self.target]

        if true_values.shape != predicted_values.shape:
            raise ValueError("True and predicted values have different shape: "
                             f"{true_values.shape=}, "
                             f"{predicted_values.shape=}")

        match metric:
            case "abs":
                error = predicted_values - true_values
                return error
            case "rel":
                error = predicted_values - true_values
                true_abs = np.abs(true_values)
                predicted_abs = np.abs(predicted_values)
                return error / np.maximum(true_abs, predicted_abs)
            case _:
                raise ValueError(f"Unknown metric: {metric}.")      

    @abc.abstractmethod
    def _train(self, train_df: pd.DataFrame, val_df: pd.DataFrame):
        pass 

    @abc.abstractmethod
    def _predict(self, df: pd.DataFrame):
        pass 


class SWIM(BaseModel):

    _model: RNN = None

    def __init__(self, target: str,
                 time_delay: int,
                 horizon: int,
                 rng: np.random.Generator, 
                 layer_width: int,
                 activation: str,
                 regularization_scale: float,
                 **__
    ):
        super().__init__(target=target,
                         time_delay=time_delay,
                         horizon=horizon, 
                         rng=rng)
        
        steps = [
            (
                "hidden",
                Dense(
                    layer_width=layer_width,
                    activation=activation,
                    parameter_sampler=activation,
                    random_seed=rng.integers(MAX_SEED),
                ),
            ),
        ]
        network_dictionary = Pipeline(steps=steps)
        # Passing 'time_delay-1' to have exactly 'time_delay' timestamps as initial state.
        swim_rnn = RNN(network_dictionary, regularization_scale, time_delay=time_delay-1)
        self._model = swim_rnn


    def _train(self, train_df: pd.DataFrame, _: pd.DataFrame):
        train_tsc = self._pandas_to_tsc(train_df)
        self._model.fit(train_tsc)

    def _predict(self, df: pd.DataFrame) -> pd.DataFrame:
        tsc_df = self._pandas_to_tsc(df)
        chunk_size = self.time_delay + self.horizon

        chunk_start = 0 
        predictions = []
        while chunk_start + self.time_delay < len(tsc_df):
            chunk = tsc_df[chunk_start:chunk_start+chunk_size]
            delayed = chunk[:self.time_delay]
            t_eval = chunk.time_values()[self.time_delay:]
            chunk_predictions = self._model.predict(delayed, time_values=t_eval)
            # The first prediction is the last value of the delayed chunk.
            predictions.append(chunk_predictions[1:])
            chunk_start += self.horizon

        predictions = np.row_stack(predictions)
        predictions_df = pd.DataFrame(predictions, columns=df.columns)
        predictions_df.set_index(df.index[self.time_delay:], inplace=True)
        return predictions_df
         

    def _pandas_to_tsc(self, df: pd.DataFrame):
        df = df.copy()
        df["id"] = 0  # identify this as the only time series (assign id=0)
        df.set_index(["id", df.index], inplace=True)
        return TSCDataFrame(df)
    
class LSTM(BaseModel):

    learning_rate: int = None
    num_features: int = None
    max_epochs: int = 30
    patience: int = 5

    _window: WindowGenerator = None 
    _model: FeedBack = None

    def __init__(self, target: str,
                 time_delay: int,
                 horizon: int,
                 rng: np.random.Generator, 
                 layer_width: int,
                 activation: str,
                 regularization_scale: float,
                 num_features: int,
                 **__):
        super().__init__(target=target,
                    time_delay=time_delay,
                    horizon=horizon, 
                    rng=rng)
        
        self.learning_rate = regularization_scale
        self.num_features = num_features
        self._window = WindowGenerator(input_width=time_delay, label_width=horizon, shift=horizon)
        self._model = FeedBack(units=layer_width, activation=activation, out_steps=horizon, num_features=num_features)

    
    def _train(self, train_df: pd.DataFrame, val_df: pd.DataFrame):
        train_dataset = self._window.make_dataset(train_df)
        val_dataset = self._window.make_dataset(val_df)
        compile_and_fit(self._model,
                        train_dataset=train_dataset,
                        val_dataset=val_dataset,
                        learing_rate=self.learning_rate,
                        max_epochs=self.max_epochs,
                        patience=self.patience)
        
    
    def _predict(self, df: pd.DataFrame):
        dataset = self._window.make_dataset(df, shuffle=False)
        predictions = self._model.predict(dataset)
        # predictions.shape -> (n_points - time_delay + 1, horizon, num_features)
        goal_length = len(df) - self.time_delay
        full_horizon = predictions[1::self.horizon].reshape(-1, self.num_features)
        tail_size = goal_length % self.horizon
        tail = predictions[-1][-tail_size:]
        merged_predictions = np.row_stack([full_horizon, tail])

        if merged_predictions.shape != (goal_length, self.num_features):
            raise RuntimeError("Something went wrong with the shapes."
                               f"goal: {goal_length} timestamps, "
                               f"got: {len(merged_predictions)}.")

        predictions_df = pd.DataFrame(merged_predictions, columns=df.columns)
        predictions_df.set_index(df.index[self.time_delay:], inplace=True)
        return predictions_df
        

        



