
import collections
import numpy as np
import torch
import torch.nn as nn


def separate_no_decay(module,
                      name_blacklist=None,
                      blacklist_weight_modules=(
                              nn.LayerNorm,
                              nn.Embedding,
                              nn.BatchNorm2d,
                              nn.GroupNorm)):


    decay = set()
    no_decay = set()
    if name_blacklist is None:
        name_blacklist = []

    whitelist_classes = set()
    for mn, m in module.named_modules():
        # This skips modules whose names include words from the name blacklist
        bl = False
        for name in name_blacklist:
            if name in mn:
                bl = True
                # print(name)
                break
        if bl:
            continue

        for pn, p in m.named_parameters():
            fpn = f"{mn}.{pn}" if mn else pn  # full param name
            if '.' in pn:
                break

            if pn.endswith("bias"):
                # all biases will not be decayed
                no_decay.add(fpn)
            elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
                # weights of blacklist modules will NOT be weight decayed
                no_decay.add(fpn)
            elif pn.endswith("weight"):
                whitelist_classes.add(type(m))
                decay.add(fpn)
            else:
                no_decay.add(fpn)

    # validate that we considered every parameter
    if len(name_blacklist) > 0:
        old_param_dict = {pn: p for pn, p in module.named_parameters()}
        param_dict = {}
        for pn, p in old_param_dict.items():
            bl = False
            for name in name_blacklist:
                if name in pn:
                    bl = True
            if not bl:
                param_dict[pn] = p
    else:
        param_dict = {pn: p for pn, p in module.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay

    assert len(inter_params) == 0, (
            "parameters %s made it into both decay/no_decay sets!"
            % (str(inter_params),)
    )
    assert len(param_dict.keys() - union_params) == 0, (
            "parameters %s were not separated into either decay/no_decay set!"
            % (str(param_dict.keys() - union_params),)
    )

    # breakpoint()
    decay = [param_dict[pn] for pn in sorted(list(decay))]
    no_decay = [param_dict[pn] for pn in sorted(list(no_decay))]

    return decay, no_decay


def recursive_dict_list_tuple_apply(x, type_func_dict):

    assert (list not in type_func_dict)
    assert (tuple not in type_func_dict)
    assert (dict not in type_func_dict)

    if isinstance(x, (dict, collections.OrderedDict)):
        new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else dict()
        for k, v in x.items():
            new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
        return new_x
    elif isinstance(x, (list, tuple)):
        ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
        if isinstance(x, tuple):
            ret = tuple(ret)
        return ret
    else:
        for t, f in type_func_dict.items():
            if isinstance(x, t):
                return f(x)
        else:
            raise NotImplementedError(
                'Cannot handle data type %s' % str(type(x)))


def map_tensor(x, func):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: func,
            type(None): lambda x: x,
        }
    )


def map_ndarray(x, func):

    return recursive_dict_list_tuple_apply(
        x,
        {
            np.ndarray: func,
            type(None): lambda x: x,
        }
    )


def map_tensor_ndarray(x, tensor_func, ndarray_func):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: tensor_func,
            np.ndarray: ndarray_func,
            type(None): lambda x: x,
        }
    )


def clone(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.clone(),
            np.ndarray: lambda x: x.copy(),
            type(None): lambda x: x,
        }
    )


def detach(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.detach(),
        }
    )


def to_batch(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x[None, ...],
            np.ndarray: lambda x: x[None, ...],
            type(None): lambda x: x,
        }
    )


def to_sequence(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x[:, None, ...],
            np.ndarray: lambda x: x[:, None, ...],
            type(None): lambda x: x,
        }
    )


def index_at_time(x, ind):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x[:, ind, ...],
            np.ndarray: lambda x: x[:, ind, ...],
            type(None): lambda x: x,
        }
    )


def unsqueeze(x, dim):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.unsqueeze(dim=dim),
            np.ndarray: lambda x: np.expand_dims(x, axis=dim),
            type(None): lambda x: x,
        }
    )


def contiguous(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.contiguous(),
            np.ndarray: lambda x: np.ascontiguousarray(x),
            type(None): lambda x: x,
        }
    )


def to_device(x, device):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x, d=device: x.to(d),
            type(None): lambda x: x,
        }
    )


def to_tensor(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x,
            np.ndarray: lambda x: torch.from_numpy(x),
            type(None): lambda x: x,
        }
    )


def to_numpy(x):

    def f(tensor):
        if tensor.is_cuda:
            return tensor.detach().cpu().numpy()
        else:
            return tensor.detach().numpy()

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: f,
            np.ndarray: lambda x: x,
            type(None): lambda x: x,
        }
    )


def to_list(x):

    def f(tensor):
        if tensor.is_cuda:
            return tensor.detach().cpu().numpy().tolist()
        else:
            return tensor.detach().numpy().tolist()

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: f,
            np.ndarray: lambda x: x.tolist(),
            type(None): lambda x: x,
        }
    )


def to_float(x):
    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.float(),
            np.ndarray: lambda x: x.astype(np.float32),
            type(None): lambda x: x,
        }
    )


def to_uint8(x):
    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.byte(),
            np.ndarray: lambda x: x.astype(np.uint8),
            type(None): lambda x: x,
        }
    )


def to_torch(x, device):

    return to_device(to_float(to_tensor(x)), device)


def to_one_hot_single(tensor, num_class):

    x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device)
    x.scatter_(-1, tensor.unsqueeze(-1), 1)
    return x


def to_one_hot(tensor, num_class):

    return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc))


