from typing import Any, Dict
from functools import partial

import torch
import torch.nn.functional as activations
import torchmetrics

class Registry:
    def __init__(self):
        self._registry = {}
        self._pending_registration = {}

    def register(self, name, obj=None):
        """Register an object (function, class) or defer initialization."""
        if obj is None:
            # Defer the registration (lazy initialization)
            def wrapper(fn):
                self._pending_registration[name] = fn
                return fn
            return wrapper
        else:
            self._registry[name] = obj

    def get(self, name):
        """Retrieve an object by its name, lazy-load if necessary."""
        if name not in self._registry:
            if name in self._pending_registration:
                # Initialize the deferred object
                self._registry[name] = self._pending_registration[name]()
            else:
                raise KeyError(f"'{name}' is not registered or pending registration.")
        return self._registry[name]

    def __contains__(self, name):
        return name in self._registry or name in self._pending_registration

"""
ACTIVATIONS
===========
"""
ACT_REGISTRY = Registry()
ACT_REGISTRY.register('relu', activations.relu)
ACT_REGISTRY.register('sigmoid', activations.sigmoid)
ACT_REGISTRY.register('tanh', activations.tanh)
ACT_REGISTRY.register('softmax', activations.softmax)
ACT_REGISTRY.register('leaky_relu', activations.leaky_relu)
ACT_REGISTRY.register('elu', activations.elu)
ACT_REGISTRY.register('gelu', activations.gelu)
ACT_REGISTRY.register('silu', activations.silu)

"""
LOADERS
=======
"""
LOADER_REGISTRY = Registry()
from loaders.kchains import kchain_loaders
LOADER_REGISTRY.register('kchain', kchain_loaders)
from loaders.h2o_forces import h2o_loaders
LOADER_REGISTRY.register('h2o', h2o_loaders)
from loaders.box_forces import box_loaders
LOADER_REGISTRY.register('box', box_loaders)
from loaders.polystyrene import polystyrene_loaders
LOADER_REGISTRY.register('polystyrene', polystyrene_loaders)
from loaders.qm9 import qm9_loaders
LOADER_REGISTRY.register('qm9', qm9_loaders)
from loaders.qm7x import qm7_loaders
LOADER_REGISTRY.register('qm7x', qm7_loaders)
from loaders.md17 import md17_loaders
LOADER_REGISTRY.register('md17', md17_loaders)

"""
MODEL
=======
"""
MODEL_REGISTRY = Registry()
