from collections.abc import Mapping, MutableMapping, Sequence

import numpy as np
import torch


def index_collated(collated, indices):
    def _index(elem):
        if isinstance(elem, (torch.Tensor, np.ndarray)):
            return elem[indices]
        elif isinstance(elem, Sequence):
            return np.array(elem)[indices]
        elif isinstance(elem, Mapping):
            return {key: _index(value) for key, value in elem.items()}
        else:
            return elem

    return torch.utils._pytree.tree_map(_index, collated)


def apply_to_collated(collated, fn):
    def _apply(elem):
        if isinstance(elem, (torch.Tensor, np.ndarray)):
            return fn(elem)
        elif isinstance(elem, Sequence) and not isinstance(elem, (str, bytes)):
            return type(elem)(_apply(e) for e in elem)
        elif isinstance(elem, Mapping):
            return {key: _apply(value) for key, value in elem.items()}
        else:
            return elem

    return torch.utils._pytree.tree_map(_apply, collated)


def collate_concat(batch):
    """
    Custom collate function that concatenates tensors using torch.concat
    while preserving nested dictionary and list structures.

    Args:
        batch (list): A batch of data samples (tensors, dictionaries, or lists).

    Returns:
        Collated batch with the same structure as the input but tensors concatenated.
    """
    elem = batch[0]

    if isinstance(elem, torch.Tensor):
        # Concatenate tensors along the first dimension
        return torch.concat(batch, dim=0)

    elif isinstance(elem, np.ndarray):
        # Concatenate tensors along the first dimension
        return np.concatenate(batch, axis=0)

    elif isinstance(elem, Mapping):
        # Recursively apply to dictionary values
        return {key: collate_concat([d[key] for d in batch]) for key in elem}

    elif isinstance(elem, Sequence) and not isinstance(elem, (str, bytes)):
        # Recursively apply to list/tuple elements
        return type(elem)(collate_concat(samples) for samples in zip(*batch))

    else:
        # For other types, return as is (default behavior)
        return batch


def flatten_dict(dictionary, parent_key="", separator="_"):
    items = []
    for key, value in dictionary.items():
        new_key = parent_key + separator + key if parent_key else key
        if isinstance(value, MutableMapping):
            items.extend(flatten_dict(value, new_key, separator=separator).items())
        else:
            items.append((new_key, value))
    return dict(items)


def to_numpy(x):
    def _move(x):
        if isinstance(x, torch.Tensor):
            return x.cpu().numpy()
        return x

    return torch.utils._pytree.tree_map(_move, x)


def to_torch(x):
    def _move(x):
        if isinstance(x, np.ndarray):
            return torch.from_numpy(x)
        return x

    return torch.utils._pytree.tree_map(_move, x)


def to_cuda(x):
    def _move(x):
        if isinstance(x, torch.Tensor):
            return x.cuda()
        return x

    return torch.utils._pytree.tree_map(_move, x)
