# Copyright authors of
# FlowState: Sampling Rate Equivariant Time Series Forecasting

"""PyTorch FlowState model."""

from dataclasses import dataclass
from torch._tensor import Tensor
from typing import Optional, Tuple, Union
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    ModelOutput,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)

from configuration_flowstate import FlowStateConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "FlowStateConfig"


FLOWSTATE_PRETRAINED_MODEL_ARCHIVE_LIST = []


FLOWSTATE_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.
    The model implements FlowState from 

    Parameters:
        config ([`FlowStateConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

FLOWSTATE_INPUTS_DOCSTRING = r"""

    Args:
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a forecasting task, this denotes the history/past time series values.
            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
            If `batch_first=False`, the shape of `past_values` is `(seq_length, batch_size, num_input_channels)`
        batch_first (`bool`):
            Indicates whether the `batch_size` or the `seq_length` is the first dimension of `past_values`.
        scale_factor (`float`):
            The scaling factor to adjust the parameter `Delta` of the S5 block and the Functional Basis Decoder.
        prediction_length (`int`, *optional*):
            Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon.
            If not provided, or < 0, one forecasting patch is returned.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


class FlowStateCausalRevIN(nn.Module):
    def __init__(
        self,
        eps=1e-5,
        with_missing=False,
        causal=True,
        sinh=False,
    ):
        """
        Causal RevIN implementation to enable parallel predictions during training of FlowState

        :param eps: a value added for numerical stability
        :param with_missing (bool): whether contiguous patch masking (CPM) is used or not, interpreting nans as missing values
        """
        super(FlowStateCausalRevIN, self).__init__()
        self.causal = causal
        self.eps = eps
        self.missing = with_missing
        self.sinh = sinh

    def forward(self, x, mode: str):
        if mode == "norm":
            self._get_statistics(x)
            if self.sinh:
                x = torch.asinh(self._normalize(x))
            else:
                x = self._normalize(x)
        elif mode == "denorm":
            if self.sinh:
                x = self._denormalize(torch.sinh(x))
            else:
                x = self._denormalize(x)
        elif mode == "transform":
            if self.sinh:
                x = torch.asinh(self._normalize(x))
            else:
                x = self._normalize(x)
        else:
            raise NotImplementedError
        return x

    def _get_statistics(self, x):
        if not self.causal:
            if self.missing:
                # mask x with nan
                x_ = torch.where(x[...,-1] == 1, torch.nan, x[...,0])
                self.mean = torch.nanmean(x_, dim=1).unsqueeze(-1)
                self.stdev = torch.sqrt(torch.clamp(torch.nanmean((x_ - self.mean)**2, dim=1), min=self.eps)).unsqueeze(-1).unsqueeze(-1)
                self.mean.unsqueeze_(-1)
                return
            else:
                raise NotImplementedError()
        if self.missing:
            n = torch.cumsum(1 - x[..., -1], dim=1).unsqueeze(-1)
            n = torch.where(n == 0, 1.0, n) # no div by zero
        else:
            n = torch.arange(1, x.shape[1] + 1, device=x.device).unsqueeze(-1)
        self.mean = (torch.cumsum(x, dim=1) / n).detach()
        mask = 1.0 if not self.missing else 1 - x[:, :, 1:]
        self.stdev = torch.sqrt(torch.clamp(torch.cumsum(((x - self.mean) * mask) ** 2, 1) / n, min=self.eps)).detach()
        if self.missing:
            self.mean = self.mean[..., :-1]
            self.stdev = self.stdev[..., :-1]

    def set_statistics(self, mean, stdev):
        self.mean = mean
        self.stdev = stdev

    def _normalize(self, x):
        if x.ndim == 4:
            if self.causal:
                self.stdev = self.stdev[:, -x.shape[0] :].transpose(0, 1).unsqueeze(2)
                self.mean = self.mean[:, -x.shape[0] :].transpose(0, 1).unsqueeze(2)
            else:
                self.stdev = self.stdev.unsqueeze(0)
                self.mean = self.mean.unsqueeze(0)
        if x.shape[-1] == self.mean.shape[-1] + 1:  # with missing and not target
            x[..., :-1] = (x[..., :-1] - self.mean) / self.stdev
            # apply mask again after normalization
            x[..., :-1] *= 1 - x[..., -1].unsqueeze(-1)
        else:
            x = (x - self.mean) / self.stdev
        return x

    def _denormalize(self, x):
        if self.causal:
            self.stdev = self.stdev[:, -x.shape[0] :].transpose(0, 1).unsqueeze(2)
            self.mean = self.mean[:, -x.shape[0] :].transpose(0, 1).unsqueeze(2)
        else:
            self.stdev = self.stdev.unsqueeze(0)
            self.mean = self.mean.unsqueeze(0)
        if x.ndim == 5:  # quantile predictions
            return x * self.stdev.unsqueeze(-2) + self.mean.unsqueeze(-2)
        x = x * self.stdev

        return x


class FlowStateEmbedding(nn.Module):
    def __init__(
        self,
        n_inputs: int,
        embedding_feature_dim: int,
        with_missing: bool,
    ):
        """
        Linear input embedding layer

        Args:
            n_inputs (int): The number of input features or channels
            embedding_feature_dim (int): The embedding dimension
        """
        super(FlowStateEmbedding, self).__init__()
        self.embed = nn.Linear(n_inputs, embedding_feature_dim, bias=True)
        self.with_missing = with_missing
        if with_missing:
            self.nan_embedding = nn.Parameter(
            data=torch.randn(1, 1, embedding_feature_dim) / embedding_feature_dim,
            requires_grad=True
        )

    def forward(self, x):
        """
        Args:
            x `torch.FloatTensor` of shape `(seq_length, batch_size, num_input_channels)`: The normalized context time series

        Returns:
            `torch.FloatTensor` of shape `(seq_length, batch_size, embedding_feature_dim)`
        """
        if self.with_missing:
            out = self.embed(x[...,:-1])
            out = out * (1 - x[...,-1:]) + self.nan_embedding * x[...,-1:]
        else:
            out = self.embed(x)
        return out


# https://github.com/goroda/PyTorchPoly/blob/master/poly.py adapted for handing batches
def FlowStateLegendreBasis(x, degree):
    """
    Legendre basis functions used in the Functional Basis Decoder.

    Args:
        x (`torch.FloatTensor` of shape `(decoder_dim)`): A batch of discrete time vectors
        degree (`int`): The degree of the polynomial to use

    Returns:
        `torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`: The basis function values at x
    """
    retvar = torch.ones(*x.shape, degree + 1).type(x.type()).to(x.device)
    if retvar.ndim == 3:
        retvar = retvar.permute(1, 2, 0)
        x = x.transpose(0, -1)
    # retvar[:, 0] = x * 0 + 1
    if degree > 0:
        retvar[:, 1] = x
        for ii in range(1, degree):
            retvar[:, ii + 1] = ((2 * ii + 1) * x * retvar[:, ii] - ii * retvar[:, ii - 1]) / (ii + 1)
    if retvar.ndim == 3:
        retvar = retvar.permute(2, 0, 1)
    return retvar


def FlowStateFourierBasis(x, degree):
    """
    Fourier basis functions used in the Functional Basis Decoder.

    Args:
        x (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`): A batch of discrete time vectors
        degree (`int`): The degree of the polynomial to use

    Returns:
        `torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`: The basis function values at x
    """
    retvar = torch.ones(*x.shape, degree + 1).type(x.type()).to(x.device)
    if retvar.ndim == 3:
        retvar = retvar.permute(1, 2, 0)
        x = x.transpose(0, -1)
    if degree > 0:
        t2 = torch.einsum("t...,n->tn...", x, 2 * torch.pi * torch.arange(1, degree // 2 + 1).to(x.device))
        retvar[:, 1::2] = torch.sin(t2)
        retvar[:, 2::2] = torch.cos(t2)
    if retvar.ndim == 3:
        retvar = retvar.permute(2, 0, 1)
    return retvar


@dataclass
class FlowStateEncoderOutput(ModelOutput):
    """
    Base class for `FlowStateEncoderOutput`, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(1, batch_size, encoder_state_dim)`):
            Hidden-state at the output of the last layer of the model.
            These are the outputs of the last S5 layer.
        hidden_states (Tuple[`torch.FloatTensor`], the first of shape `(encoder_state_dim, batch_size, embedding_feature_dim)`,
                                                   the second to the second to last of shape `(encoder_state_dim, batch_size, encoder_state_dim)`,
                                                   and the last of shape `(1, batch_size, encoder_state_dim)`):
            Hidden-states of the model at the output of each layer.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class FlowStateDecoderOutput(ModelOutput):
    """
    Base class for `FlowStateDecoderOutput`.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(num_channels, batch_size, len(quantiles), prediction_length, 1)`):
            Hidden-state at the output of the decoder.
            These are the final outputs of the decoder after sampling considering the scaling factor.
        hidden_states (Tuple[`torch.FloatTensor`] of one element of shape `(encoder_state_dim, batch_size, embedding_feature_dim)`):
            Hidden-state of the decoder before `sampling`.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class FlowStateModelOutput(ModelOutput):
    """
    Base class for model's outputs, with potential hidden states.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(num_channels, batch_size, len(quantiles), prediction_length, 1)`):
            Final output of the model, after denormalization
        hidden_states (`tuple(torch.FloatTensor)`, *optional*):
            Hidden-states of the encoder and of the decoder
        embedded_input (`torch.FloatTensor` of shape `(seq_len, batch_size, num_channels)`):
            Inputs of the encoder, result of the embedding layer.
        embedded_output (`torch.FloatTensor` of shape `(seq_len, batch_size, embedding_feature_dim)`):
            Outputs of the encoder, result of the embedding layer.
        backbone_hidden_state (`torch.FloatTensor` of shape `(1, batch_size, encoder_state_dim)`):
            Last hidden state at the output of the backbone before passing through the decoder
        decoder_hidden_state (`torch.FloatTensor` of shape `(num_channels, batch_size, len(quantiles), prediction_length, 1)`):
            Last hidden state of the decoder embeddings.
    """

    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    embedded_input: torch.FloatTensor = None
    embedded_output: torch.FloatTensor = None
    backbone_hidden_state: torch.FloatTensor = None
    decoder_hidden_state: torch.FloatTensor = None


