from collections import defaultdict
from collections.abc import Mapping
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch.utils.data
import torch_geometric
import torch_sparse
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData
from torch_geometric.data.datapipes import DatasetAdapter
from torch_geometric.data.on_disk_dataset import OnDiskDataset
from torch_geometric.data.storage import BaseStorage
from torch_geometric.typing import SparseTensor, TensorFrame, torch_frame
from torch_geometric.utils import is_sparse, is_torch_sparse_tensor
from torch_geometric.utils.sparse import cat

FLOAT_PADDING_VALUE = 1e-8
NON_FLOAT_PADDING_VALUE = -1


def _dense_pad_tensor(
    key,
    values,
    float_padding_value: float = FLOAT_PADDING_VALUE,
    non_float_padding_value: int = NON_FLOAT_PADDING_VALUE,
    max_size: Optional[int] = None,
) -> Tuple[Any, Any]:

    mask = None
    cat_dim = None
    elem = values[0]
    dtype = elem.dtype
    padding_value = (
        float_padding_value
        if torch.is_floating_point(elem)
        else non_float_padding_value
    )

    if elem.dim() == 0:
        values = [value.unsqueeze(0) for value in values]
    else:
        if dtype != torch.float:
            if dtype in [torch.uint8, torch.bool]:

                values = [value.long() for value in values]
            if "edge_index" in key and elem.dim() == 2 and elem.shape[0] == 2:

                if max_size is not None:

                    padded_values = []
                    for val in values:
                        val_transposed = val.permute(1, 0)
                        if val_transposed.size(0) < max_size:
                            pad_size = max_size - val_transposed.size(0)
                            padding = torch.full(
                                (pad_size, 2),
                                padding_value,
                                dtype=val_transposed.dtype,
                                device=val_transposed.device,
                            )
                            val_padded = torch.cat([val_transposed, padding], dim=0)
                        else:
                            val_padded = val_transposed[:max_size]
                        padded_values.append(val_padded.unsqueeze(0))
                    values = [value.permute(0, 2, 1) for value in padded_values]
                else:
                    values = [
                        value.permute(1, 0).unsqueeze(0)
                        for value in torch.nn.utils.rnn.pad_sequence(
                            [val.permute(1, 0) for val in values],
                            batch_first=True,
                            padding_value=padding_value,
                        )
                    ]
            else:
                if max_size is not None:

                    padded_values = []
                    for val in values:
                        if val.size(0) < max_size:
                            pad_size = max_size - val.size(0)
                            pad_shape = (pad_size,) + val.shape[1:]
                            padding = torch.full(
                                pad_shape,
                                padding_value,
                                dtype=val.dtype,
                                device=val.device,
                            )
                            val_padded = torch.cat([val, padding], dim=0)
                        else:
                            val_padded = val[:max_size]
                        padded_values.append(val_padded.unsqueeze(0))
                    values = padded_values
                else:
                    values = [
                        value.unsqueeze(0)
                        for value in torch.nn.utils.rnn.pad_sequence(
                            values, batch_first=True, padding_value=padding_value
                        )
                    ]
            if dtype in [torch.uint8, torch.bool]:

                mask = torch.cat(
                    [(value != padding_value) for value in values], dim=cat_dim or 0
                )
                for value in values:
                    value[value == padding_value] = 0
                values = [value.to(dtype) for value in values]
        else:
            if max_size is not None:

                padded_values = []
                for val in values:
                    if val.size(0) < max_size:
                        pad_size = max_size - val.size(0)
                        pad_shape = (pad_size,) + val.shape[1:]
                        padding = torch.full(
                            pad_shape, padding_value, dtype=val.dtype, device=val.device
                        )
                        val_padded = torch.cat([val, padding], dim=0)
                    else:
                        val_padded = val[:max_size]
                    padded_values.append(val_padded.unsqueeze(0))
                values = padded_values
            else:
                values = [
                    value.unsqueeze(0)
                    for value in torch.nn.utils.rnn.pad_sequence(
                        values, batch_first=True, padding_value=padding_value
                    )
                ]

    return values, mask


