import torch
import torch.nn as nn
from utils import create_state
import math
from torch_geometric.data import Data
from torch_scatter import scatter_add
from torch_geometric.utils import add_self_loops


class PE_MPNN(nn.Module):
    def __init__(self, feat_in: int, pos_in: int, feat_hidden: int, pos_hidden: int, num_out: int, num_layers: int,
                 state_type: str, layer_type: str, ent_deg: int, red: int) -> None:
        super().__init__()
        # embeddings
        assert state_type in ['concat', 'tensor']
        assert layer_type in ['gin', 'gcn', 'sage', 'transformer']

        self.feat_hidden, self.pos_hidden = feat_hidden, pos_hidden
        self.reduced = ent_deg > 0
        reduced_size = int(math.sqrt(2 * feat_hidden)) if state_type == 'concat' else feat_hidden
        self.reduced_size = reduced_size
        self.layer_type = layer_type
        self.state_type = state_type
        self.num_layers = num_layers
        self.pre_embed_h, self.pre_embed_p = nn.Linear(feat_in, feat_hidden), nn.Linear(pos_in, pos_hidden)

        embed_dim = feat_hidden + pos_hidden if state_type == 'concat' else feat_hidden * pos_hidden
        self.embed = nn.Linear(embed_dim, embed_dim)

        self.readout = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Linear(embed_dim // 2, num_out)
        )

        # message passing and read out
        if self.reduced:
            self.layers = nn.ModuleList(
                [self.get_layer(layer_type, reduced_size, reduced_size, reduced_size, reduced_size, state_type, ent_deg) for _ in range(num_layers)]
            )
        else:
            self.layers = nn.ModuleList(
                [self.get_layer(layer_type, feat_hidden, pos_hidden, feat_hidden, pos_hidden, state_type, ent_deg)
                 for _ in range(num_layers)]
            )
    
    def forward(self, data: Data) -> torch.Tensor:
        x, edge_index, e, p, batch = data.x, data.edge_index, data.edge_attr, data.p, data.batch

        if self.layer_type == 'gcn':
            edge_index = add_self_loops(edge_index)[0]

        x, p = self.pre_embed_h(x.float()), self.pre_embed_p(p.float())
        state = create_state(self.state_type, x, p)
        state = self.embed(state.float())

        if self.reduced:
            state = torch.reshape(state, (state.shape[0], self.reduced_size, self.reduced_size))

        for layer in self.layers:
            state = layer(state, edge_index)

        state = scatter_add(state, batch, 0)

        if self.reduced:
            state = torch.flatten(state, start_dim=1)

        out = self.readout(state)
        return torch.squeeze(out)

    def get_layer(self, layer_type, feat_hidden, pos_hidden, feat_out, pos_out, state_type, state_kwargs) -> nn.Module:
        if layer_type == 'gin':
            from layers.gin_layer import GINLayer
            layer = GINLayer(feat_hidden, pos_hidden, state_type, state_kwargs)
        elif layer_type == 'gcn':
            from layers.gcn_layer import GCNLayer
            layer = GCNLayer(feat_hidden, pos_hidden, feat_out, pos_out, state_type, state_kwargs)
        elif layer_type == 'sage':
            from layers.sage_conv_layer import SageConvLayer
            layer = SageConvLayer(feat_hidden, pos_hidden, state_type, state_kwargs)
        else:
            raise ValueError(f'Do not recognize layer type {layer_type}.')

        return layer
