from typing import Dict

import torch
import torch.nn as nn
from torch import Tensor

from models.layers import get_activation_fn


class GraphConv(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, identity_init: bool = False):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)

        self._reset_parameters(identity_init)

    def _reset_parameters(self, identity_init):
        if identity_init:
            nn.init.eye_(self.linear.weight)
            nn.init.zeros_(self.linear.bias)
        else:
            nn.init.xavier_uniform_(self.linear.weight)
            nn.init.normal_(self.linear.bias)

    def forward(self, edges: torch.Tensor, feat: torch.Tensor) -> torch.Tensor:
        """
        Args:
            edges: shape: [bs, n, n]
            x: shape: [bs, n, dim]
        """
        # [bs, n, n] [bs, n, dim] -> [bs, n, dim]
        adj = edges + edges.transpose(1, 2)
        In = torch.zeros_like(adj)
        In.diagonal(dim1=1, dim2=2).fill_(1)
        feat = torch.bmm(adj / 2 + In, feat)
        return self.linear(feat)


class Layer(nn.Module):
    def __init__(
        self,
        emb_dim: int,
        activation: str,
        identity_init: bool = False,
        no_activation: bool = False
    ):
        super().__init__()
        self.g_conv = GraphConv(emb_dim, emb_dim, identity_init)
        self.norm = nn.LayerNorm(emb_dim)
        self.activation = nn.Identity() if no_activation else get_activation_fn(activation)

    def forward(self, edges: torch.Tensor, feat: torch.Tensor):
        feat = self.g_conv(edges, feat)
        feat = self.norm(self.activation(feat))
        return feat


class GNN(nn.Module):
    def __init__(
        self,
        num_codes: int,
        embed_dim: int,
        num_layers: int,
        identity_init: bool = None,
        embedding_init: bool = False,
        activation: str = "relu",
        no_fc: bool = False,
        inc_embedding: int = None,
        vertex_weight: Tensor = None,
        debug: bool = False
    ):
        super().__init__()
        self.num_codes = num_codes
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.no_fc = no_fc
        self.identity_init = False if identity_init is None else identity_init
        self.debug = debug

        self.inc_embedding = inc_embedding
        self.embedding = nn.Embedding(num_embeddings=num_codes, embedding_dim=embed_dim)
        layers = [Layer(embed_dim, activation, self.identity_init) for _ in range(num_layers)]
        self.layers = nn.ModuleList(layers)
        self.fc = nn.Identity() if no_fc else nn.Linear(embed_dim, embed_dim)
        if self.inc_embedding is not None:
            self.embedding.weight.requires_grad = False
            self.fc.bias.requires_grad = False
            self.fc.weight.requires_grad = False
            for param in self.layers.parameters():
                param.requires_grad = False
            self.new_embedding = nn.Embedding(num_embeddings=self.inc_embedding, embedding_dim=embed_dim)
        self._reset_parameters()
        if embedding_init is True:
            if vertex_weight is not None:
                self.initialize_embedding(vertex_weight)

    def _reset_parameters(self):
        if not self.no_fc:
            nn.init.normal_(self.fc.weight)
            nn.init.zeros_(self.fc.bias)
        nn.init.trunc_normal_(self.embedding.weight)
        if self.inc_embedding is not None:
            nn.init.trunc_normal_(self.new_embedding.weight)

    def initialize_embedding(self, vertex_weight: Tensor):
        # simple methods
        vertex_weight_threshold = 0.05
        self.embedding.weight.requires_grad_(False)
        embedding_correct_eps = 250
        v_mask = (vertex_weight >= vertex_weight_threshold)
        padding_shape = (self.num_codes, self.embed_dim - v_mask.shape[0])
        self.embedding.weight[torch.cat((v_mask.T, (torch.zeros(padding_shape) > 1)), dim=1)] += embedding_correct_eps
        self.embedding.weight.requires_grad_(True)

    def forward(self, nodes: torch.Tensor, edges: torch.Tensor, task: int = None):
        """
        Args:
            nodes: [bs, n]
            edges: [bs, n, n]
        """
        # code embedding [bs, n, dim]
        bs = nodes.shape[0]
        info: Dict[str, torch.Tensor] = dict()
        # [n, dim] -> [bs, n, dim]
        if self.inc_embedding is None:
            feat = self.embedding.weight.expand(bs, -1, -1)
        else:
            feat = torch.cat((self.embedding.weight, self.new_embedding.weight), dim=0)[task * 256:(task + 1) * 256, :].expand(bs, -1, -1)
        for l_id, layer in enumerate(self.layers):
            info[f"feat_{l_id}"] = feat
            if task is None:
                feat = layer(edges, feat)
            else:
                feat = layer(edges[:, task * 256:(task + 1) * 256, task * 256:(task + 1) * 256], feat)

        feat_fc = self.fc(feat)
        # weighted average pooling
        if task is not None:
            feat_fc = feat_fc * (nodes[:, task * 256:(task + 1) * 256])[..., None]
        else:
            feat_fc = feat_fc * nodes[..., None]
        feat_fc = feat_fc.mean(dim=1)

        if self.debug:
            info["feat"] = feat
            info["feat_fc"] = feat_fc

        return feat_fc, info
