from math import ceil
from typing import Any, overload
from warnings import warn

import torch
from skimage.transform import resize
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.nn.functional import one_hot
from torch.utils.data import Dataset as TorchDataset
from torchvision.transforms import transforms

MW_INFO_DIM = 2
MW_IMAGE_SHAPE = (3, 25, 25)
MW_RESIZE_IMG_SHAPE = (25, 25, 3)


def process_miniworld_images(images: Any) -> Tensor:
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

    images_resized = [resize(image, MW_RESIZE_IMG_SHAPE, anti_aliasing=True) for image in images]
    images_transformed: list[Tensor] = [transform(image) for image in images_resized]  # type: ignore
    images_modified = torch.stack(images_transformed).float()

    return images_modified


class MDPDatasetTorch(TorchDataset):
    n_envs: int
    n_steps: int
    context_len: int
    n_states: int
    state_dim: int
    action_dim: int

    _states: Tensor
    _actions: Tensor
    _rewards: Tensor
    _rewards_original: Tensor

    _optimal_actions: Tensor | None
    _query_states: Tensor
    _optimal_query_actions: Tensor

    device: DeviceLikeType | None
    shuffle: bool

    def __init__(
        self,
        dataset: "MDPDataset",
        optimal_actions: Tensor | None,
        query_states: Tensor,
        optimal_query_actions: Tensor,
        shuffle: bool = True,
        *,
        convert_states_onehot: int | None = None,
        context_len: int | None = None,
        device: DeviceLikeType | None = None,
    ):
        super().__init__()

        self.device = device
        self.shuffle = shuffle

        self.n_envs = dataset.n_envs
        self.n_steps = dataset.step_ptr
        self.context_len = context_len if context_len is not None else self.n_steps
        self.n_states = dataset.n_states
        self.state_dim = dataset.state_dim if convert_states_onehot is None else convert_states_onehot
        self.action_dim = dataset.action_dim
        self._states = dataset._states.detach().to(device)
        self._query_states = query_states.detach().to(device)
        self._optimal_query_actions = optimal_query_actions.detach().to(device)
        if convert_states_onehot is not None:
            self._states = one_hot(self._states.squeeze(2).long(), self.state_dim)
            self._query_states = one_hot(self._query_states.squeeze(2).long(), self.state_dim)
        self._actions = dataset._actions.detach().to(device)
        self._rewards = dataset._rewards.detach().to(device)
        self._rewards_original = dataset._rewards_original.detach().to(device)

        if optimal_actions is not None:
            self._optimal_actions = optimal_actions.detach().to(device)
        else:
            self._optimal_actions = None

    def __len__(self) -> int:
        multiplier = ceil(self.n_steps / self.context_len)  # FIXME: ceil? are you sure?
        return self.n_envs * multiplier

    def __getitem__(self, index: int, *, return_perm: bool = False) -> tuple[Tensor, ...]:
        if self.shuffle:
            perm = torch.randperm(self.context_len, device=self.device)
        else:
            perm = torch.arange(self.context_len, device=self.device)

        context_selector = index // self.n_envs
        context_start = context_selector * self.context_len
        context_end = context_start + self.context_len
        index_new = index - context_selector * self.n_envs

        context_states = self._states[index_new, context_start:context_end, :]
        context_actions = self._actions[index_new, context_start:context_end, :]
        context_next_states = self._states[index_new, context_start:context_end, :].roll(-1, dims=0)
        context_next_states[-1, :] = self._query_states[index_new]
        context_rewards = self._rewards[index_new, context_start:context_end, None]

        optimal_actions = self._optimal_query_actions[index_new, :].unsqueeze(0).repeat(self.context_len, 1)

        context = torch.cat((context_states, context_actions, context_next_states, context_rewards), dim=1)[perm]
        query_line = torch.zeros((1, context.shape[-1]), device=self.device)
        query_line[0, : self.state_dim] = self._query_states[None, index_new]

        x = torch.cat((query_line, context), dim=0)

        if return_perm:
            return x, optimal_actions, perm

        return x, optimal_actions

    @property
    def states(self) -> Tensor:
        return self._states

    @property
    def actions(self) -> Tensor:
        return self._actions

    @property
    def rewards(self) -> Tensor:
        return self._rewards

    @property
    def rewards_original(self) -> Tensor:
        return self._rewards_original

    @property
    def query_states(self) -> Tensor:
        return self._query_states


