from typing import Literal, overload

import torch
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.utils.data import Dataset as TorchDataset


class BanditDatasetTorch(TorchDataset):
    n_envs: int
    n_steps: int
    action_dim: int
    device: DeviceLikeType | None
    _actions: Tensor
    _rewards: Tensor
    _rewards_original: Tensor
    step_ptr: int

    _optimal_actions: Tensor
    shuffle: bool

    def __init__(
        self,
        n_envs: int,
        n_steps: int,
        action_dim: int,
        actions: Tensor,
        rewards: Tensor,
        rewards_original: Tensor,
        step_ptr: int,
        optimal_actions: Tensor,
        shuffle: bool = True,
        device: DeviceLikeType | None = None,
    ):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.action_dim = action_dim
        self.device = device
        self._actions = actions.detach()
        self._rewards = rewards.detach()
        self._rewards_original = rewards_original.detach()
        self.step_ptr = step_ptr

        self._optimal_actions = optimal_actions.detach().float()
        self.shuffle = shuffle

    def __len__(self) -> int:
        return self.n_envs

    @overload
    def __getitem__(self, index, *, return_perm: Literal[True] = True) -> tuple[Tensor, Tensor]: ...
    @overload
    def __getitem__(self, index, *, return_perm: Literal[False] = False) -> tuple[Tensor, Tensor, Tensor]: ...
    def __getitem__(self, index, *, return_perm=False):
        if self.shuffle:
            perm = torch.randperm(self.step_ptr, device=self.device)
        else:
            perm = torch.arange(self.step_ptr, device=self.device)

        context_states = torch.ones((self.step_ptr, 1), device=self.device)
        context_actions = self._actions[index, : self.step_ptr, :][perm]
        context_next_states = torch.ones((self.step_ptr, 1), device=self.device)
        context_rewards = self._rewards[index, : self.step_ptr, None][perm]

        optimal_actions = self._optimal_actions[index]  # .unsqueeze(0).repeat(self.n_steps, 1)

        context = torch.cat((context_states, context_actions, context_next_states, context_rewards), dim=1)
        query_line = torch.zeros((1, context.shape[-1]), device=self.device)

        x = torch.cat((query_line, context), dim=0)

        if return_perm:
            return x, optimal_actions, perm

        return x, optimal_actions

    @property
    def rewards(self) -> Tensor:
        return self._rewards

    @property
    def rewards_original(self) -> Tensor:
        return self._rewards_original

    @property
    def actions(self) -> Tensor:
        return self._actions


class BanditDataset:
    action_dim: int
    n_envs: int
    n_steps: int

    _actions: Tensor
    _rewards: Tensor
    _rewards_original: Tensor

    step_ptr: int

    def __init__(self, n_envs: int, n_steps: int, action_dim: int, device=None):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.action_dim = action_dim
        self.device = device

        self.clear()

    def clear(self) -> None:
        self._actions = torch.zeros((self.n_envs, self.n_steps, self.action_dim), device=self.device)
        self._rewards = torch.zeros((self.n_envs, self.n_steps), device=self.device)
        self._rewards_original = torch.zeros((self.n_envs, self.n_steps), device=self.device)
        self.step_ptr = 0

    @property
    def actions(self) -> Tensor:
        return self._actions[:, : self.step_ptr, :]

    @property
    def rewards(self) -> Tensor:
        return self._rewards[:, : self.step_ptr]

    @property
    def rewards_original(self) -> Tensor:
        return self._rewards_original[:, : self.step_ptr]

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor):
        if self.step_ptr >= self.n_steps:
            raise ValueError("Dataset full, cannot append")

        if rewards.ndim == 2:
            rewards = rewards.squeeze(dim=1)

        self._actions[:, self.step_ptr, :] = actions
        self._rewards[:, self.step_ptr] = rewards
        self._rewards_original[:, self.step_ptr] = rewards_original
        self.step_ptr += 1

    def get_context_for_transformer(self, *, with_rewards_original: bool = False) -> Tensor:
        rewards = self._rewards_original[..., None] if with_rewards_original else self._rewards[..., None]

        dummy_states = torch.ones((self.n_envs, self.n_steps, 1), device=self.device)
        padded_context = torch.concat((dummy_states, self._actions, dummy_states, rewards), dim=-1)
        context = padded_context[:, : self.step_ptr, :]
        return context

    def finalize(self, optimal_actions: Tensor, shuffle: bool = True) -> BanditDatasetTorch:
        return BanditDatasetTorch(
            self.n_envs, self.n_steps, self.action_dim, self._actions, self._rewards, self._rewards_original, self.step_ptr, optimal_actions, shuffle, device=self.device
        )
