from typing import Any, Literal, overload

import torch
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.nn.functional import one_hot
from tqdm import tqdm

from args import AdversarialTrainingConfig
from bandit2.bandit_attacker import BanditAttacker
from bandit2.bandit_dataset import BanditDataset, BanditDatasetTorch


class BanditController:
    n_envs: int
    n_steps: int
    n_actions: int
    device: DeviceLikeType | None

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, device: DeviceLikeType | None = None):
        self.n_envs = n_envs
        self.n_steps = n_steps
        self.n_actions = n_actions
        self.device = device

    def sample_actions(self) -> Tensor:
        raise NotImplementedError()

    def clear_dataset(self) -> None:
        pass

    def append(self, actions: Tensor, rewards: Tensor, rewards_original: Tensor) -> None:
        pass

    def update(self, dataset: BanditDatasetTorch, adv_train_config: AdversarialTrainingConfig) -> tuple[list[dict[str, Any]], dict[str, Any]]:
        return [], {}


class BanditBiasedRandomController(BanditController):
    probs: Tensor

    def __init__(self, n_envs: int, n_steps: int, n_actions: int, device: str | torch.device | int | None = None):
        super().__init__(n_envs, n_steps, n_actions, device)

        self.clear_dataset()

    def clear_dataset(self) -> None:
        arm_bias_prob = (torch.randint(0, 10, (self.n_envs,), device=self.device) / 10).unsqueeze(1)
        arm_bias_idx = torch.randint(0, self.n_actions, (self.n_envs,), device=self.device)
        arm_bias_probs = one_hot(arm_bias_idx, self.n_actions).float()

        dirichlet_probs = torch.distributions.Dirichlet(torch.ones((self.n_envs, self.n_actions), device=self.device)).sample()

        self.probs = (1 - arm_bias_prob) * dirichlet_probs + arm_bias_prob * arm_bias_probs

    def sample_actions(self) -> Tensor:
        indices = torch.distributions.Categorical(probs=self.probs).sample()
        return one_hot(indices, self.n_actions).float()


class BanditUniformRandomController(BanditController):
    def sample_actions(self) -> Tensor:
        indices = torch.randint(0, self.n_actions, (self.n_envs,), device=self.device)
        return one_hot(indices, self.n_actions).float()


class BaseBanditEnv:
    n_envs: int
    n_steps: int

    def __init__(self, n_envs: int, n_steps: int):
        self.n_envs = n_envs
        self.n_steps = n_steps

    def reset(self) -> None:
        raise NotImplementedError()

    def step(self, action: Tensor) -> tuple[Tensor, Tensor, bool]:
        raise NotImplementedError()

    def _set_attacker(self, attacker: BanditAttacker, eps_episodes: float, eps_steps: float) -> None:
        raise NotImplementedError()

    @overload
    def deploy(
        self,
        controller: BanditController,
        *,
        pbar_desc: str | None = None,
        **kwargs,
    ) -> BanditDatasetTorch: ...
    @overload
    def deploy(
        self,
        controller: BanditController,
        attacker: BanditAttacker | None,
        eps_episodes: float,
        eps_steps: float,
        *,
        pbar_desc: str | None = None,
        **kwargs,
    ) -> BanditDatasetTorch: ...
    def deploy(
        self,
        controller: BanditController,
        attacker: BanditAttacker | None = None,
        eps_episodes: float | None = None,
        eps_steps: float | None = None,
        *,
        pbar_desc: str | None = None,
        **kwargs,
    ) -> BanditDatasetTorch:
        # def deploy(self, controller: BanditController, *, pbar_desc: str | None = None) -> tuple[Tensor, BanditDatasetTorch]:
        """Deploy a controller in the environment. The dataset gets saved in the controller."""
        self.reset()
        controller.clear_dataset()
        dataset = BanditDataset(controller.n_envs, self.n_steps, controller.n_actions, device=controller.device)

        if attacker is None:
            self.attacker = None
        else:
            assert eps_episodes is not None and eps_steps is not None, "eps_episodes and eps_steps must be set"
            self._set_attacker(attacker, eps_episodes, eps_steps)
            attacker.clear_dataset()

        rewards_original = []

        if self.n_envs < 10000 or self.n_steps < 100:
            loop = lambda x: x
        else:
            loop = lambda x: tqdm(x, desc=(f"{pbar_desc} " if pbar_desc is not None else "") + "Deploy - " + controller.__class__.__name__)

        for _ in loop(range(self.n_steps)):
            action = controller.sample_actions()

            reward, reward_original, _ = self.step(action)
            rewards_original.append(reward_original.detach())

            controller.append(action, reward, reward_original)
            if attacker is not None:
                attacker.append(action, reward, reward_original)
            dataset.append(action, reward, reward_original)

        optimal_actions = self.get_optimal_actions()

        return dataset.finalize(optimal_actions)

    def get_optimal_actions(self) -> Tensor:
        raise NotImplementedError()