def flatten_single(x, begin_axis=1):

    fixed_size = x.size()[:begin_axis]
    _s = list(fixed_size) + [-1]
    return x.reshape(*_s)


def flatten(x, begin_axis=1):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
        }
    )


def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):

    assert (begin_axis <= end_axis)
    assert (begin_axis >= 0)
    assert (end_axis < len(x.shape))
    assert (isinstance(target_dims, (tuple, list)))
    s = x.shape
    final_s = []
    for i in range(len(s)):
        if i == begin_axis:
            final_s.extend(target_dims)
        elif i < begin_axis or i > end_axis:
            final_s.append(s[i])
    return x.reshape(*final_s)


def reshape_dimensions(x, begin_axis, end_axis, target_dims):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
                x, begin_axis=b, end_axis=e, target_dims=t),
            np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single(
                x, begin_axis=b, end_axis=e, target_dims=t),
            type(None): lambda x: x,
        }
    )


def join_dimensions(x, begin_axis, end_axis):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
                x, begin_axis=b, end_axis=e, target_dims=[-1]),
            np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
                x, begin_axis=b, end_axis=e, target_dims=[-1]),
            type(None): lambda x: x,
        }
    )


def expand_at_single(x, size, dim):

    assert dim < x.ndimension()
    assert x.shape[dim] == 1
    expand_dims = [-1] * x.ndimension()
    expand_dims[dim] = size
    return x.expand(*expand_dims)


def expand_at(x, size, dim):

    return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))


def unsqueeze_expand_at(x, size, dim):

    x = unsqueeze(x, dim)
    return expand_at(x, size, dim)


def repeat_by_expand_at(x, repeats, dim):

    x = unsqueeze_expand_at(x, repeats, dim + 1)
    return join_dimensions(x, dim, dim + 1)


def named_reduce_single(x, reduction, dim):

    assert x.ndimension() > dim
    assert reduction in ["sum", "max", "mean", "flatten"]
    if reduction == "flatten":
        x = flatten(x, begin_axis=dim)
    elif reduction == "max":
        x = torch.max(x, dim=dim)[0]  # [B, D]
    elif reduction == "sum":
        x = torch.sum(x, dim=dim)
    else:
        x = torch.mean(x, dim=dim)
    return x


def named_reduce(x, reduction, dim):

    return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))


def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):

    assert len(indices.shape) == 1
    assert x.shape[source_dim] == indices.shape[0]

    # unsqueeze in all dimensions except the source dimension
    new_shape = [1] * x.ndimension()
    new_shape[source_dim] = -1
    indices = indices.reshape(*new_shape)

    # repeat in all dimensions - but preserve shape of source dimension,
    # and make sure target_dimension has singleton dimension
    expand_shape = list(x.shape)
    expand_shape[source_dim] = -1
    expand_shape[target_dim] = 1
    indices = indices.expand(*expand_shape)

    out = x.gather(dim=target_dim, index=indices)
    return out.squeeze(target_dim)


def gather_along_dim_with_dim(x, target_dim, source_dim, indices):

    return map_tensor(x,
                      lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i))


def gather_sequence_single(seq, indices):

    return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices)


def gather_sequence(seq, indices):

    return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices)


def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None):

    assert isinstance(seq, (np.ndarray, torch.Tensor))
    assert pad_same or pad_values is not None
    if pad_values is not None:
        assert isinstance(pad_values, float)
    repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave
    concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
    ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like
    seq_dim = 1 if batched else 0

    begin_pad = []
    end_pad = []

    if padding[0] > 0:
        pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
        begin_pad.append(repeat_func(pad, padding[0], seq_dim))
    if padding[1] > 0:
        pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
        end_pad.append(repeat_func(pad, padding[1], seq_dim))

    return concat_func(begin_pad + [seq] + end_pad, seq_dim)


def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):

    return recursive_dict_list_tuple_apply(
        seq,
        {
            torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
            pad_sequence_single(x, p, b, ps, pv),
            np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
            pad_sequence_single(x, p, b, ps, pv),
            type(None): lambda x: x,
        }
    )


def assert_size_at_dim_single(x, size, dim, msg):

    assert x.shape[dim] == size, msg


def assert_size_at_dim(x, size, dim, msg):

    map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))


def get_shape(x):

    return recursive_dict_list_tuple_apply(
        x,
        {
            torch.Tensor: lambda x: x.shape,
            np.ndarray: lambda x: x.shape,
            type(None): lambda x: x,
        }
    )


def list_of_flat_dict_to_dict_of_list(list_of_dict):

    assert isinstance(list_of_dict, list)
    dic = collections.OrderedDict()
    for i in range(len(list_of_dict)):
        for k in list_of_dict[i]:
            if k not in dic:
                dic[k] = []
            dic[k].append(list_of_dict[i][k])
    return dic


def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):

    items = []
    if isinstance(d, (tuple, list)):
        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
        for i, v in enumerate(d):
            items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
        return items
    elif isinstance(d, dict):
        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
        for k, v in d.items():
            assert isinstance(k, str)
            items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
        return items
    else:
        new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key
        return [(new_key, d)]


def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs):

    batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
    inputs = join_dimensions(inputs, 0, 1)
    if inputs_as_kwargs:
        outputs = op(**inputs, **kwargs)
    elif inputs_as_args:
        outputs = op(*inputs, **kwargs)
    else:
        outputs = op(inputs, **kwargs)

    if activation is not None:
        outputs = map_tensor(outputs, activation)
    outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len))
    return outputs
