from typing import Literal, overload

import matplotlib.pyplot as plt
import torch
from matplotlib.animation import ArtistAnimation
from matplotlib.axes import Axes
from matplotlib.patches import Ellipse, Rectangle
from torch import Tensor
from torch._prims_common import DeviceLikeType
from tqdm import tqdm

from mdp.mdp_controller import MDPRandomController
from mdp.mdp_dataset import MDPDataset, MDPDatasetTorch
from mdp.mdp_env import BaseMDP

device = "cuda" if torch.cuda.is_available() else None

transit_kernel = torch.tensor([-1, 1], dtype=torch.float, device=device)
n_actions = 2

handicap = 1


class ChainEnv(BaseMDP):
    device: DeviceLikeType | None

    n_states: int
    n_actions: int
    reward_kernel: Tensor
    states: Tensor
    current_step: int
    variance: float

    optimal_actions: Tensor

    def env_name(self) -> str:
        return "chain"

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, n_states: int = 5, variance: float = 0.3, device=None) -> "ChainEnv":
        reward_kernel = torch.randn((n_envs, n_states, n_actions), device=device)

        optimal_actions = torch.zeros((n_envs, n_states, n_actions), dtype=torch.int, device=device)

        for env in tqdm(range(n_envs), desc=f"Generating Environments - {cls.__name__}"):
            current_r = reward_kernel[env]

            # Find top 2 actions
            sort = current_r.reshape(-1).sort(descending=True)
            val_2max = sort.values[1].item()
            pos_max = torch.unravel_index(sort.indices[0], current_r.shape)
            pos_2max = torch.unravel_index(sort.indices[1], current_r.shape)
            pos_swap = torch.tensor([pos_max[0] + transit_kernel[pos_max[1]].int(), 1 - pos_max[1]], dtype=torch.int)

            # Set optimal actions to the direction of pos_max[0]
            optimal_actions[env, pos_max[0] :, 0] = 1
            optimal_actions[env, : pos_max[0], 1] = 1

            if pos_swap[0] >= n_states:
                optimal_actions[env, pos_max[0], 0] = 0
                optimal_actions[env, pos_max[0], 1] = 1
                continue
            if pos_swap[0] < 0:
                continue

            # Make top 2 action rewards consecutive
            val_swap = current_r[*pos_swap]
            current_r[*pos_2max] = val_swap
            current_r[*pos_swap] = val_2max

            current_r[*pos_max] += handicap
            current_r[*pos_2max] += handicap

            reward_kernel[env] = current_r

            if pos_max[1] == 0:
                continue

            # Set optimal actions correctly
            optimal_actions[env, pos_max[0], 0] = 0
            optimal_actions[env, pos_max[0], 1] = 1

        return ChainEnv(reward_kernel, n_envs, n_steps, n_states, variance, device=device, optimal_actions=optimal_actions)

    def __init__(self, reward_kernel: Tensor, n_envs: int, n_steps: int, n_states: int, variance: float, device: DeviceLikeType | None = None, *, optimal_actions: Tensor):
        super().__init__(n_envs, n_steps, n_states, n_actions)
        self.n_states = n_states
        self.n_actions = 2
        self.reward_kernel = reward_kernel
        self.variance = variance
        self.optimal_actions = optimal_actions

        self.device = device

        self.reset()

    def reset(self) -> Tensor:
        self.current_step = 0
        states = torch.zeros((self.n_envs, 1), dtype=torch.int, device=self.device)
        self.states = states

        return states

    def sample_states(self) -> Tensor:
        states = torch.randint(0, self.n_states, (self.n_envs, 1), dtype=torch.int, device=self.device)
        return states

    @overload
    def get_rewards(self, states: Tensor, actions: Tensor, type: Literal["corrupted", "clean"]) -> Tensor: ...
    @overload
    def get_rewards(self, states: Tensor, actions: Tensor, type: Literal["both"]) -> tuple[Tensor, Tensor]: ...
    def get_rewards(self, states: Tensor, actions: Tensor, type: Literal["corrupted", "clean", "both"] = "corrupted") -> tuple[Tensor, Tensor] | Tensor:
        assert states.ndim == 2 and actions.ndim == 3

        reward_at_states = self.reward_kernel[torch.arange(states.shape[0])[:, None], states, :]
        # e -- envs
        # s -- steps
        # a -- actions
        r = torch.einsum("esa,esa->es", reward_at_states, actions)

        # Add noise
        r += torch.randn(r.shape, device=r.device) * self.variance

        if type == "clean":
            return r

        if self.attacker is None:
            raise ValueError("`attacker` is unset, please first call `set_attacker` or get type='clean' rewards only.")

        assert self.corrupted_steps is not None

        r_original = r

        r_poisoned = self.attacker.get_reward(r_original, states.unsqueeze(1), actions) if self.attacker is not None else r
        r += r_poisoned * (self.corrupted_steps[:, self.current_step : self.current_step + 1] if r.shape[1] == 1 else self.corrupted_steps)

        if type == "both":
            return r, r_original

        return r

    def step(self, actions: Tensor, with_attacker: bool):
        if self.current_step >= self.n_steps:
            raise RuntimeError(f"Episode has already ended (current_step exceeds n_steps={self.n_steps}).")

        # assert actions.shape == (self.n_envs, 2)
        # assert self.states.shape == n_envs

        # print(self.states.shape, actions.shape)

        if with_attacker:
            r, r_original = self.get_rewards(self.states, actions[:, None, :], "both")
        else:
            r = self.get_rewards(self.states, actions[:, None, :], "clean")

        self.states[:, 0] += torch.einsum("e...a,a->e...", actions, transit_kernel).int()
        self.states[self.states < 0] = 0
        self.states[self.states >= self.n_states] = self.n_states - 1

        self.current_step += 1

        done = False
        if self.current_step == self.n_steps:
            done = True

        if with_attacker:
            return r, r_original, done
        else:
            return r, done

    def get_optimal_actions_per_state(self, states: Tensor) -> Tensor:
        if states.shape[2] > 1:
            states = states.argmax(dim=-1)
        else:
            states = states.squeeze(dim=-1)

        state_indices = states.int()
        optimal_actions_per_step = self.optimal_actions[torch.arange(self.n_envs)[:, None], state_indices, :].float()

        return optimal_actions_per_step

    def visualize_dataset(self, dataset: MDPDataset | MDPDatasetTorch, attacker_weights=None) -> ArtistAnimation:
        plt.rcParams.update({"font.size": 14})
        plt.rcParams["svg.fonttype"] = "none"
        plt.rcParams["font.family"] = "Latin Modern Math"

        n_states = dataset.n_states

        if dataset.states.shape[2] > 1:
            states = dataset.states.argmax(dim=-1)
        else:
            states = dataset.states.squeeze(dim=-1)

        state_indices = states.int()
        optimal_actions_per_step = dataset.actions[torch.arange(self.n_envs)[:, None], state_indices, :].float()

        env = 0

        scale = 3
        state_margin = 0.2
        arrow_margin = 0
        arrow_offset = -0.5 + state_margin + arrow_margin
        arrow_length = 2 * (state_margin + arrow_margin)
        agent_size = 0.2
        fig, ax = plt.subplots(figsize=(scale * (n_states + 2 * state_margin + arrow_margin), scale))  # type: ignore[arg-type]
        ax: Axes

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlim(-arrow_length, n_states + arrow_length)
        ax.set_ylim(0, 1)

        state_size = 1 - 2 * state_margin
        arrowprops = {"head_length": 0.075, "head_width": 0.05, "length_includes_head": True}
        textprops = {"fontsize": 8 * scale, "horizontalalignment": "center"}

        for i in range(n_states):
            state = Ellipse((i + 0.5, 0.5), state_size, state_size, facecolor="lightgray", edgecolor="black", linewidth=2)
            ax.add_patch(state)

        for i in range(n_states):
            dir_opt = -1 + 2 * self.optimal_actions[env, i, 1].item()
            ax.arrow(i + 0.5 + dir_opt * arrow_offset, 0.5 + dir_opt * (0.5 - 2 * state_margin), -dir_opt * arrow_length, 0, **arrowprops, color="black")
            ax.arrow(i + 0.5 - dir_opt * arrow_offset, 0.5 - dir_opt * (0.5 - 2 * state_margin), dir_opt * arrow_length, 0, **arrowprops, color="tab:blue")

            ax.text(i, 1 - 2 * state_margin + 0.05, f"{self.reward_kernel[env, i, 0]:.2f}", **textprops)
            ax.text(i + 1, 2 * state_margin - 0.05, f"{self.reward_kernel[env, i, 1]:.2f}", **textprops, verticalalignment="top")

        fig.tight_layout()

        # Animation starts here

        anim_artists = []
        total_reward = 0.0
        correct_actions = 0

        for t in range(dataset.n_steps):
            agent_state = dataset.states[env, t].argmax(dim=-1).item()
            agent_dir = -1 + 2 * dataset.actions[env, t, 1].item()

            agent_rect = Rectangle((agent_state + 0.5 - agent_size * 0.5, 0.5 - agent_size / 2), agent_size, agent_size, facecolor="black")
            agent_arrow = ax.arrow(agent_state + 0.5, 0.5, agent_dir * arrow_length / 2, 0, **arrowprops, color="tab:red")
            ax.add_patch(agent_rect)

            total_reward += dataset.rewards[env, t].item()
            correct_actions += 1 if optimal_actions_per_step[env, t, 1].item() == dataset.actions[env, t, 1].item() else 0

            timestep = ax.text(-state_margin, state_margin, f"$t={t}$", fontsize=8 * scale)
            info = ax.text(
                -arrow_length + 0.05,
                0.05,
                (
                    f"state={agent_state:.0f}, opt=[{optimal_actions_per_step[env, t, 0].item():.0f}, {optimal_actions_per_step[env, t, 1].item():.0f}], "
                    f"r={dataset.rewards[env, t].item():.2f}, r̃={dataset.rewards_original[env, t].item():.2f}, R={total_reward:.2f}, "
                    f"acc={correct_actions / (t + 1) * 100:.1f}%"
                ),
            )

            anim_artists.append((agent_rect, agent_arrow, timestep, info))

        anim = ArtistAnimation(fig, anim_artists, interval=100)

        return anim


def generate_trajectories(n_envs: int, n_steps: int, n_states: int, variance: float, device: DeviceLikeType | None = None, *, pbar_desc: str | None = None) -> MDPDatasetTorch:
    ctrl = MDPRandomController(n_envs, n_steps, n_states, 1, 2, device)
    envs = ChainEnv.sample(n_envs, n_steps, n_states, variance, device=device)

    dataset = envs.deploy(ctrl, omit_optimal_actions=False, pbar_desc=pbar_desc)
    return dataset