class BanditEnv(BaseBanditEnv):
    original_means: Tensor
    optimal_actions: Tensor
    variance: float
    n_steps: int
    n_actions: int

    poisoned_means: Tensor
    current_step: int
    attacker: BanditAttacker | None
    corrupted_steps: Tensor

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, n_arms: int = 5, variance: float = 0.3, device=None) -> "BanditEnv":
        means = torch.rand((n_envs, n_arms), device=device)
        return BanditEnv(means, n_steps, variance, device=device)

    def __init__(self, means: Tensor, n_steps: int, variance: float, device=None) -> None:
        n_envs = means.shape[0]
        super().__init__(n_envs, n_steps)

        self.attacker = None

        self.original_means = means
        self.variance = variance

        self.optimal_actions = torch.argmax(self.original_means, dim=-1)
        self.n_actions = self.original_means.shape[1]

        self.reset()

    def _set_attacker(self, attacker: BanditAttacker, eps_episodes: float, eps_steps: float):
        self.attacker = attacker

        if attacker is None:
            return

        device = self.attacker.device

        corrupted_steps_all = (
            torch.multinomial(torch.tensor([1 - eps_steps, eps_steps]), self.n_envs * self.n_steps, replacement=True)
            .to(dtype=torch.bool, device=device)
            .reshape(self.n_envs, self.n_steps)
        )
        corrupted_episodes = (
            torch.multinomial(torch.tensor([1 - eps_episodes, eps_episodes]), self.n_envs, replacement=True).to(dtype=torch.bool, device=device).reshape(self.n_envs, 1)
        )
        self.corrupted_steps = corrupted_steps_all * corrupted_episodes

    def reset(self) -> None:
        self.current_step = 0

    def step(self, actions: Tensor) -> tuple[Tensor, Tensor, bool]:
        if self.current_step >= self.n_steps:
            raise RuntimeError(f"Episode has already ended (current_step exceeds n_steps={self.n_steps}).")

        done = False

        r, r_original = self.get_arm_rewards(actions, "both")

        self.current_step += 1

        if self.current_step == self.n_steps:
            done = True

        return r, r_original, done

    @overload
    def get_arm_rewards(self, actions: Tensor, type: Literal["corrupted", "original"]) -> Tensor: ...
    @overload
    def get_arm_rewards(self, actions: Tensor, type: Literal["both"]) -> tuple[Tensor, Tensor]: ...
    def get_arm_rewards(self, actions: Tensor, type: Literal["corrupted", "original", "both"] = "corrupted"):
        # e -- n_envs
        # s -- n_steps (or '...' to be general)
        # a -- n_arms
        r = torch.einsum("ea,e...a->e...", self.original_means, actions)

        # Add noise
        r += torch.randn(r.shape, device=r.device) * self.variance

        if type == "original":
            return r

        r_original = r

        if self.attacker is None:
            if type == "both":
                return r, r
            else:
                return r

        r_poisoned = self.attacker.get_reward(r_original, actions) if self.attacker is not None else r
        r = r + r_poisoned * (self.corrupted_steps[:, self.current_step] if r.ndim == 1 else self.corrupted_steps)

        if type == "both":
            return r, r_original

        return r

    def get_optimal_actions(self) -> Tensor:
        optimal_actions_onehot = one_hot(self.optimal_actions, num_classes=self.n_actions)
        return optimal_actions_onehot.to(dtype=torch.float)


