from collections import OrderedDict, defaultdict

import numpy as np
import torch

def cat_dict_numpy(dicts):
    assert isinstance(dicts[0], dict)
    keys = dicts[0].keys()

    for d in dicts:
        assert isinstance(d, dict) and d.keys() == keys
    ### HACK HACK HACK
    return OrderedDict([(k, np.concatenate([d[k] for d in dicts], axis=0)) for k in keys])
    # return OrderedDict([(k, np.concatenate([np.expand_dims(d[k], axis=0) for d in dicts], axis=0)) for k in keys])

def cat_dict_numpy_rollout(dicts):
    assert isinstance(dicts[0], dict)
    keys = dicts[0].keys()

    for d in dicts:
        assert isinstance(d, dict) and d.keys() == keys

    return OrderedDict([(k, np.concatenate([np.expand_dims(d[k], axis=0) for d in dicts], axis=0)) for k in keys])

# concatenate OrderedDicts of tensors
def cat_dict_tensor(dicts):
    assert isinstance(dicts[0], OrderedDict)
    keys = dicts[0].keys()

    for dict in dicts:
        assert isinstance(dict, OrderedDict) and dict.keys() == keys

    return OrderedDict([(k, torch.cat([dict[k] for dict in dicts], dim=0)) for k in keys])