class MDPDataset:
    """A dataset for MDPs which keeps the latest n_steps of transitions in memory"""

    n_envs: int
    n_steps: int
    state_dim: int
    action_dim: int
    device: DeviceLikeType | None

    _states: Tensor
    _actions: Tensor
    _rewards: Tensor
    _rewards_original: Tensor
    step_ptr: int

    _query_states: Tensor

    def __init__(self, n_envs: int, n_steps: int, n_states: int, state_dim: int, action_dim: int, device=None):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_states = n_states
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device

        self.clear()

    def clear(self) -> None:
        self._states = torch.zeros((self.n_envs, self.n_steps, self.state_dim), device=self.device)
        self._actions = torch.zeros((self.n_envs, self.n_steps, self.action_dim), device=self.device)
        self._rewards = torch.zeros((self.n_envs, self.n_steps), device=self.device)
        self._rewards_original = torch.zeros((self.n_envs, self.n_steps), device=self.device)
        self.step_ptr = 0

    @property
    def states(self) -> Tensor:
        return self._states[:, : self.step_ptr, :]

    @states.setter
    def states(self, values: Tensor):
        self._states[:, : self.step_ptr, :] = values

    @property
    def actions(self) -> Tensor:
        return self._actions[:, : self.step_ptr, :]

    @actions.setter
    def actions(self, values: Tensor):
        self._actions[:, : self.step_ptr, :] = values

    @property
    def rewards(self) -> Tensor:
        return self._rewards[:, : self.step_ptr]

    @rewards.setter
    def rewards(self, value: Tensor):
        self._rewards[:, : self.step_ptr] = value

    @property
    def rewards_original(self) -> Tensor:
        return self._rewards_original[:, : self.step_ptr]

    @rewards_original.setter
    def rewards_original(self, value: Tensor):
        self._rewards_original[:, : self.step_ptr] = value

    def append(self, states: Tensor, actions: Tensor, rewards: Tensor, states_next: Tensor, rewards_original: Tensor | None = None, extras: dict[str, Any] = {}):
        if self.step_ptr >= self.n_steps:
            warn(f"Dataset full, rolling to keep only latest {self.n_steps} entries.")
            self._states = self._states.roll(-1, dims=1)
            self._actions = self._actions.roll(-1, dims=1)
            self._rewards = self._rewards.roll(-1, dims=1)
            self._rewards_original = self._rewards_original.roll(-1, dims=1)
            self.step_ptr -= 1

        if states.shape[1] == 1 and self.state_dim != states.shape[1]:
            states = one_hot(states.squeeze(1).long(), self.state_dim)

        if rewards.ndim == 2:
            rewards = rewards.squeeze(dim=1)

        if states.ndim == 1:
            states = states[:, None]

        self._states[:, self.step_ptr, :] = states
        self._actions[:, self.step_ptr, :] = actions
        self._rewards[:, self.step_ptr] = rewards
        if rewards_original is not None:
            if rewards_original.ndim == 2:
                rewards_original = rewards_original.squeeze(dim=1)
            self._rewards_original[:, self.step_ptr] = rewards_original
        self.step_ptr += 1

        self._query_states = states_next

    def get_context_for_transformer(self) -> Tensor:
        padded_context = torch.concat((self._states, self._actions, self._states.roll(-1, dims=1), self._rewards[..., None]), dim=-1)
        context = padded_context[:, : self.step_ptr, :]
        if self.step_ptr != 0:
            context[:, -1, self.state_dim + self.action_dim : 2 * self.state_dim + self.action_dim] = self._query_states
        return context

    def finalize(
        self, optimal_actions: Tensor | None, query_states: Tensor, optimal_query_actions: Tensor, *, convert_states_onehot: int | None = None, context_len: int | None = None
    ) -> MDPDatasetTorch:
        return MDPDatasetTorch(self, optimal_actions, query_states, optimal_query_actions, convert_states_onehot=convert_states_onehot, context_len=context_len, device=self.device)


