# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#           http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

if os.environ.get("TFT_SCRIPTING", False):
    from torch.nn import LayerNorm
else:
    from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm


class MaybeLayerNorm(nn.Module):
    def __init__(self, output_size, hidden_size, eps):
        super().__init__()
        if output_size and output_size == 1:
            self.ln = nn.Identity()
        else:
            self.ln = LayerNorm(
                output_size if output_size else hidden_size, eps=eps
            )

    def forward(self, x):
        return self.ln(x)


class GLU(nn.Module):
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.lin = nn.Linear(hidden_size, output_size * 2)

    def forward(self, x: Tensor) -> Tensor:
        x = self.lin(x)
        x = F.glu(x)
        return x


class GRN(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size=None,
        context_hidden_size=None,
        dropout=0,
    ):
        super().__init__()

        self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)
        self.lin_a = nn.Linear(input_size, hidden_size)
        if context_hidden_size is not None:
            self.lin_c = nn.Linear(context_hidden_size, hidden_size, bias=False)
        self.lin_i = nn.Linear(hidden_size, hidden_size)
        self.glu = GLU(hidden_size, output_size if output_size else hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = (
            nn.Linear(input_size, output_size) if output_size else None
        )

    def forward(self, a: Tensor, c: Optional[Tensor] = None):
        x = self.lin_a(a)
        if c is not None:
            x = x + self.lin_c(c).unsqueeze(1)
        x = F.elu(x)
        x = self.lin_i(x)
        x = self.dropout(x)
        x = self.glu(x)
        y = a if not self.out_proj else self.out_proj(a)
        x = x + y
        x = self.layer_norm(x)
        return x


class TFTEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.s_cat_inp_lens = config.static_categorical_inp_lens
        self.t_cat_k_inp_lens = config.temporal_known_categorical_inp_lens
        self.t_cat_o_inp_lens = config.temporal_observed_categorical_inp_lens
        self.s_cont_inp_size = config.static_continuous_inp_size
        self.t_cont_k_inp_size = config.temporal_known_continuous_inp_size
        self.t_cont_o_inp_size = config.temporal_observed_continuous_inp_size
        self.t_tgt_size = config.temporal_target_size

        self.hidden_size = config.hidden_size

        # There are 7 types of input:
        # 1. Static categorical
        # 2. Static continuous
        # 3. Temporal known a priori categorical
        # 4. Temporal known a priori continuous
        # 5. Temporal observed categorical
        # 6. Temporal observed continuous
        # 7. Temporal observed targets (time series obseved so far)

        self.s_cat_embed = (
            nn.ModuleList(
                [nn.Embedding(n, self.hidden_size) for n in self.s_cat_inp_lens]
            )
            if self.s_cat_inp_lens
            else None
        )
        self.t_cat_k_embed = (
            nn.ModuleList(
                [
                    nn.Embedding(n, self.hidden_size)
                    for n in self.t_cat_k_inp_lens
                ]
            )
            if self.t_cat_k_inp_lens
            else None
        )
        self.t_cat_o_embed = (
            nn.ModuleList(
                [
                    nn.Embedding(n, self.hidden_size)
                    for n in self.t_cat_o_inp_lens
                ]
            )
            if self.t_cat_o_inp_lens
            else None
        )

        self.s_cont_embedding_vectors = (
            nn.Parameter(torch.Tensor(self.s_cont_inp_size, self.hidden_size))
            if self.s_cont_inp_size
            else None
        )
        self.t_cont_k_embedding_vectors = (
            nn.Parameter(torch.Tensor(self.t_cont_k_inp_size, self.hidden_size))
            if self.t_cont_k_inp_size
            else None
        )
        self.t_cont_o_embedding_vectors = (
            nn.Parameter(torch.Tensor(self.t_cont_o_inp_size, self.hidden_size))
            if self.t_cont_o_inp_size
            else None
        )
        self.t_tgt_embedding_vectors = nn.Parameter(
            torch.Tensor(self.t_tgt_size, self.hidden_size)
        )

        self.s_cont_embedding_bias = (
            nn.Parameter(torch.zeros(self.s_cont_inp_size, self.hidden_size))
            if self.s_cont_inp_size
            else None
        )
        self.t_cont_k_embedding_bias = (
            nn.Parameter(torch.zeros(self.t_cont_k_inp_size, self.hidden_size))
            if self.t_cont_k_inp_size
            else None
        )
        self.t_cont_o_embedding_bias = (
            nn.Parameter(torch.zeros(self.t_cont_o_inp_size, self.hidden_size))
            if self.t_cont_o_inp_size
            else None
        )
        self.t_tgt_embedding_bias = nn.Parameter(
            torch.zeros(self.t_tgt_size, self.hidden_size)
        )

        if self.s_cont_embedding_vectors is not None:
            torch.nn.init.xavier_normal_(self.s_cont_embedding_vectors)
        if self.t_cont_k_embedding_vectors is not None:
            torch.nn.init.xavier_normal_(self.t_cont_k_embedding_vectors)
        if self.t_cont_o_embedding_vectors is not None:
            torch.nn.init.xavier_normal_(self.t_cont_o_embedding_vectors)
        torch.nn.init.xavier_normal_(self.t_tgt_embedding_vectors)

    def _apply_embedding(
        self,
        cat: Optional[Tensor],
        cont: Optional[Tensor],
        cat_emb: Optional[nn.ModuleList],
        cont_emb: Tensor,
        cont_bias: Tensor,
    ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        e_cat = (
            torch.stack(
                [embed(cat[..., i]) for i, embed in enumerate(cat_emb)], dim=-2
            )
            if cat is not None
            else None
        )
        if cont is not None:
            # the line below is equivalent to following einsums
            # e_cont = torch.einsum('btf,fh->bthf', cont, cont_emb)
            # e_cont = torch.einsum('bf,fh->bhf', cont, cont_emb)
            e_cont = torch.mul(cont.unsqueeze(-1), cont_emb)
            e_cont = e_cont + cont_bias
        else:
            e_cont = None

        if e_cat is not None and e_cont is not None:
            return torch.cat([e_cat, e_cont], dim=-2)
        elif e_cat is not None:
            return e_cat
        elif e_cont is not None:
            return e_cont
        else:
            return None

    def forward(self, x: Dict[str, Tensor]):
        # temporal/static categorical/continuous known/observed input
        s_cat_inp = x.get("s_cat", None)
        s_cont_inp = x.get("s_cont", None)
        t_cat_k_inp = x.get("k_cat", None)
        t_cont_k_inp = x.get("k_cont", None)
        t_cat_o_inp = x.get("o_cat", None)
        t_cont_o_inp = x.get("o_cont", None)
        t_tgt_obs = x["target"]  # Has to be present

        # Static inputs are expected to be equal for all timesteps
        # For memory efficiency there is no assert statement
        s_cat_inp = s_cat_inp[:, 0, :] if s_cat_inp is not None else None
        s_cont_inp = s_cont_inp[:, 0, :] if s_cont_inp is not None else None

        s_inp = self._apply_embedding(
            s_cat_inp,
            s_cont_inp,
            self.s_cat_embed,
            self.s_cont_embedding_vectors,
            self.s_cont_embedding_bias,
        )
        t_known_inp = self._apply_embedding(
            t_cat_k_inp,
            t_cont_k_inp,
            self.t_cat_k_embed,
            self.t_cont_k_embedding_vectors,
            self.t_cont_k_embedding_bias,
        )
        t_observed_inp = self._apply_embedding(
            t_cat_o_inp,
            t_cont_o_inp,
            self.t_cat_o_embed,
            self.t_cont_o_embedding_vectors,
            self.t_cont_o_embedding_bias,
        )

        # Temporal observed targets
        # t_observed_tgt = torch.einsum('btf,fh->btfh', t_tgt_obs, self.t_tgt_embedding_vectors)
        t_observed_tgt = torch.matmul(
            t_tgt_obs.unsqueeze(3).unsqueeze(4),
            self.t_tgt_embedding_vectors.unsqueeze(1),
        ).squeeze(3)
        t_observed_tgt = t_observed_tgt + self.t_tgt_embedding_bias

        return s_inp, t_known_inp, t_observed_inp, t_observed_tgt


class VariableSelectionNetwork(nn.Module):
    def __init__(self, config, num_inputs):
        super().__init__()
        self.joint_grn = GRN(
            config.hidden_size * num_inputs,
            config.hidden_size,
            output_size=num_inputs,
            context_hidden_size=config.hidden_size,
        )
        self.var_grns = nn.ModuleList(
            [
                GRN(
                    config.hidden_size,
                    config.hidden_size,
                    dropout=config.dropout,
                )
                for _ in range(num_inputs)
            ]
        )

    def forward(self, x: Tensor, context: Optional[Tensor] = None):
        Xi = x.reshape(*x.shape[:-2], -1)
        grn_outputs = self.joint_grn(Xi, c=context)
        sparse_weights = F.softmax(grn_outputs, dim=-1)
        transformed_embed_list = [
            m(x[..., i, :]) for i, m in enumerate(self.var_grns)
        ]
        transformed_embed = torch.stack(transformed_embed_list, dim=-1)
        # the line below performs batched matrix vector multiplication
        # for temporal features it's bthf,btf->bth
        # for static features it's bhf,bf->bh
        variable_ctx = torch.matmul(
            transformed_embed, sparse_weights.unsqueeze(-1)
        ).squeeze(-1)

        return variable_ctx, sparse_weights


class StaticCovariateEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vsn = VariableSelectionNetwork(config, config.num_static_vars)
        self.context_grns = nn.ModuleList(
            [
                GRN(
                    config.hidden_size,
                    config.hidden_size,
                    dropout=config.dropout,
                )
                for _ in range(4)
            ]
        )

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        variable_ctx, sparse_weights = self.vsn(x)

        # Context vectors:
        # variable selection context
        # enrichment context
        # state_c context
        # state_h context
        cs, ce, ch, cc = tuple(m(variable_ctx) for m in self.context_grns)

        return cs, ce, ch, cc


class InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        assert config.hidden_size % config.n_head == 0
        self.d_head = config.hidden_size // config.n_head
        self.qkv_linears = nn.Linear(
            config.hidden_size, (2 * self.n_head + 1) * self.d_head, bias=False
        )
        self.out_proj = nn.Linear(self.d_head, config.hidden_size, bias=False)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.out_dropout = nn.Dropout(config.dropout)
        self.scale = self.d_head**-0.5
        self.register_buffer(
            "_mask",
            torch.triu(
                torch.full(
                    (config.example_length, config.example_length),
                    float("-inf"),
                ),
                1,
            ).unsqueeze(0),
        )

    def forward(
        self, x: Tensor, mask_future_timesteps: bool = True
    ) -> Tuple[Tensor, Tensor]:
        bs, t, h_size = x.shape
        qkv = self.qkv_linears(x)
        q, k, v = qkv.split(
            (self.n_head * self.d_head, self.n_head * self.d_head, self.d_head),
            dim=-1,
        )
        q = q.view(bs, t, self.n_head, self.d_head)
        k = k.view(bs, t, self.n_head, self.d_head)
        v = v.view(bs, t, self.d_head)

        # attn_score = torch.einsum('bind,bjnd->bnij', q, k)
        attn_score = torch.matmul(
            q.permute((0, 2, 1, 3)), k.permute((0, 2, 3, 1))
        )
        attn_score.mul_(self.scale)

        if mask_future_timesteps:
            attn_score = attn_score + self._mask

        attn_prob = F.softmax(attn_score, dim=3)
        attn_prob = self.attn_dropout(attn_prob)

        # attn_vec = torch.einsum('bnij,bjd->bnid', attn_prob, v)
        attn_vec = torch.matmul(attn_prob, v.unsqueeze(1))
        m_attn_vec = torch.mean(attn_vec, dim=1)
        out = self.out_proj(m_attn_vec)
        out = self.out_dropout(out)

        return out, attn_vec


class TemporalFusionTransformer(nn.Module):
    """Implementation of https://arxiv.org/abs/1912.09363."""

    def __init__(self, config):
        super().__init__()

        if hasattr(config, "model"):
            config = config.model

        self.encoder_length = (
            config.encoder_length
        )  # this determines from how distant past we want to use data from

        self.embedding = TFTEmbedding(config)
        self.static_encoder = StaticCovariateEncoder(config)

        self.history_vsn = VariableSelectionNetwork(
            config, config.num_historic_vars
        )
        self.history_encoder = nn.LSTM(
            config.hidden_size, config.hidden_size, batch_first=True
        )
        self.future_vsn = VariableSelectionNetwork(
            config, config.num_future_vars
        )
        self.future_encoder = nn.LSTM(
            config.hidden_size, config.hidden_size, batch_first=True
        )

        self.input_gate = GLU(config.hidden_size, config.hidden_size)
        self.input_gate_ln = LayerNorm(config.hidden_size, eps=1e-3)

        self.enrichment_grn = GRN(
            config.hidden_size,
            config.hidden_size,
            context_hidden_size=config.hidden_size,
            dropout=config.dropout,
        )
        self.attention = InterpretableMultiHeadAttention(config)
        self.attention_gate = GLU(config.hidden_size, config.hidden_size)
        self.attention_ln = LayerNorm(config.hidden_size, eps=1e-3)

        self.positionwise_grn = GRN(
            config.hidden_size, config.hidden_size, dropout=config.dropout
        )

        self.decoder_gate = GLU(config.hidden_size, config.hidden_size)
        self.decoder_ln = LayerNorm(config.hidden_size, eps=1e-3)

        self.quantile_proj = nn.Linear(
            config.hidden_size, len(config.quantiles)
        )

    def forward(self, x: Dict[str, Tensor]) -> Tensor:
        s_inp, t_known_inp, t_observed_inp, t_observed_tgt = self.embedding(x)

        # Static context
        cs, ce, ch, cc = self.static_encoder(s_inp)
        ch, cc = ch.unsqueeze(0), cc.unsqueeze(0)  # lstm initial states

        # Temporal input
        _historical_inputs = [
            t_known_inp[:, : self.encoder_length, :],
            t_observed_tgt[:, : self.encoder_length, :],
        ]
        if t_observed_inp is not None:
            _historical_inputs.insert(
                0, t_observed_inp[:, : self.encoder_length, :]
            )

        historical_inputs = torch.cat(_historical_inputs, dim=-2)
        future_inputs = t_known_inp[:, self.encoder_length :]

        # Encoders
        historical_features, _ = self.history_vsn(historical_inputs, cs)
        history, state = self.history_encoder(historical_features, (ch, cc))
        future_features, _ = self.future_vsn(future_inputs, cs)
        future, _ = self.future_encoder(future_features, state)
        torch.cuda.synchronize()  # this call gives perf boost for unknown reasons

        # skip connection
        input_embedding = torch.cat(
            [historical_features, future_features], dim=1
        )
        temporal_features = torch.cat([history, future], dim=1)
        temporal_features = self.input_gate(temporal_features)
        temporal_features = temporal_features + input_embedding
        temporal_features = self.input_gate_ln(temporal_features)

        # Static enrichment
        enriched = self.enrichment_grn(temporal_features, c=ce)

        # Temporal self attention
        x, _ = self.attention(enriched, mask_future_timesteps=True)

        # Don't compute hictorical quantiles
        x = x[:, self.encoder_length :, :]
        temporal_features = temporal_features[:, self.encoder_length :, :]
        enriched = enriched[:, self.encoder_length :, :]

        x = self.attention_gate(x)
        x = x + enriched
        x = self.attention_ln(x)

        # Position-wise feed-forward
        x = self.positionwise_grn(x)

        # Final skip connection
        x = self.decoder_gate(x)
        x = x + temporal_features
        x = self.decoder_ln(x)

        out = self.quantile_proj(x)

        return out
