from spaghettini import quick_register

from torch.nn import Module


@quick_register
class ModuleDict(Module):
    # Helps initialize pytorch models using dictionaries of keys as model names and values as pytorch modules.
    def __init__(self, **kwargs):
        super().__init__()
        # Set the models provvided as attributes.
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.model_keys = list(kwargs.keys())

    def items(self):
        return ((k, getattr(self, k)) for k in self.model_keys)

    def __getitem__(self, item):
        return getattr(self, item)

    def __len__(self):
        return len(self.model_keys)

    def __setitem__(self, key, value):
        assert isinstance(value, Module)
        setattr(self, key, value)
        self.model_keys.append(key)

    def __contains__(self, key):
        return key in self.model_keys
