import torch

from abc import ABC, abstractmethod
from torch.nn import Module, ModuleList, LSTM
from torch.nn.functional import mse_loss, relu
from torch_geometric.nn import GATConv, GCNConv, GCN2Conv, Linear
from torch_geometric.utils import add_self_loops


class BaseModel(Module, ABC):
    def __init__(self, in_channels, out_channels, hidden_channels, num_hidden, param_sharing, layerfun, edge_weights):
        super().__init__()
        self.pred_len = out_channels
        self.encoder = Linear(in_channels, hidden_channels, weight_initializer="kaiming_uniform")
        self.decoder = Linear(hidden_channels, out_channels, weight_initializer="kaiming_uniform")
        if param_sharing:
            self.layers = ModuleList(num_hidden * [layerfun()])
        else:
            self.layers = ModuleList([layerfun() for _ in range(num_hidden)])
        self.edge_weights = edge_weights
        if self.edge_weights is not None:
            self.loop_fill_value = 1.0 if (self.edge_weights == 0).all() else "mean"

    def forward(self, x, x_mark_enc, x_dec, x_mark_dec, mask=None, evo_tracking=False):
        # x = x.flatten(1)
        edge_index = self.edge_index
        # Normalization from Non-stationary Transformer
        means = x.mean(1, keepdim=True).detach()
        x = x - means
        stdev = torch.sqrt(torch.var(x, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x /= stdev

        x = x.permute(0, 2, 1)
        if self.edge_weights is not None:
            num_graphs = edge_index.size(1) // len(self.edge_weights)
            edge_weights = torch.cat(num_graphs * [self.edge_weights], dim=0).to(x.device)
            edge_weights = edge_weights.abs()  # relevant when edge weights are learned
        else:
            edge_weights = torch.zeros(self.edge_index.size(1)).to(x.device)

        if self.edge_weights is not None:
            edge_index, edge_weights = add_self_loops(edge_index, edge_weights, fill_value=self.loop_fill_value)

        x_0 = self.encoder(x)
        evolution = [x_0.detach()] if evo_tracking else None

        x = x_0
        for layer in self.layers:
            x = self.apply_layer(layer, x, x_0, edge_index, edge_weights)
            if evo_tracking:
                evolution.append(x.detach())
        x = self.decoder(x)

        if evo_tracking:
            return x, evolution
        x = x.permute(0, 2, 1)

        # De-Normalization from Non-stationary Transformer
        x = x * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        x = x + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len, 1))
        return x

    @abstractmethod
    def apply_layer(self, layer, x, x_0, edge_index, edge_weights):
        pass


class MLP(BaseModel):
    def __init__(self, in_channels, hidden_channels, num_hidden, param_sharing):
        layer_gen = lambda: Linear(hidden_channels, hidden_channels, weight_initializer="kaiming_uniform")
        super().__init__(in_channels, hidden_channels, num_hidden, param_sharing, layer_gen, None, None)

    def apply_layer(self, layer, x, x_0, edge_index, edge_weights):
        return relu(layer(x))











