import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch

from models.layers.compgcn_conv import CompGCNConv

# Knowledge-conditioned Feature Modulation
class KFM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = args.device
        self.hidden_size = args.hidden_size
        # Test
        self.relation_embed = torch.nn.Embedding(
            len(args.label_map),
            128,
        )
        print(f"The number of relations: {len(args.label_map)}")
        self.compgcn = CompGCNConv(self.hidden_size, self.hidden_size, 128)
        self.weight_linear = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size)
        )
        self.bias_linear = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size)
        )

    def fill_batch(self, graph_embeds, B):
        _B, L, D = graph_embeds.shape
        # Sometimes, the last batch may not be appended
        while _B < B:
            graph_embeds = torch.cat([graph_embeds, torch.zeros((1, L, D), device=self.device)],
                                    dim=0)
            _B, L, D = graph_embeds.shape
        
        # Zero padding in front to cover NO corresponding node case
        graph_embeds = torch.cat([torch.zeros((B, 1, D), device=self.device),
                                graph_embeds], dim=1)
        return graph_embeds

    def forward(
        self,
        hidden_states,
        nodes,
        edge_index,
        edge_attr,
        batch,
        mention_positions,
    ):
        # B x N
        B, N = mention_positions.shape
        # B x N x D -> B*N x D
        knowledge_states = hidden_states[:,:N,:].detach().reshape(B * N, -1)

        # B*N
        _mention_positions = mention_positions.reshape(B * N)
        
        # N' x D (N' = The number of total nodes in batch)
        node_embeds = scatter(knowledge_states, _mention_positions+1, dim=0, reduce='mean')
        node_embeds = node_embeds[1:,:]

        edge_index = torch.tensor(edge_index, device=self.device, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_attr, device=self.device, dtype=torch.long)
        batch = torch.tensor(batch, device=self.device)
        # print(edge_attr.max())
        edge_attr = self.relation_embed(edge_attr)

        # N' x D
        out = self.compgcn(node_embeds, edge_index, edge_attr)
        # (1 + N') x D
        out = torch.cat([torch.zeros(1, self.hidden_size, device=self.device), out], dim=0)
        # B x N x D
        graph_embeds = out[mention_positions + 1]
        graph_mask = (mention_positions != -1).to(self.device).unsqueeze(-1).float()

        gamma = self.weight_linear(F.relu(graph_embeds)) * graph_mask
        _gamma_holder = torch.zeros_like(hidden_states)
        _gamma_holder[:,:N,:] = gamma
        
        beta = self.bias_linear(F.relu(graph_embeds)) * graph_mask
        # B x N x D -> B x L x D
        _beta_holder = torch.zeros_like(hidden_states)
        _beta_holder[:,:N,:] = beta

        hidden_states = (1 + _gamma_holder) * hidden_states + _beta_holder
        return hidden_states