def _dense_padded_collate(
    key: str,
    values: List[Any],
    data_list: List[BaseData],
    stores: List[BaseStorage],
    float_padding_value: float = FLOAT_PADDING_VALUE,
    non_float_padding_value: int = NON_FLOAT_PADDING_VALUE,
    max_size: Optional[int] = None,
) -> Tuple[Any, Any]:

    cat_dim = None
    elem = values[0]

    if isinstance(elem, torch.Tensor) and not is_sparse(elem):

        padding_value = (
            float_padding_value
            if torch.is_floating_point(elem)
            else non_float_padding_value
        )
        values, mask = _dense_pad_tensor(
            key,
            values,
            float_padding_value=float_padding_value,
            non_float_padding_value=non_float_padding_value,
            max_size=max_size,
        )

        if getattr(elem, "is_nested", False):
            raise NotImplementedError(
                "Dense padding collation for nested tensors is not supported (tested) yet."
            )
            tensors = []
            for nested_tensor in values:
                tensors.extend(nested_tensor.unbind())
            value = torch.nested.nested_tensor(tensors)
            mask = torch.nested.map(lambda tensor: (tensor != padding_value), value)

            return value, mask

        out = None
        if torch.utils.data.get_worker_info() is not None:

            numel = sum(value.numel() for value in values)
            if torch_geometric.typing.WITH_PT20:
                storage = elem.untyped_storage()._new_shared(
                    numel * elem.element_size(), device=elem.device
                )
            elif torch_geometric.typing.WITH_PT112:
                storage = elem.storage()._new_shared(numel, device=elem.device)
            else:
                storage = elem.storage()._new_shared(numel)
            shape = [len(data_list)] + list(
                values[np.argmax([value.numel() for value in values])].shape[1:]
            )
            out = elem.new(storage).resize_(shape)

        value = torch.cat(values, dim=cat_dim or 0, out=out)
        mask = mask if mask is not None else (value != padding_value)

        return value, mask

    elif isinstance(elem, TensorFrame):
        raise NotImplementedError(
            "Dense padding collation for TensorFrames is not supported (tested) yet."
        )
        values, mask = _dense_pad_tensor(
            key,
            values,
            non_float_padding_value=non_float_padding_value,
            max_size=max_size,
        )
        value = torch_frame.cat(values, along="row")
        return value, mask

    elif is_sparse(elem):

        raise NotImplementedError(
            "Dense padding collation for SparseTensors is not supported (tested) yet."
        )
        values, mask = _dense_pad_tensor(
            key,
            values,
            non_float_padding_value=non_float_padding_value,
            max_size=max_size,
        )
        if is_torch_sparse_tensor(elem):
            value = cat(values, dim=cat_dim)
        else:
            value = torch_sparse.cat(values, dim=cat_dim)
        return value, mask

    elif isinstance(elem, (int, float)):

        value = torch.tensor(values)
        return value, None

    elif isinstance(elem, Mapping):

        value_dict, mask_dict = {}, {}
        for key in elem.keys():
            value_dict[key], mask_dict[key] = _dense_padded_collate(
                key,
                [v[key] for v in values],
                data_list,
                stores,
                float_padding_value=float_padding_value,
                non_float_padding_value=non_float_padding_value,
                max_size=max_size,
            )
        return value_dict, mask_dict

    elif (
        isinstance(elem, Sequence)
        and not isinstance(elem, str)
        and len(elem) > 0
        and isinstance(elem[0], (torch.Tensor, SparseTensor))
    ):

        value_list, mask_list = [], []
        for i in range(len(elem)):
            value, mask = _dense_padded_collate(
                key,
                [v[i] for v in values],
                data_list,
                stores,
                float_padding_value=float_padding_value,
                non_float_padding_value=non_float_padding_value,
                max_size=max_size,
            )
            value_list.append(value)
            mask_list.append(mask)
        return value_list, mask_list

    else:

        return values, None


def dense_padded_collate(
    cls,
    data_list: List[BaseData],
    follow_batch: Optional[List[str]] = None,
    exclude_keys: Optional[List[str]] = None,
    max_size: Optional[int] = None,
) -> Tuple[BaseData, Mapping]:

    if not isinstance(data_list, (list, tuple)):

        data_list = list(data_list)

    if cls != data_list[0].__class__:
        out = cls(_base_cls=data_list[0].__class__)
    else:
        out = cls()

    out.stores_as(data_list[0])

    follow_batch = set(follow_batch or [])
    exclude_keys = set(exclude_keys or [])

    key_to_stores = defaultdict(list)
    for data in data_list:
        for store in data.stores:
            key_to_stores[store._key].append(store)

    mask_dict = defaultdict(dict)
    for out_store in out.stores:
        key = out_store._key
        stores = key_to_stores[key]
        for attr in stores[0].keys():
            if attr in exclude_keys:
                continue

            values = [store[attr] for store in stores]

            if attr == "num_nodes":
                out_store._num_nodes = values
                out_store.num_nodes = sum(values)
                continue

            if attr == "ptr":
                continue

            value, mask = _dense_padded_collate(
                attr, values, data_list, stores, max_size=max_size
            )

            out_store[attr] = value
            if mask is not None:
                if key is not None:
                    mask_dict[key][attr] = mask
                else:
                    mask_dict[attr] = mask

    return out, mask_dict


def dense_padded_from_data_list(
    data_list: List[BaseData],
    follow_batch: Optional[List[str]] = None,
    exclude_keys: Optional[List[str]] = None,
    max_size: Optional[int] = None,
):

    batch, mask_dict = dense_padded_collate(
        Batch,
        data_list=data_list,
        follow_batch=follow_batch,
        exclude_keys=exclude_keys,
        max_size=max_size,
    )

    batch._num_graphs = len(data_list)
    batch.mask_dict = mask_dict

    return batch


class DensePaddingCollater:

    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        max_size: Optional[int] = None,
    ):
        self.dataset = dataset
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
        self.max_size = max_size

    def __call__(self, batch: List[Any]) -> Any:

        elem = batch[0]
        if isinstance(elem, BaseData):
            return dense_padded_from_data_list(
                batch,
                follow_batch=self.follow_batch,
                exclude_keys=self.exclude_keys,
                max_size=self.max_size,
            )
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, TensorFrame):
            return torch_frame.cat(batch, along="row")
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, Mapping):
            return {key: self([data[key] for data in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, "_fields"):
            return type(elem)(*(self(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self(s) for s in zip(*batch)]

        raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

    def collate_fn(self, batch: List[Any]) -> Any:

        if isinstance(self.dataset, OnDiskDataset):
            return self(self.dataset.multi_get(batch))
        return self(batch)


class DensePaddingDataLoader(torch.utils.data.DataLoader):

    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        max_size: Optional[int] = None,
        **kwargs,
    ):

        kwargs.pop("collate_fn", None)

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys
        self.max_size = max_size

        self.collator = DensePaddingCollater(
            dataset, follow_batch, exclude_keys, max_size
        )

        if isinstance(dataset, OnDiskDataset):
            dataset = range(len(dataset))

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=self.collator.collate_fn,
            **kwargs,
        )
