from torch import nn
import torch

from src.lib.nn.utils import maybe_cat_emb
from tsl.nn.blocks import RNN, MLPDecoder, LinearReadout
from tsl.nn.layers import DiffConv

from tsl.nn.layers.base import NodeEmbedding
from tsl.nn.models import BaseModel
from tsl.nn.utils import maybe_cat_exog


class STGNNModel(BaseModel):
    return_type = torch.Tensor

    r""""""

    def __init__(
        self,
        input_size: int,
        horizon: int,
        n_nodes: int,
        hidden_size: int,
        emb_size: int,
        temporal_layers: int,
        gnn_layers: int,
        output_size: int = None,
        exog_size: int = 0,
        activation: str = "elu",
    ):
        super(STGNNModel, self).__init__()

        self.input_size = input_size
        self.output_size = output_size or input_size
        self.horizon = horizon

        self.emb = NodeEmbedding(n_nodes=n_nodes, emb_size=emb_size)

        self.input_encoder = nn.Linear(
            input_size + exog_size + emb_size,
            hidden_size,
        )

        self.temporal_encoder = RNN(
            input_size=hidden_size,
            hidden_size=hidden_size,
            return_only_last_state=True,
            n_layers=temporal_layers,
        )

        self.gnn_layers = nn.ModuleList(
            [
                DiffConv(
                    in_channels=hidden_size,
                    out_channels=hidden_size,
                    add_backward=True,
                    k=1,
                    activation=activation,
                )
                for _ in range(gnn_layers)
            ]
        )

        self.decoder = LinearReadout(
            input_size=2 * hidden_size,
            output_size=self.output_size,
            horizon=horizon,
        )

    def message_passing(self, x, edge_index, edge_weight=None):
        for layer in self.gnn_layers:
            x = layer(
                x=x, edge_index=edge_index, edge_weight=edge_weight, cache_support=True
            )
        return x

    def forward(self, x, edge_index, edge_weight=None, u=None):
        """"""
        x = maybe_cat_exog(x, u)

        emb = self.emb()
        x = maybe_cat_emb(x, emb)

        # temporal encoding
        # weights are shared across levels
        x = self.input_encoder(x)

        x = self.temporal_encoder(x)

        out = self.message_passing(x, edge_index=edge_index, edge_weight=edge_weight)

        out = torch.cat([out, x], dim=-1)

        out = self.decoder(out)
        return out
