from collections import OrderedDict

import torch


class TensorDict(OrderedDict):
    """Container mainly used for dicts of torch tensors. Extends OrderedDict with pytorch functionality."""
    def concat(self, other):
        """Concatenates two dicts without copying internal data."""
        return TensorDict(self, **other)

    def copy(self):
        return TensorDict(super(TensorDict, self).copy())

    def __getattr__(self, name):
        if not hasattr(torch.Tensor, name):
            raise AttributeError(
                '\'TensorDict\' object has not attribute \'{}\''.format(name))

        def apply_attr(*args, **kwargs):
            return TensorDict({
                n: getattr(e, name)(*args, **kwargs) if hasattr(e, name) else e
                for n, e in self.items()
            })

        return apply_attr

    def attribute(self, attr: str, *args):
        return TensorDict({n: getattr(e, attr, *args) for n, e in self.items()})

    def apply(self, fn, *args, **kwargs):
        return TensorDict({n: fn(e, *args, **kwargs) for n, e in self.items()})

    @staticmethod
    def _iterable(a):
        return isinstance(a, (TensorDict, list))
