import torch
import geoopt
from torch import nn
from torch.nn import functional as F


class EncoderLayer(nn.Module):
    def __init__(self, args, feature_dim) -> None:
        super().__init__()

        self.latent_dim = args.latent_dim
        self.feature_dim = feature_dim

        self.variational = nn.Linear(
            self.feature_dim,
            3 * self.latent_dim
        )

    def forward(self, feature):
        feature = self.variational(feature)
        alpha, beta, gamma = torch.split(
            feature,
            [self.latent_dim, self.latent_dim, self.latent_dim],
            dim=-1
        )

        return torch.stack([alpha, beta], dim=-1), gamma


class VanillaEncoderLayer(nn.Module):
    def __init__(self, args, feature_dim) -> None:
        super().__init__()
        
        self.encoder = EncoderLayer(args, feature_dim)

    def forward(self, feature):
        return self.encoder(feature)


class VanillaDecoderLayer(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()

    def forward(self, z):
        z = z.reshape(*z.shape[:-2], -1)
        return z

