from typing import List, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F


class Last(nn.Module):
    def forward(self, node_embeddings: List[torch.Tensor]):
        return node_embeddings[-1]


class All(nn.Module):
    def forward(self, node_embeddings: List[torch.Tensor]):
        return torch.cat(node_embeddings, dim=-1)


class LayerAttention(nn.Module):
    def __init__(self, nlayers: int, norm_weights: bool = True, skip_first: bool = False, sum_embeddings: bool = False):
        super().__init__()
        self.norm_weights = norm_weights
        self.skip_first = skip_first
        self.sum_embeddings = sum_embeddings
        self.length = nlayers if skip_first else nlayers + 1
        self.weight = nn.Parameter(torch.empty(self.length))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)

    def forward(self, node_embeddings: List[torch.Tensor]):
        if self.skip_first:
            del node_embeddings[0]
        if self.norm_weights:
            att = F.softmax(self.weight, dim=0)
        else:
            att = self.weight
        weighted_embs = [att[i] * node_embeddings[i] for i in range(self.length)]
        if self.sum_embeddings:
            return torch.sum(torch.stack(weighted_embs), dim=0)
        else:
            return torch.cat(weighted_embs, dim=-1)


class PPRWeighted(nn.Module):
    def __init__(self, nlayers: int, alpha: float):
        super().__init__()
        self.weight = torch.empty(nlayers + 1)
        for k in range(nlayers):
            self.weight[k] = alpha * (1 - alpha) ** k
        self.weight[nlayers] = (1 - alpha) ** nlayers

    def forward(self, node_embeddings: List[torch.Tensor]):
        weighted_embs = [self.weight[i] * node_embeddings[i] for i in range(len(node_embeddings))]
        return torch.cat(weighted_embs, dim=-1)


class Weighted(nn.Module):
    def __init__(self, nlayers: int, weight_fn: Callable[[float], float]):
        super().__init__()
        xs = torch.arange(nlayers + 1, dtype=torch.float32)
        self.weight = weight_fn(xs)
        self.weight /= self.weight.mean()

    def forward(self, node_embeddings: List[torch.Tensor]):
        weighted_embs = [self.weight[i] * node_embeddings[i] for i in range(len(node_embeddings))]
        return torch.cat(weighted_embs, dim=-1)


class FullAttention(nn.Module):
    def __init__(self, emb_size: int, nlayers: int, norm_weights: bool = True):
        super().__init__()
        self.norm_weights = norm_weights
        self.weight = nn.Parameter(torch.empty(nlayers + 1, emb_size))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.zeros_(self.weight)

    def forward(self, node_embeddings: List[torch.Tensor]):
        if self.norm_weights:
            att = F.softmax(self.weight, dim=0) * len(self.weight)
        else:
            att = self.weight
        weighted_embs = [att[i] * node_embeddings[i] for i in range(len(node_embeddings))]
        return torch.cat(weighted_embs, dim=-1)


class FullProjection(nn.Module):
    def __init__(self, emb_size: int, nlayers: int, output_size: int):
        super().__init__()
        self.weight = nn.Parameter(torch.empty((nlayers + 1) * emb_size, output_size))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight)

    def forward(self, node_embeddings: List[torch.Tensor]):
        node_emb_cat = torch.cat(node_embeddings, dim=-1)
        return node_emb_cat @ self.weight


class MLP(nn.Module):
    def __init__(self, emb_size: int, nlayers: int, output_size: int):
        super().__init__()
        match_emb_size = (nlayers + 1) * emb_size
        self.mlp = nn.Sequential(
                nn.Linear(match_emb_size, match_emb_size, bias=True),
                nn.LeakyReLU(),
                nn.Linear(match_emb_size, output_size, bias=True))

    def forward(self, node_embeddings: List[torch.Tensor]):
        node_emb_cat = torch.cat(node_embeddings, dim=-1)
        return self.mlp(node_emb_cat)
