import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import GraphConv
from gatv3.dgl import GATv1Layer, GATv2Layer


class GAT(nn.Module):
    def __init__(self, node_feats, n_classes, n_layers, n_heads, n_hidden, activation, dropout,
                 input_drop, type, convolve, lambda_policy, share_weights_value, share_weights_score
    ):
        super().__init__()

        self.type = type
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_hidden = n_hidden
        self.n_classes = n_classes

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()

        self.node_encoder = nn.Linear(node_feats, n_hidden)

        in_hidden = n_hidden
        out_hidden = n_hidden
        for i in range(n_layers):
            # bias = i == n_layers - 1

            if type == 'GCN':
                layer = GATv2Layer(
                    in_hidden,
                    out_hidden,
                    heads=n_heads,
                    bias=True,
                    share_weights_score=share_weights_score,
                    share_weights_value=share_weights_value,
                    convolve=True,
                    lambda_policy=None,
                    gcn_mode=True
                )
            elif type == 'GAT':
                layer = GATv1Layer(
                    in_hidden,
                    out_hidden,
                    heads=n_heads,
                    bias=True,
                    share_weights_score=share_weights_score,
                    share_weights_value=share_weights_value,
                    convolve=convolve,
                    lambda_policy=lambda_policy,
                )
            elif type == 'GAT2':
                layer = GATv2Layer(
                    in_hidden,
                    out_hidden,
                    heads=n_heads,
                    bias=True,
                    share_weights_score=share_weights_score,
                    share_weights_value=share_weights_value,
                    convolve=convolve,
                    lambda_policy=lambda_policy,
                )
            else:
                raise NotImplementedError

            self.convs.append(layer)
            self.norms.append(nn.BatchNorm1d(out_hidden))

        self.pred_linear = nn.Linear(out_hidden, n_classes)

        self.input_drop = nn.Dropout(input_drop)
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    @property
    def lmbda(self):
        return [float(l.lmbda) for l in self.convs] if 'GAT' in self.type else [0.] * len(self.convs)

    @property
    def lmbda2(self):
        return [float(l.lmbda2) for l in self.convs] if 'GAT' in self.type else [1.] * len(self.convs)

    def forward(self, subgraphs):
        if not isinstance(subgraphs, list):
            subgraphs = [subgraphs] * self.n_layers

        h = subgraphs[0].srcdata["feat"]
        h = self.node_encoder(h)
        h = F.relu(h, inplace=True)
        h = self.input_drop(h)

        h_last = None

        for i in range(self.n_layers):
            h = self.convs[i](subgraphs[i], h).flatten(1, -1)

            if h_last is not None:
                h += h_last[: h.shape[0], :]

            h_last = h

            h = self.norms[i](h)
            h = self.activation(h, inplace=True)
            h = self.dropout(h)

        h = self.pred_linear(h)

        return h

