from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy
from multiprocessing.context import get_spawning_popen
from pathlib import Path

from tensordict import (
    is_tensor_collection,
    LazyStackedTensorDict,
    TensorDict,
    TensorDictBase,
)
from torch import multiprocessing as mp
from torch.utils._pytree import LeafSpec, tree_flatten, tree_map, tree_unflatten
from torchrl.data.replay_buffers.storages import _flip_list, LazyMemmapStorage
from torchrl.data.replay_buffers.utils import _is_int, INT_CLASSES


class SharedLazyMemmap(LazyMemmapStorage):
    _storage: LazyStackedTensorDict

    def __init__(self, max_size, scratch_dir):
        super().__init__(max_size=max_size, scratch_dir=scratch_dir)
        self.scratch_dir = Path(scratch_dir)
        self.lock = mp.RLock()

    def set(
        self,
        cursor,
        data,
    ):
        with self.lock:
            if isinstance(data, list):
                # flip list
                try:
                    data = _flip_list(data)
                except Exception:
                    raise RuntimeError(
                        "Stacking the elements of the list resulted in "
                        "an error. "
                        f"Storages of type {type(self)} expect all elements of the list "
                        f"to have the same tree structure. If the list is compact (each "
                        f"leaf is itself a batch with the appropriate number of elements) "
                        f"consider using a tuple instead, as lists are used within `extend` "
                        f"for per-item addition."
                    )

            if not self.initialized:
                if not isinstance(cursor, INT_CLASSES):
                    if is_tensor_collection(data):
                        self._init(data[0])
                    else:
                        self._init(tree_map(lambda x: x[0], data))
                else:
                    self._init(data)
            if _is_int(cursor):
                if cursor < len(self._storage.tensordicts):
                    self._storage.tensordicts[cursor] = data.memmap(  # type: ignore
                        self.scratch_dir / f"{int(cursor)}"
                    ).unlock_()
                else:
                    # Shortcutting lazy memmap append which checks the lock status
                    self._storage.tensordicts.append(data)  # type: ignore
            else:
                with ThreadPoolExecutor(max_workers=32) as executor:
                    futures = []
                    for c, _data in zip(cursor, data):  # type: ignore
                        futures.append(
                            executor.submit(
                                lambda dt, idx, path: dt.memmap(
                                    path / f"{int(idx)}"
                                ).unlock_(),
                                _data,
                                c,
                                self.scratch_dir,
                            )
                        )
                    for c, _data in zip(cursor, futures):  # type: ignore
                        _data = _data.result()
                        if c < len(self._storage.tensordicts):
                            self._storage.tensordicts[c] = _data
                        else:
                            # Shortcutting lazy memmap append which checks the lock status
                            self._storage.tensordicts.append(_data)
            self._get_new_len(data, cursor)

    def _init(self, data) -> None:
        if self.device == "auto":
            self.device = data.device
        if self.device.type != "cpu":  # type: ignore
            raise RuntimeError(f"Unsupported device: {self.device}")

        def max_size_along_dim0(data_shape):
            if self.ndim > 1:
                return (
                    -(self.max_size // -data_shape[: self.ndim - 1].numel()),
                    *data_shape,
                )
            return (self.max_size, *data_shape)

        if is_tensor_collection(data):
            data = data.clone().to(self.device).memmap(self.scratch_dir / "0").unlock_()  # type: ignore
            out = LazyStackedTensorDict(data, stack_dim=0)
        else:
            raise RuntimeError("Unsupported data type.")
        self._storage = out
        self.initialized = True

    def get(self, index):
        with self.lock:
            if _is_int(index):
                return TensorDict.load_memmap(self.scratch_dir / f"{int(index)}")  # type: ignore
            else:
                with ThreadPoolExecutor(max_workers=32) as executor:
                    futures = []
                    for c in index:  # type: ignore
                        futures.append(
                            executor.submit(
                                lambda idx, path: TensorDict.load_memmap(
                                    path / f"{int(idx)}"
                                ),
                                c,
                                self.scratch_dir,
                            )
                        )
                    wait(futures)
                    return LazyStackedTensorDict(
                        *[future.result() for future in futures], stack_dim=0
                    )

    def __getstate__(self):
        state = copy(self.__dict__)
        if get_spawning_popen() is None:
            len = self._len
            del state["_len_value"]
            state["len__context"] = len
        return state
