from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import MixedLinear, MixedDropout


class GNNs(nn.Module):
    def __init__(self, nfeatures: int, nclasses: int, hiddenunits: List[int], drop_prob: float,
                 propagation: nn.Module, bias: bool = False):
        super().__init__()

        fcs = [MixedLinear(nfeatures, hiddenunits[0], bias=bias)]
        for i in range(1, len(hiddenunits)):
            fcs.append(nn.Linear(hiddenunits[i - 1], hiddenunits[i], bias=bias))
        fcs.append(nn.Linear(hiddenunits[-1], nclasses, bias=bias))
        self.fcs = nn.ModuleList(fcs)

        self.reg_params = list(self.fcs[0].parameters())

        if drop_prob == 0:
            self.dropout = lambda x: x
        else:
            self.dropout = MixedDropout(drop_prob)
        self.act_fn = nn.ReLU()

        self.propagation = propagation

    def _transform_features(self, attr_matrix: torch.sparse.FloatTensor):
        layer_inner = self.act_fn(self.fcs[0](self.dropout(attr_matrix)))
        for fc in self.fcs[1:-1]:
            layer_inner = self.act_fn(fc(layer_inner))
        res = self.fcs[-1](self.dropout(layer_inner))
        return res

    def forward(self, attr_matrix: torch.sparse.FloatTensor, idx: torch.LongTensor):
        local_logits = self._transform_features(attr_matrix)
        ################### Batch normalization ##########################
        # method 1
        # norm_matrix = torch.norm(local_logits)
        # local_logits = local_logits / norm_matrix

        # method 2
        # mean = torch.mean(local_logits, dim=0)
        # variance = torch.var(local_logits, dim=0)
        # local_logits = (local_logits - mean) / torch.sqrt(variance + 1e-8)
        ##################################################################
        final_logits = self.propagation(local_logits, idx)
        # print("finallogit={} {}".format(final_logits.shape , final_logits))
        return F.log_softmax(final_logits, dim=-1)