class FlowStateS5Block(nn.Module):
    def __init__(self, config, last=False):
        super().__init__()
        self.last = last
        self.real = False
        self.mamba_like_flow = False
        self.input_gating = False
        self.output_gating = True
        self.selective_delta = False
        self.config = config
        self.scalar_a_real = False
        if self.mamba_like_flow:
            self.in_proj = nn.Linear(config.embedding_feature_dim, 4 * config.embedding_feature_dim) # z, x, b_gate, c_gate
        if self.input_gating:
            self.input_gate = nn.Linear(config.embedding_feature_dim, config.embedding_feature_dim)
        if self.output_gating:
            self.output_gate = nn.Linear(config.embedding_feature_dim, config.embedding_feature_dim)
        if self.selective_delta:
            self.delta_gate = nn.Linear(config.embedding_feature_dim, config.encoder_state_dim, bias=False)
            dt_init_std = self.config.encoder_state_dim**-0.5
            nn.init.uniform_(self.delta_gate.weight, -dt_init_std, dt_init_std) # from Mamba
        self.init_params()

    def init_params(self):
        with torch.no_grad():
            state_dim = self.config.encoder_state_dim
            H = self.config.embedding_feature_dim
            encoder_num_hippo_blocks = self.config.encoder_num_hippo_blocks
            if state_dim % encoder_num_hippo_blocks != 0:
                raise ValueError("encoder_state_dim has to be divisible by encoder_num_hippo_blocks.")
            block_size = int(state_dim / encoder_num_hippo_blocks)
            # blockwise Hippo-N diagonalized
            n = torch.sqrt(2 * (torch.arange(block_size) + 1.0) + 1)
            A = -torch.outer(n, n) / 2
            A = -0.5 * torch.eye(block_size) + torch.triu(A) - torch.tril(A)
            if self.real:
                A = -torch.arange(2, block_size + 2) * torch.eye(block_size) # s4d-real
            Lambda, V = torch.linalg.eig(torch.block_diag(*encoder_num_hippo_blocks * [A]))
            if self.scalar_a_real:
                self.log_Lambda_real = torch.nn.Parameter(torch.log(-Lambda.real.mean()))
            else:
                self.log_Lambda_real = torch.nn.Parameter(torch.log(-Lambda.real))
            if self.real:
                self.Lambda_imag = torch.tensor(0.) # not a trainable parameter, just set to zero for simplicity
            else:
                self.Lambda_imag = torch.nn.Parameter(Lambda.imag)

            B = 0.5 / torch.sqrt(torch.tensor(H)) * (torch.randn(state_dim, H) + torch.randn(state_dim, H) * 1.0j)
            C = (
                0.5
                / torch.sqrt(torch.tensor(state_dim))
                * (torch.randn(H, state_dim) + torch.randn(H, state_dim) * 1.0j)
            )
            if self.real:
                self.B_tilde_r = nn.Parameter(B.real) # for real case (from s4d-real) diagonalization not necessary
                self.B_tilde_i = torch.tensor(0.)
            else:
                self.B_tilde_r = nn.Parameter((V.inverse() @ B).real)
                self.B_tilde_i = nn.Parameter((V.inverse() @ B).imag)
            if self.real:
                self.C_tilde_r = nn.Parameter(C.real)
                self.C_tilde_i = torch.tensor(0.)
            else:
                self.C_tilde_r = nn.Parameter((C @ V).real)
                self.C_tilde_i = nn.Parameter((C @ V).imag)
            self.D = nn.Parameter(torch.randn(H))
            log_min, log_max = torch.log(torch.tensor(0.001)), torch.log(torch.tensor(0.1))
            log_Delta = log_min + (log_max - log_min) * torch.rand(state_dim)
            if self.selective_delta:
                self.log_Delta = nn.Parameter((log_Delta.exp() + torch.log(-torch.expm1(-log_Delta.exp())))) # inverse of softplus
            else:
                self.log_Delta = nn.Parameter(log_Delta)

    def get_discretized(self, scale_factor, L=None, delta=None):
        """Discretize a diagonalized, continuous-time linear SSM
        Args:
        L (int): length of the sequence
        scale_factor (float32): mult factor for discretization step sizes (b,)
        Returns:
        discretized kernel (complex64), B_bar (complex64) (P,), (P,H)"""
        if L is None:
            L = int((self.config.context_length + 1 - self.config.patch) / self.config.stride - 1e-9) + 1
        if delta is None:
            delta = torch.exp(self.log_Delta) # selective delta
        lambda_ = -torch.exp(self.log_Lambda_real) + 1j * self.Lambda_imag
        device = lambda_.device
        B_tilde = self.B_tilde_r + 1.0j * self.B_tilde_i
        scale_factor = scale_factor.repeat((lambda_.shape[0], 1)).T
        scale_factor[:,int(self.config.ssm_invariant_fraction*scale_factor.shape[1]):] = 1.
        log_Lambda_bar = scale_factor.unsqueeze(1) * lambda_[None,None,:] * delta
        if self.selective_delta:
            kernel = F.pad(torch.cumsum(log_Lambda_bar[:,1:], dim=1), pad=(0,0,1,0))
            kernel = kernel.flip(1).exp()
        else:
            kernel = torch.einsum("bLd,L->bLd", log_Lambda_bar, torch.arange(L - 1, -1, -1, device=device)).exp()
        if self.config.discretization == "euler" or self.config.discretization == "trap":
            B_bar = (scale_factor * self.log_Delta.exp())[..., None] * B_tilde
        elif self.config.discretization == "bilinear":
            B_bar = ((1 - 0.5 * scale_factor * self.log_Delta.exp() * lambda_) ** -1)[..., None] * B_tilde
        else:
            if self.selective_delta:
                B_bar = ( log_Lambda_bar.exp() - 1.0, 1 / lambda_)
            else:
                B_bar = (1 / lambda_ * (kernel[:,-2] - 1.0))[..., None] * B_tilde

        return kernel, B_bar

    def apply_ssm_kern_ff(self, Bu_elements, kernel):
        """Compute the LxBxH output of discretized SSM given an LxBxH input.
        Args:
        Bu_elements (float32): projected input sequence of features (L, B, H)
        Lambda_bar (float32): discretized
        Returns:
        ys (float32): the SSM outputs (S5 layer preactivations) (L, H)"""
        l, b, d = Bu_elements.shape
        if kernel.ndim == 2:
            kernel = kernel.unsqueeze(0)
        kff = torch.fft.fft(kernel.transpose(0, 1).flip(dims=(0,)), n=2 * l, dim=0)
        buff = torch.fft.fft(Bu_elements, n=2 * l, dim=0)
        o = torch.fft.ifft(kff * buff, n=2 * l, dim=0)[:l]

        if self.last and self.config.min_context == Bu_elements.shape[0]:  # only last hidden state will be used
            return o[-1].unsqueeze(0)
        elif self.last:
            return o[min(self.config.min_context, o.shape[0]) - 1 :]
        else:
            return o

    def forward(self, input_sequences: Tensor, scale_factor: Tensor):
        """Computes LxBxH output sequence of an S5 layer given LxBxH input sequence.
        Args:
        params: tuple of the continuous time SSM parameters
        input_sequences: batch of input feature sequences (L, B ,H)
        Returns:
        Batch of S5 layer output sequences (L, B, H)"""
        L, b, d = input_sequences.shape
        if self.input_gating:
            input_sequences = F.sigmoid(self.input_gate(input_sequences)) * input_sequences

        if self.mamba_like_flow:
            z, input_sequences, ingate, outgate = self.in_proj(input_sequences).chunk(4, dim=-1)
            input_sequences = input_sequences * F.sigmoid(ingate) # more like LSTM than Mamba
        
        if self.selective_delta:
            delta = F.softplus(self.delta_gate(input_sequences) + self.log_Delta).transpose(0,1)
        else: delta = torch.exp(self.log_Delta)

        kernel, B_bar = self.get_discretized(scale_factor, L=input_sequences.shape[0], delta=delta)
        if self.selective_delta:
            if self.real: raise NotImplementedError("Combination of selective and real not implemented")
            Bu_elements = torch.einsum("n,bln,nm,lbm->lbn", B_bar[1], B_bar[0], self.B_tilde_r + 1.0j * self.B_tilde_i, input_sequences + 0.j)
        else:
            if self.real:
                Bu_elements = torch.einsum("bnm,lbm->lbn", B_bar.real, input_sequences)
                kernel = kernel.real # imag is zero
            else:
                Bu_elements = torch.einsum("bnm,lbm->lbn", B_bar.real, input_sequences) + torch.einsum("bnm,lbm->lbn", B_bar.imag, input_sequences) * 1.j

        if self.config.discretization == "trap":
            # correction term for second order discretization
            Bu_elements = 0.5 * Bu_elements + 0.5 * kernel.unsqueeze(0)[...,-2,:] * torch.cat([Bu_elements[None,0], Bu_elements[:-1]], dim=0)
        xs = self.apply_ssm_kern_ff(Bu_elements, kernel)
        if self.last:
            input_sequences = input_sequences[min(self.config.min_context, input_sequences.shape[0]) - 1 :]
            if self.mamba_like_flow:
                outgate = outgate[-input_sequences.shape[0]:]
                z = z[-input_sequences.shape[0]:]
        # Compute SSM output sequence
        if self.real:
            xs = torch.einsum("hn,...bn->...bh", self.C_tilde_r, xs.real) 
        else:
            xs = torch.einsum("hn,...bn->...bh", self.C_tilde_r, xs.real) + torch.einsum("hn,...bn->...bh", self.C_tilde_i, -xs.imag)

        if self.mamba_like_flow:
            xs = xs * F.sigmoid(outgate) # is this repetitive with z, instead use outgate? or complex and above?
        if self.output_gating:
            xs = F.sigmoid(self.output_gate(xs)) * xs
        xs += torch.einsum("h,...bh->...bh", self.D, input_sequences)

        if self.mamba_like_flow:
            xs * F.silu(z)
        return xs


