from torch import nn
from utils.utils import *
from src.models.ssn_heteroconv import SSNHeteroConv
from torch_geometric.nn.conv.gat_conv import GATConv


class FeatureProjector(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        max_dim: int,
        accelerator: str,
        ConvLayer: nn.Module,
        conv_kwargs: dict,
    ):
        super(FeatureProjector, self).__init__()

        self.max_dim = max_dim

        # Initialize projectors and normalizers
        self.projectors = nn.ModuleList(
            [
                nn.Linear(2 * hidden_dim, hidden_dim).to(accelerator)
                for _ in range(max_dim)
            ]
        )

        self.normalizers = nn.ModuleList(  # TODO change normalizers
            [nn.LayerNorm(hidden_dim).to(accelerator) for _ in range(max_dim)]
        )

        in_channels = (-1, -1) if isinstance(ConvLayer(1, 1), GATConv) else -1
        # Initialize convolutional layers for each dimension
        self.conv_layers = nn.ModuleList(
            [
                SSNHeteroConv(
                    {
                        (f"{i}", "b_a", f"{i - 1}"): ConvLayer(
                            in_channels=in_channels,
                            out_channels=hidden_dim,
                            **conv_kwargs,
                        )
                    },
                    unique=True,
                ).to(accelerator)
                for i in range(max_dim, 0, -1)
            ]
        )

    def forward(self, x_dict: dict, edge_index_dict: dict):
        """Propagates features iteratively from higher dimensions to lower dimensions."""
        for i in range(self.max_dim):
            ldim = str(self.max_dim - (1 + i))
            x_low = self.normalizers[i](self.conv_layers[i](x_dict, edge_index_dict))
            x_dict[ldim] = self.projectors[i](torch.cat([x_low, x_dict[ldim]], dim=1))
        return x_dict
