from typing import Dict, List

from schnetpack.model import AtomisticModel

import torch
import torch.nn as nn

__all__ = ["EnVariationalDiffusion"]


class EnVariationalDiffusion(AtomisticModel):
    """
    E(n) Variational Diffusion based on https://proceedings.mlr.press/v162/hoogeboom22a.html
    """

    def __init__(
        self,
        representation: nn.Module,
        input_modules: List[nn.Module] = None,
        output_modules: List[nn.Module] = None,
        input_dtype_str: str = "float32",
    ):
        super().__init__(input_dtype_str=input_dtype_str)
        self.representation = representation
        self.input_modules = nn.ModuleList(input_modules)
        self.output_modules = nn.ModuleList(output_modules)

        self.collect_derivatives()
        self.collect_outputs()

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        pass
