
import torch

import collections
from collections import OrderedDict

class GetAttr:

    "Inherit from this to have all attr accesses in `self._xtra` passed down to `self.default`"
    _default='default'
    def _component_attr_filter(self,k):
        if k.startswith('__') or k in ('_xtra',self._default): return False
        xtra = getattr(self,'_xtra',None)
        return xtra is None or k in xtra

    def _dir(self): 
        return [k for k in dir(getattr(self,self._default)) if self._component_attr_filter(k)]

    def __getattr__(self, k):
        if self._component_attr_filter(k):
            attr = getattr(self, self._default, None)
            if attr is not None: return getattr(attr,k)
        # raise AttributeError(k)

    def __dir__(self): 
        return custom_dir(self,self._dir())

#     def __getstate__(self): return self.__dict__
    def __setstate__(self,data): 
        self.__dict__.update(data)



def default_device(use_cuda=True):
    "Return or set default device; `use_cuda`: None - CUDA if available; True - error if not available; False - CPU"
    if not torch.cuda.is_available():
        use_cuda = False
    return torch.device(torch.cuda.current_device()) if use_cuda else torch.device('cpu')


def to_device(b, device=None, non_blocking=False):
    "Recursively put `b` on `device`."    
    "components of b are torch tensors"
    if device is None: 
        device = default_device(use_cuda=True)

    if isinstance(b, dict):
        return {key: to_device(val, device) for key, val in b.items()}

    if isinstance(b, (list, tuple)):        
        return type(b)(to_device(o, device) for o in b)      
    
    return b.to(device, non_blocking=non_blocking)