class FlowStateS5Layer(nn.Module):
    def __init__(self, config, last=False, ssm=True):
        super().__init__()
        self.config = config
        n = config.embedding_feature_dim
        self.last = last
        self.ssm = FlowStateS5Block(config, last=last)
        self.out = nn.Linear(n, n)
        self.norm = nn.LayerNorm(n)

    def forward(self, x, scale_factor):
        skip = (
            x if (not self.last or x.ndim == 2) else x[min(self.config.min_context, x.shape[0]) - 1 :]
        )  # last layer doesn't need MLP on all timesteps
        # SSM
        x = self.ssm(x, scale_factor)

        # self gated MLP
        x = F.selu(x)
        x = x * F.sigmoid(self.out(x))

        # pre layernorm
        x = self.norm(x)
        x = skip + x
        return x


class FlowStatePreTrainedModel(PreTrainedModel):
    # Weight initialization
    config_class = FlowStateConfig
    base_model_prefix = "model"
    main_input_name = "past_values"
    supports_gradient_checkpointing = False

    def _init_weights(self, module):
        """Initialize weights"""

        print("Should not reach here, all parameters should have been initialized!")
        # For training, here would be place to initialize the parameters of FlowState


class FlowStateEncoder(FlowStatePreTrainedModel):
    """
    Encoder for FlowState which inputs time-series and outputs embeddings.

    Args:
        config (`FlowStateConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: FlowStateConfig):
        if config.init_processing is False:
            config.check_and_init_preprocessing()

        super().__init__(config)

        self.use_return_dict = config.use_return_dict

        self.layers = nn.ModuleList(
            [
                FlowStateS5Layer(config, last=(i == config.encoder_num_layers - 1), ssm=True)
                for i in range(config.encoder_num_layers)
            ]
        )

    @replace_return_docstrings(output_type=FlowStateEncoderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        encoder_inputs: torch.Tensor,
        scale_factor: Optional[float] = 1.0,
    ) -> Union[Tuple, FlowStateEncoderOutput]:
        r"""
        Args:
            past_values (`torch.FloatTensor` of shape `(seq_len, batch_size, embedding_feature_dim)`):
                Context values of the time series.
                For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
                it is greater than 1.

        Returns:
            `torch.FloatTensor` of shape `(1, batch_size, encoder_state_dim)`
        """
        if type(scale_factor) is not torch.Tensor or scale_factor.ndim == 0:  # optionally different scale factor per sample
            scale_factor = torch.ones(encoder_inputs.shape[1], device=encoder_inputs.device) * scale_factor
        output = encoder_inputs

        all_hidden_states = []

        # Encoder
        for _, lay in enumerate(self.layers):
            output = lay(output, scale_factor=scale_factor)
            all_hidden_states.append(output)

        return FlowStateEncoderOutput(last_hidden_state=output, hidden_states=all_hidden_states)


class FlowStateFunctionalBasisDecoder(FlowStatePreTrainedModel):
    """
    This is the Functional Basis Decoder (FBD) of FlowState.

    Args:
        config (`FlowStateConfig`, *required*):
            Configuration.
    """

    def __init__(self, config):
        if config.init_processing is False:
            config.check_and_init_preprocessing()

        super().__init__(config)

        n_out = config.decoder_dim
        if "legs" in config.decoder_type.lower() or config.decoder_type == 'fixed':
            self.range = [-1.0, 0.95]
            if config.decoder_type.lower() == "hlegs":
                self.range = [0.0, 0.95]
            self.basis_f = lambda t: FlowStateLegendreBasis(t, n_out - 1)
        elif config.decoder_type.lower() == "four":
            if n_out % 2 == 0:
                raise ValueError("Fourier decoder must have odd dimension.")
            self.range = [0.0, 1.0]
            self.basis_f = lambda t: FlowStateFourierBasis(t, n_out - 1)
        else:
            raise ValueError("Unknown decoder" + str(config.decoder_type))
        n_lin = n_out * len(config.quantiles)
        self.pred_dist = config.decoder_patch_len
        n = config.embedding_feature_dim
        self.config = config
        self.lin = nn.Linear(n, n_lin)

    def get_t(self, scale, pred_dim, device):
        dt = scale * (self.range[1] - self.range[0]) / self.pred_dist
        t = self.range[0] + torch.arange(1, pred_dim + 1, dtype=torch.float, device=device) * dt
        return t

    def get_kernel(self, sampling_factor, target_points, device):
        with torch.no_grad():
            if type(sampling_factor) is torch.Tensor and sampling_factor.ndim > 0:
                # individual factor per sample
                t = torch.stack([self.get_t(sf, target_points, device) for sf in sampling_factor], dim=0).to(device)
            else:
                t = self.get_t(sampling_factor, target_points, device)

            if self.training:
                delta = (self.range[1] - self.range[0]) * sampling_factor / self.pred_dist
                t = torch.clamp(t + torch.randn_like(t) * self.config.decoder_noise *  torch.tensor(delta).to(device).unsqueeze(-1), max=self.range[1])
            t = torch.clamp(t, min=self.range[0])
            f = self.basis_f(t)
        return f.float() / 4.0

    @replace_return_docstrings(output_type=FlowStateDecoderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        encoder_output: FlowStateEncoderOutput,
        prediction_length: int,
        scale_factor: Optional[float] = 1.0,
    ) -> FlowStateDecoderOutput:
        """
        The FBD receives the `encoder_output` (final output of the encoder) a prediction_length and the `scale_factor`.
        It first linearly encodes the `encoder_output`, then samples the continuous basis functions based on `prediction_length`
        and `scale_factor` producing a set of configured discrete basis functions and
        finally applies the linearly encoded `encoder_output` to the discrete basis function by matrix multiplication
        to produce the (time-)scaled prediction for the Forecasting Horizon.

        Args:
            encoder_output (FlowStateEncoderOutput): The output of the encoder.
            prediction_length (int): Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon.
            scale_factor (Optional[float], optional): The scaling factor to adjust the parameter `Delta` of the S5 block and the Functional Basis Decoder, defaults to 1.

        Returns:
            FlowStateDecoderOutput: The final outputs of the decoder.
        """

        if prediction_length is None or scale_factor is None:
            raise ValueError("Provide valid scale factor and Nr. of target points")

        values = self.lin(encoder_output.last_hidden_state)
        hidden_state = values.view((*values.shape[:-1], -1, self.config.decoder_dim))
        W = self.get_kernel(scale_factor, prediction_length, hidden_state.device)
        if W.ndim == 2:
            return FlowStateDecoderOutput(
                last_hidden_state=torch.einsum("...h,ph->...p", hidden_state, W).unsqueeze(-1),
                hidden_states=[hidden_state],
            )
        else:
            return FlowStateDecoderOutput(
                last_hidden_state=torch.einsum("...bqh,bph->...bqp", hidden_state, W).unsqueeze(-1),
                hidden_states=[hidden_state],
            )


@add_start_docstrings(
    "The FlowState Model for time-series forecasting.",
    FLOWSTATE_START_DOCSTRING,
)
class FlowStateModel(FlowStatePreTrainedModel):
    def __init__(self, config: FlowStateConfig):
        if config.init_processing is False:
            config.check_and_init_preprocessing()

        super().__init__(config)

        self.config = config

        n_inputs = 1
        self.norm = FlowStateCausalRevIN(with_missing=config.with_missing, sinh=config.sinh, causal=self.config.cpm_mode!='balanced_naive')

        # n_inputs += 1 if config.with_missing else 0
        self.embed = FlowStateEmbedding(n_inputs, config.embedding_feature_dim, with_missing=config.with_missing)

        if config.trend_expert is None or config.trend_expert == 'FlowState_noscale':
            self.trend_expert = None
        elif config.trend_expert == 'S5':
            trend_config = copy.deepcopy(self.config)
            trend_config.encoder_num_layers = 3 # TODO trend expert config?
            self.trend_decoder = torch.nn.Linear(config.embedding_feature_dim, len(config.quantiles))
            self.trend_expert = FlowStateEncoder(trend_config)
        else:
            raise NotImplementedError(f"Trend expert option: {config.trend_expert} not implemented.")
        self.encoder = FlowStateEncoder(config)
        if config.decoder_type == 'lin':
            self.decoder = torch.nn.Linear(config.embedding_feature_dim, len(config.quantiles))
        else:
            self.decoder = FlowStateFunctionalBasisDecoder(config)

        trainable_paras = sum(p.numel() for p in self.parameters() if p.requires_grad)
        decoder_paras = sum(p.numel() for p in self.decoder.parameters() if p.requires_grad)
        logger.info(f" Total Number of parameters: {trainable_paras * 1e-3}k")
        logger.info(
            f"Number of decoder parameters: {decoder_paras * 1e-3}k ({100 * decoder_paras / trainable_paras:.2f}%)"
        )

    def _add_trend(self, flowstate_out, trend_out):
        para, b, q, f, _ = flowstate_out.shape
        para2, b, q = trend_out.shape
        flowstate_out = torch.stack([flowstate_out[:,:,:,fi] + trend_out[fi:para+fi, :, :,None] for fi in range(f)]).permute(1,2,3,0,4)

        return flowstate_out

    @add_start_docstrings_to_model_forward(FLOWSTATE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlowStateModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        return_dict: Optional[bool] = None,
        scale_factor: Optional[float] = None,
        batch_first: Optional[bool] = None,
        mask_n: Optional[int] = None,
        normalize_inputs: Optional[bool] = None,
    ) -> FlowStateModelOutput:
        r"""
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a forecasting task, this denotes the history/past time series values.
            For univariate time series, `num_input_channels` dimension should be 1. So far only univariate forecasts are supported.
            If `batch_first=False`, the shape of `past_values` is `(seq_length, batch_size, num_input_channels)`
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        scale_factor (`float`):
            The scaling factor to adjust the parameter `Delta` of the S5 block and the Functional Basis Decoder
        batch_first (`bool`):
            Indicates whether the `batch_size` or the `seq_length` is the first dimension of `past_values`.
        mask_n (`int`, *optional*):
            When contiguous patch masking (CPM) is used during prediction, the `mask_n` indicates how many
            elements of `past_values` should be treated as unknown.
        Returns:
            FlowStateModelOutput: The final denormalized prediction of FlowState.

        """
        if batch_first is None:
            batch_first = self.config.batch_first
        if scale_factor is None:
            scale_factor = self.config.scale_factor
        if mask_n is None:
            mask_n = 0

        if past_values.dim() != 3:
            raise ValueError(
                "`past_values` must have 3 dimensions of shape `(sequence_length, batch_size, num_input_channels)`."
            )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if batch_first:
            time_first_past_values = torch.transpose(past_values, 1, 0)
        else:
            time_first_past_values = past_values

        L, batch, n_ch = time_first_past_values.shape
        if n_ch > 1:
            raise RuntimeError("past_values may only contain a single variate / channel.")

        if self.config.with_missing:
            # if input contains nans, fill as missing
            mask = torch.where(time_first_past_values.isnan(), torch.ones_like(time_first_past_values), 0)
            time_first_past_values = torch.nan_to_num(time_first_past_values, 0.0)
            time_first_past_values = torch.cat((time_first_past_values, mask), dim=-1)
            if mask_n > 0:
                apdx = torch.cat((torch.zeros(mask_n, batch, n_ch), torch.ones(mask_n, batch, 1)), dim=-1).to(
                    past_values.device
                )
                time_first_past_values = torch.cat((time_first_past_values, apdx), dim=0)

        # Normalize the inputs
        time_first_past_values = self.norm(time_first_past_values.transpose(0, 1), "norm").transpose(0, 1)

        # Emebd the inputs
        encoder_inputs = self.embed(time_first_past_values)

        # Execute encoder
        encoder_output = self.encoder(encoder_inputs, scale_factor=scale_factor)

        if isinstance(encoder_output, tuple):
            encoder_output = FlowStateEncoderOutput(*encoder_output)
        
        # Execute decoder
        if self.config.decoder_type == 'lin':
            decoder_output = FlowStateDecoderOutput(
                last_hidden_state=self.decoder(encoder_output.last_hidden_state).unsqueeze(-1).unsqueeze(-1),
                hidden_states=[encoder_output.last_hidden_state],
            )
        else:
            dec_scale_factor = scale_factor if self.config.decoder_type != "fixed" else 1.0
            decoder_output = self.decoder(
                encoder_output=encoder_output,
                prediction_length=self.config.target_points if self.training else int(self.config.decoder_patch_len / dec_scale_factor),
                scale_factor=dec_scale_factor,
            )

        if isinstance(decoder_output, tuple):
            encoder_output = FlowStateDecoderOutput(*decoder_output)
        
        # optionally execute trend expert
        if self.trend_expert is not None:
            attach_n = decoder_output.last_hidden_state.shape[3] - 1
            attach_x = self.embed(torch.tensor([0.,1.], device=past_values.device)) # nan embedding
            attach_x = attach_x.repeat((attach_n, batch, 1))
            self.trend_expert.config.min_context = self.config.min_context if self.config.min_context != 0 or attach_n <= 0 else L
            trend_output = self.trend_expert(torch.cat((encoder_inputs, attach_x), dim=0)) # optionally use lfiltered seasonal input
            if isinstance(trend_output, tuple):
                trend_output = FlowStateEncoderOutput(*trend_output)
            trend_decoder_output = self.trend_decoder(trend_output.last_hidden_state)
            decoder_output.last_hidden_state = self._add_trend(decoder_output.last_hidden_state, trend_decoder_output) 

        # denorm during evaluation
        if not self.training:
            pred = self.norm(decoder_output.last_hidden_state, "denorm")
        else:
            pred = decoder_output.last_hidden_state

        if not return_dict:
            return tuple(
                v
                for v in [
                    pred,
                    encoder_output.hidden_states + decoder_output.hidden_states,
                    past_values,
                    encoder_inputs,
                    encoder_output.last_hidden_state,
                    decoder_output.last_hidden_state,
                ]
            )

        return FlowStateModelOutput(
            last_hidden_state=pred, # TODO: Why Last hidden state? not forecast/ prediction
            hidden_states=encoder_output.hidden_states + decoder_output.hidden_states,
            embedded_input=past_values,
            embedded_output=encoder_inputs,
            backbone_hidden_state=encoder_output.last_hidden_state,
            decoder_hidden_state=decoder_output.last_hidden_state,
        )


@dataclass
class FlowStateForPredictionOutput(ModelOutput):
    """
    Output type of [`FlowStateForPredictionOutput`].

    Args:
        loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
            Total loss.
        prediction_outputs (`torch.FloatTensor` of shape `(num_channels, batch_size, len(quantiles), prediction_length, 1)`):
            Prediction output from FlowState model.
        backbone_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Last hidden state at the output of the backbone before passing through the decoder
        decoder_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
            Last hidden state of the decoder embeddings.
    """

    loss: Optional[torch.FloatTensor] = None
    prediction_outputs: torch.FloatTensor = None
    backbone_hidden_state: torch.FloatTensor = None
    decoder_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class FlowStateForPrediction(FlowStatePreTrainedModel):
    r"""
    `FlowState` for forecasting application.

    Args:
        config (`FlowStateConfig`, *required*):
            Configuration.
    """

    def __init__(self, config: FlowStateConfig):
        config.check_and_init_preprocessing()
        super().__init__(config)

        self.config = config

        self.model = FlowStateModel(config)

    def _quantile_autoregressive_expand(self, context: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        """
        Given the previous context and corresponding prediction, combine them and prepare for an autoregressive prediction step from each quantile
        """
        batch_, _, n_ch = context.shape
        batch, n_quants, fl, n_ch = pred.shape
        assert batch_ == batch or batch_ == batch * n_quants # in the first step, one history gets expanded afterwards we use the already expanded history
        if batch_ == batch:
            context = context.repeat(n_quants, 1, 1, 1).transpose(0,1)
        else:
            context = context.unflatten(0, (batch, n_quants))
        new_context = torch.cat((context, pred), dim=2) # b, q, l, d
        new_context = torch.flatten(new_context, end_dim=1)
        return new_context

    def _quantile_autoregressive_contract(self, new_pred: FlowStateModelOutput) -> FlowStateModelOutput:
        """
        Given n_quantiles x n_quantiles prediction contract them to a regular n_quantile prediction FlowStateModelOutput
        """
        batch2, n_quants, fl, n_ch = new_pred.last_hidden_state.shape
        assert batch2 % n_quants == 0
        batch = batch2 // n_quants
        preds = torch.unflatten(new_pred.last_hidden_state, 0, (batch, n_quants)) # unflatten in the same was expand flattened
        preds = preds.flatten(1,2)
        quantile_pred = torch.quantile(preds, torch.tensor(self.config.quantiles, device=preds.device), dim=1)
        return quantile_pred.transpose(0,1) # torch.quantile moves quantile dimension first

    def _combine_cpm_predictions(self, pred, pred_len) -> Tensor:
        n_preds, batch, quants, fl, n_ch = pred.shape
        lastp = pred[-1]
        rest = n_preds % fl - 1
        pred = torch.cat([pred[ix] for ix in range(0, n_preds, fl)], dim=2)
        if rest > 0:
            pred = torch.cat((pred, lastp[:, :, -rest:]), dim=2)
        pred = pred[:, :, :pred_len]
        return pred

    @add_start_docstrings_to_model_forward(FLOWSTATE_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=FlowStateForPredictionOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        past_values: torch.Tensor,
        future_values: Optional[torch.Tensor] = None,
        past_observed_mask: Optional[torch.Tensor] = None,
        future_observed_mask: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_loss: bool = True,
        return_dict: Optional[bool] = None,
        batch_first: Optional[bool] = None,
        scale_factor: Optional[float] = None,
        prediction_length: Optional[int] = None,
        prediction_type: Optional[str] = None,
    ) -> FlowStateForPredictionOutput:
        r"""
        past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
            Context values of the time series. For a forecasting task, this denotes the history/past time series values.
            For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
            greater than 1.
            If `batch_first=False`, the shape of `past_values` is `(seq_length, batch_size, num_input_channels)`
        future_values: (`torch.FloatTensor`, *optional*):
            currently not used.
        past_observed_mask: (`torch.FloatTensor`, *optional*):
            currently not used.
        future_observed_mask: (`torch.FloatTensor`, *optional*):
            currently not used.
        output_hidden_states: (`bool`, *optional*):
            currently not used.
        return_loss: (`bool`, *optional*):
            currently not used.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        batch_first (`bool`, *optional*):
            Indicates whether the `batch_size` or the `seq_length` is the first dimension of `past_values`.
        scale_factor (`float`, *optional*):
            The scaling factor to adjust the parameter `Delta` of the S5 block and the Functional Basis Decoder
        prediction_length (`int`, *optional*):
            Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon.
            If not provided, or < 0, one forecasting patch is returned.
        prediction_type (`str`, *optional*):
            Indicates the desired return type of the model. Can be any either:
            quantile: The predictions for all predicted quantiles is returned
            mean: The mean of the predicted quantiles is returned
            median: The median of the predicted quantiles is returned

        Returns:
            FlowStateModelOutput: The final denormalized prediction of FlowState.
        """
        if batch_first is None:
            batch_first = self.config.batch_first
        if scale_factor is None:
            scale_factor = self.config.scale_factor
        if prediction_length is None:
            prediction_length = self.config.prediction_length
        if prediction_type is None:
            prediction_type = self.config.prediction_type
        if not hasattr(self.config, 'cpm_max'):
            self.config.cpm_max = 1e8

        if past_values.dim() != 3:
            raise ValueError(
                "`past_values` must have 3 dimensions of shape `(sequence_length, batch_size, num_input_channels)`."
            )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        dec_scale_factor = scale_factor if self.config.decoder_type != "fixed" else 1.0
        max_decoder_patch_len = int(self.config.decoder_patch_len / dec_scale_factor + 1e-6)
        max_mpi_mask = 1e8 # int(self.config.cpm_max / scale_factor)
        if prediction_length == -1:
            prediction_length = max_decoder_patch_len if self.config.decoder_type != 'lin' else max_mpi_mask
        # prepare multi patch inferencing
        if self.config.decoder_type == 'lin':
            mask_n = min(max_mpi_mask, prediction_length)
        else:
            mask_n = min(max_mpi_mask, max(0, prediction_length - max_decoder_patch_len))
        max_context = min(16 * 1024, int(self.config.context_length / scale_factor)) - mask_n
        if not batch_first:
            past_values = past_values.transpose(0,1)
        past_values = past_values[:, -max_context:]
        if mask_n > 0:
            self.model.config.min_context = (
                past_values.shape[1] + 1 if self.config.decoder_type == 'lin' else past_values.shape[1]
            )  # min context from which to start predicting
        else:
            self.model.config.min_context = 0
        # past_values: tensor [batch_size x seq_length x num_input_channels], or [seq_length x batch_size x num_input_channels]
        model_output = self.model(
            past_values,
            return_dict=return_dict,
            batch_first=True,
            scale_factor=scale_factor,
            mask_n=mask_n,
        )
        if isinstance(model_output, tuple):
            model_output = FlowStateModelOutput(*model_output)
        model_output.last_hidden_state = self._combine_cpm_predictions(
            model_output.last_hidden_state, prediction_length
        )
        # Autoregressive prediction steps if necessary
        tmp_output = model_output.last_hidden_state
        while model_output.last_hidden_state.shape[2] < prediction_length:
            if self.config.decoder_type == 'lin':
                mask_n = min(max_mpi_mask, prediction_length - model_output.last_hidden_state.shape[2])
            else:
                mask_n = min(max_mpi_mask, max(0, prediction_length - max_decoder_patch_len - model_output.last_hidden_state.shape[2]))
            past_values = self._quantile_autoregressive_expand(past_values, tmp_output)
            max_context = min(16 * 1024, int(self.config.context_length / scale_factor)) - mask_n
            past_values = past_values[:, -max_context:]
            if mask_n > 0:
                self.model.config.min_context = (
                    past_values.shape[1] + 1 if self.config.decoder_type == 'lin' else past_values.shape[1]
                )  # min context from which to start predicting
            else:
                self.model.config.min_context = 0
            tmp_output = self.model(
                past_values,
                return_dict=return_dict,
                batch_first=True,
                scale_factor=scale_factor,
                mask_n=mask_n,
            )
            if isinstance(tmp_output, tuple):
                tmp_output = FlowStateModelOutput(*tmp_output)
            tmp_output.last_hidden_state = self._combine_cpm_predictions(
                tmp_output.last_hidden_state, prediction_length - model_output.last_hidden_state.shape[2])

            tmp_output = self._quantile_autoregressive_contract(tmp_output) 
            model_output.last_hidden_state = torch.cat((model_output.last_hidden_state, tmp_output), dim=2)

        if prediction_type == "quantile":
            # ensure correct order in quantiles
            model_output.last_hidden_state = torch.quantile(model_output.last_hidden_state, torch.tensor(self.config.quantiles).to(past_values.device), 1).transpose(0,1)
        elif prediction_type == "mean":
            # calculate an approximate mean from quantiles
            quant_prob = 0.5 - (0.5 - torch.tensor(self.config.quantiles)).abs()
            quant_prob /= quant_prob.sum()  # normalize quantile weights
            quant_prob = quant_prob.view(1, -1, 1, 1).to(past_values.device)
            model_output.last_hidden_state = (model_output.last_hidden_state * quant_prob).sum(dim=1)
        elif prediction_type == "median":
            if 0.5 not in self.config.quantiles:
                raise RuntimeError("Median requested but not part of the quantiles.")
            ix = self.config.quantiles.index(0.5)
            model_output.last_hidden_state = model_output.last_hidden_state[:, ix, :]
        else:
            raise RuntimeError("Unknown prediction_type detected. Should be one of ['quantile', 'mean', 'median']")

        loss_val = None
        # if not batch_first:
        #     model_output.last_hidden_state.transpose_(0, 1)
        #     model_output.backbone_hidden_state.transpose_(0, 1)
        #     model_output.decoder_hidden_state.transpose_(0, 1)
        #     model_output.hidden_states = [hs.transpose(0, 1) for hs in model_output.hidden_states]
        if not return_dict:
            return tuple(
                v
                for v in [
                    loss_val,
                    model_output.last_hidden_state,
                    model_output.backbone_hidden_state,
                    model_output.decoder_hidden_state,
                    model_output.hidden_states,
                ]
            )

        return FlowStateForPredictionOutput(
            loss=loss_val,
            prediction_outputs=model_output.last_hidden_state,  # tensor [batch_size x prediction_length x num_input_channels]
            backbone_hidden_state=model_output.backbone_hidden_state,
            decoder_hidden_state=model_output.decoder_hidden_state,
            hidden_states=model_output.hidden_states,
        )