class BinomialBanditEnv(BaseBanditEnv):
    optimal_actions: Tensor
    original_n: int
    original_ps: Tensor

    n_steps: int
    n_actions: int

    poisoned_means: Tensor
    current_step: int
    attacker: BanditAttacker | None
    corrupted_steps: Tensor

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, n_arms: int = 5, binomial_n: int = 10, device=None) -> "BinomialBanditEnv":
        ps = torch.ones((n_envs, n_arms), device=device) * 0.8
        optimal_actions = torch.randint(0, n_arms, (n_envs,))
        ps[torch.arange(n_envs), optimal_actions] = 0.9
        return BinomialBanditEnv(binomial_n, ps, n_steps, device=device)

    def __init__(self, n: int, ps: Tensor, n_steps: int, device=None) -> None:
        n_envs = ps.shape[0]
        super().__init__(n_envs, n_steps)

        self.attacker = None

        self.original_n = n
        self.original_ps = ps

        self.optimal_actions = torch.argmax(self.original_ps, dim=-1)
        self.n_actions = self.original_ps.shape[1]

        self.reset()

    def set_attacker(self, attacker: BanditAttacker, eps_episodes: float, eps_steps: float):
        self.attacker = attacker

        if attacker is None:
            return

        device = self.attacker.device

        corrupted_steps_all = (
            torch.multinomial(torch.tensor([1 - eps_steps, eps_steps]), self.n_envs * self.n_steps, replacement=True)
            .to(dtype=torch.bool, device=device)
            .reshape(self.n_envs, self.n_steps)
        )
        corrupted_episodes = (
            torch.multinomial(torch.tensor([1 - eps_episodes, eps_episodes]), self.n_envs, replacement=True).to(dtype=torch.bool, device=device).reshape(self.n_envs, 1)
        )
        self.corrupted_steps = corrupted_steps_all * corrupted_episodes

    def reset(self) -> None:
        self.current_step = 0

    def step(self, actions: Tensor) -> tuple[Tensor, Tensor, bool]:
        if self.current_step >= self.n_steps:
            raise RuntimeError(f"Episode has already ended (current_step exceeds n_steps={self.n_steps}).")

        done = False

        r, r_original = self.get_arm_rewards(actions, "both")

        self.current_step += 1

        if self.current_step == self.n_steps:
            done = True

        return r, r_original, done

    @overload
    def get_arm_rewards(self, actions: Tensor, type: Literal["corrupted", "original"]) -> Tensor: ...
    @overload
    def get_arm_rewards(self, actions: Tensor, type: Literal["both"]) -> tuple[Tensor, Tensor]: ...
    def get_arm_rewards(self, actions: Tensor, type: Literal["corrupted", "original", "both"] = "corrupted"):
        r = torch.distributions.Binomial(self.original_n, self.original_ps[torch.arange(self.n_envs), actions.argmax(dim=-1)]).sample()

        if type == "original":
            return r

        r_original = r

        if self.attacker is None:
            if type == "both":
                return r, r
            else:
                return r

        r_poisoned = self.attacker.get_reward(r_original, actions) if self.attacker is not None else r
        r = r + r_poisoned * (self.corrupted_steps[:, self.current_step] if r.ndim == 1 else self.corrupted_steps)

        if type == "both":
            return r, r_original

        return r

    def get_optimal_actions(self) -> Tensor:
        optimal_actions_onehot = one_hot(self.optimal_actions, num_classes=self.n_actions)
        return optimal_actions_onehot.to(dtype=torch.float)


def generate_bandit_trajectories(
    n_envs: int, n_steps: int, n_arms: int, variance: float, device: DeviceLikeType | None = None, *, pbar_desc: str | None = None
) -> BanditDatasetTorch:
    ctrl = BanditBiasedRandomController(n_envs, n_steps, n_arms, device)
    envs = BanditEnv.sample(n_envs, n_steps, n_arms, variance, device)

    dataset = envs.deploy(ctrl, pbar_desc=pbar_desc)
    return dataset