class MDPDatasetImagesTorch(MDPDatasetTorch):
    _query_images: Tensor
    _images: Tensor
    _infos: Tensor

    def __init__(
        self,
        dataset: "MDPDatasetImages",
        optimal_actions: Tensor | None,
        query_states: Tensor,
        query_images: Tensor,
        optimal_query_actions: Tensor,
        shuffle: bool = True,
        *,
        context_len: int | None = None,
        device: str | torch.device | int | None = None,
    ):
        super().__init__(dataset, optimal_actions, query_states, optimal_query_actions, shuffle, context_len=context_len, device=device)
        self._query_images = query_images.detach().to(device)
        self._images = dataset._images.detach().to(device)
        self._infos = dataset._infos.detach().to(device)

    @property
    def infos(self) -> Tensor:
        return self._infos

    @overload
    def __getitem__(self, index: int, *, return_perm: bool = True) -> tuple[tuple[Tensor, Tensor], Tensor, Tensor]: ...
    @overload
    def __getitem__(self, index: int, *, return_perm: bool = False) -> tuple[tuple[Tensor, Tensor], Tensor]: ...
    def __getitem__(self, index: int, *, return_perm: bool = False) -> tuple[tuple[Tensor, Tensor], Tensor] | tuple[tuple[Tensor, Tensor], Tensor, Tensor]:
        x_full, optimal_actions, perm = super().__getitem__(index, return_perm=True)

        x_nonextstate = torch.concat((x_full[..., : self.state_dim + self.action_dim], x_full[..., self.state_dim * 2 + self.action_dim :]), dim=-1)

        context_selector = index // self.n_envs
        context_start = context_selector * self.context_len
        context_end = context_start + self.context_len
        index_new = index - context_selector * self.n_envs

        context_images = self._images[index_new, context_start:context_end][perm]

        image_seq = torch.cat([self._query_images[index_new][None, ...], context_images], dim=0)

        if return_perm:
            return (image_seq, x_nonextstate), optimal_actions, perm

        return (image_seq, x_nonextstate), optimal_actions


class MDPDatasetImages(MDPDataset):
    _images: Tensor
    _infos: Tensor
    _query_images: Tensor

    def clear(self) -> None:
        super().clear()
        self._images = torch.zeros((self.n_envs, self.n_steps, *MW_IMAGE_SHAPE), device=self.device)
        self._infos = torch.zeros((self.n_envs, self.n_steps, MW_INFO_DIM))

    def append(
        self,
        states: Tensor,
        images: Any,
        actions: Tensor,
        rewards: Tensor,
        states_next: Tensor,
        images_next: Any,
        rewards_original: Tensor | None = None,
        extras: dict[str, Any] = {},
    ):
        super().append(states, actions, rewards, states_next, rewards_original)

        self._query_images = images_next

        self._images[:, self.step_ptr - 1] = process_miniworld_images(images)
        self._infos[:, self.step_ptr - 1] = extras["infos"]

    def get_context_for_transformer(self) -> tuple[Tensor, Tensor]:
        context_full = super().get_context_for_transformer()

        context_nonextstate = torch.concat((context_full[..., : self.state_dim + self.action_dim], context_full[..., self.state_dim * 2 + self.action_dim :]), dim=2)

        images = self._images[:, : self.step_ptr]
        if self.step_ptr != 0:
            images[:, -1] = process_miniworld_images(self._query_images)
        return images, context_nonextstate

    def finalize(
        self,
        optimal_actions: Tensor | None,
        query_states: Tensor,
        query_images: Tensor,
        optimal_query_actions: Tensor,
        *,
        context_len: int | None = None,
    ) -> MDPDatasetImagesTorch:
        return MDPDatasetImagesTorch(self, optimal_actions, query_states, query_images, optimal_query_actions, context_len=context_len, device=self.device)
