from typing import Any, overload

import torch
from matplotlib.animation import ArtistAnimation
from torch import Tensor
from torch._prims_common import DeviceLikeType
from tqdm import tqdm

from args import AdversarialTrainingConfig
from mdp.mdp_attacker import MDPAttacker
from mdp.mdp_dataset import MDPDataset, MDPDatasetTorch


class MDPController:
    n_envs: int
    n_steps: int
    state_dim: int
    n_states: int
    n_actions: int
    device: DeviceLikeType | None

    def __init__(self, n_envs: int, n_steps: int, state_dim: int, n_states: int, n_actions: int, device: DeviceLikeType | None = None):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_states = n_states
        self.state_dim = state_dim
        self.n_actions = n_actions
        self.device = device

    def sample_actions(self, states: Tensor) -> Tensor:
        raise NotImplementedError()

    def clear_dataset(self) -> None:
        """Clear collected trajectory used for updating the algorithm"""
        pass

    def append(self, states: Tensor, actions: Tensor, rewards: Tensor, states_next: Tensor, rewards_original: Tensor) -> None:
        """Collect transition for having a history of the trajectory (for in-context, training, etc.)"""
        pass

    def update(self, dataset: MDPDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        """Apply an update to the controller from a dataset of transitions"""
        return [], {}

    def reinitialize(self) -> None:
        pass


class BaseMDP:
    n_envs: int
    n_steps: int
    n_states: int
    action_dim: int
    states: Tensor
    attacker: MDPAttacker | None
    corrupted_steps: Tensor | None

    def env_name(self) -> str:
        raise NotImplementedError()

    def __init__(self, n_envs: int, n_steps: int, n_states: int, action_dim: int):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_states = n_states
        self.action_dim = action_dim

    def _set_attacker(self, attacker: MDPAttacker, eps_episodes: float, eps_steps: float):
        self.attacker = attacker

        if attacker is None:
            self.attacker = None
            return

        device = self.attacker.device

        corrupted_steps_all = (
            torch.multinomial(torch.tensor([1 - eps_steps, eps_steps]), self.n_envs * self.n_steps, replacement=True)
            .to(dtype=torch.bool, device=device)
            .reshape(self.n_envs, self.n_steps)
        )
        corrupted_envs = (
            torch.multinomial(torch.tensor([1 - eps_episodes, eps_episodes]), self.n_envs, replacement=True).to(dtype=torch.bool, device=device).reshape(self.n_envs, 1)
        )
        self.corrupted_steps = corrupted_steps_all * corrupted_envs

    def sample_states(self) -> Tensor:
        raise NotImplementedError()

    def reset(self) -> Tensor:
        raise NotImplementedError()

    def step(self, actions: Tensor) -> tuple[Tensor, Tensor, bool]:
        raise NotImplementedError()

    @overload
    def deploy(
        self,
        controller: MDPController,
        *,
        omit_optimal_actions: bool = True,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        save_video: bool = False,
        force_show_progress: bool = False,
        **kwargs,
    ) -> MDPDatasetTorch: ...

    @overload
    def deploy(
        self,
        controller: MDPController,
        attacker: MDPAttacker | None,
        eps_episodes: float,
        eps_steps: float,
        *,
        omit_optimal_actions: bool = True,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        save_video: bool = False,
        force_show_progress: bool = False,
        **kwargs,
    ) -> MDPDatasetTorch: ...
    def deploy(
        self,
        controller: MDPController,
        attacker: MDPAttacker | None = None,
        eps_episodes: float | None = None,
        eps_steps: float | None = None,
        *,
        omit_optimal_actions: bool = True,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        save_video: bool = False,
        force_show_progress: bool = False,
        **kwargs,
    ) -> MDPDatasetTorch:
        """Deploy a controller in the environment with corruption. Returns the trajectories of the deployment."""
        self.reset()
        if clear_dataset:
            controller.clear_dataset()
        if attacker is None:
            self.attacker = None
        else:
            assert eps_episodes is not None and eps_steps is not None, "eps_episodes and eps_steps must be set"
            self._set_attacker(attacker, eps_episodes, eps_steps)

        dataset = MDPDataset(self.n_envs, self.n_steps, self.n_states, self.states.shape[-1], self.action_dim, controller.device)

        if (self.n_envs < 10000 or self.n_steps < 100) and not force_show_progress:
            loop = lambda x: x
        else:
            loop = lambda x: tqdm(x, desc=(f"{pbar_desc} " if pbar_desc is not None else "") + "Deploy - " + controller.__class__.__name__)

        for _ in loop(range(self.n_steps)):
            states = self.states.clone()
            actions = controller.sample_actions(states)  # .float()

            rewards, rewards_original, _ = self.step(actions, **kwargs)
            states_next = self.states.clone()

            controller.append(states, actions, rewards, states_next, rewards_original)
            dataset.append(states, actions, rewards, states_next, rewards_original)

        if omit_optimal_actions:
            optimal_actions = None
        else:
            optimal_actions = self.get_optimal_actions_per_state(dataset.states.int(), pbar_desc=pbar_desc)
        query_states = self.sample_states()
        optimal_query_actions = self.get_optimal_actions_per_state(query_states[:, None, :].int()).squeeze(1)

        convert_states_onehot = self.n_states if self.env_name() == "chain" else None  # TODO: leaky abstr.

        return dataset.finalize(optimal_actions, query_states, optimal_query_actions, context_len=context_len, convert_states_onehot=convert_states_onehot)

    def get_optimal_actions_per_state(self, states: Tensor, *, pbar_desc=None) -> Tensor:
        raise NotImplementedError()

    def visualize_dataset(self, dataset: MDPDataset) -> ArtistAnimation:
        raise NotImplementedError()
