from math import sqrt
from typing import Literal, overload

import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import torch
from matplotlib.animation import ArtistAnimation
from matplotlib.axes import Axes
from matplotlib.patches import Rectangle
from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.nn.functional import one_hot
from tqdm import tqdm

from mdp.mdp_controller import MDPRandomController
from mdp.mdp_dataset import MDPDataset, MDPDatasetTorch
from mdp.mdp_env import BaseMDP

device = "cuda" if torch.cuda.is_available() else None


transit_kernel = torch.tensor([[1, 0], [-1, 0], [0, 1], [0, -1], [0, 0]], dtype=torch.float, device=device)
n_actions = 5
a_down = one_hot(torch.tensor(0, device=device), 5)
a_up = one_hot(torch.tensor(1, device=device), 5)
a_right = one_hot(torch.tensor(2, device=device), 5)
a_left = one_hot(torch.tensor(3, device=device), 5)
a_noop = one_hot(torch.tensor(4, device=device), 5)
n_goals_in_env = 2


class DarkroomEnv(BaseMDP):
    device: DeviceLikeType | None

    n_states: int
    n_actions: int
    reward_map: Tensor
    states: Tensor
    current_step: int
    variance: float
    optimal_actions: Tensor

    def env_name(self) -> str:
        return "darkroom"

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, square_len: int, device=None, *, pbar_desc=None) -> "DarkroomEnv":
        assert square_len * square_len > n_goals_in_env
        assert square_len > 1

        goals = torch.stack(
            [
                torch.arange(square_len, device=device).repeat_interleave(square_len),
                torch.arange(square_len, device=device).repeat(square_len),
            ],
            dim=1,
        )

        arange_n_envs = torch.arange(n_envs)

        goals = goals[None, ...].tile(n_envs, 1, 1)
        n_goals = goals.shape[1]

        shuffle_indices = torch.rand(n_envs, n_goals, device=device).argsort(dim=1)
        goals = goals[arange_n_envs[:, None], shuffle_indices, :]
        goals = goals[:, :n_goals_in_env, :]

        reward_map = torch.zeros((n_envs, square_len, square_len), device=device)
        for goal_num in range(n_goals_in_env):
            reward_map[arange_n_envs, goals[arange_n_envs, goal_num, 0], goals[arange_n_envs, goal_num, 1]] = goal_num + 1

        all_states = torch.stack(
            [
                torch.arange(square_len, device=device).repeat_interleave(square_len),
                torch.arange(square_len, device=device).repeat(square_len),
            ],
            dim=1,
        )

        optimal_actions = torch.zeros((n_envs, square_len, square_len, n_actions), device=device)

        for env in tqdm(range(n_envs), desc=(f"{pbar_desc} - " if pbar_desc is not None else "") + f"Getting Optimal Actions"):
            for state in all_states:
                optimal_actions[env, state[0], state[1]] = get_optimal_action(state, goals[env, -1])

        return DarkroomEnv(reward_map, n_envs, n_steps, square_len, goals, optimal_actions, device)

    def __init__(self, reward_map: Tensor, n_envs: int, n_steps: int, square_len: int, goals: Tensor, optimal_actions: Tensor, device: DeviceLikeType | None = None):
        action_dim = 5
        super().__init__(n_envs, n_steps, square_len**2, action_dim)
        self.square_len = square_len
        self.n_actions = action_dim
        self.goals = goals
        self.reward_map = reward_map
        self.optimal_actions = optimal_actions

        self.device = device

        self.reset()

    def reset(self) -> Tensor:
        self.current_step = 0
        states = torch.zeros((self.n_envs, 2), dtype=torch.int, device=self.device)
        self.states = states

        return states

    def sample_states(self) -> Tensor:
        states = torch.randint(0, self.square_len, (self.n_envs, 2), dtype=torch.int, device=self.device)
        return states

    @overload
    def get_rewards(self, next_states: Tensor, actions: Tensor, type: Literal["corrupted", "clean"]) -> Tensor: ...
    @overload
    def get_rewards(self, next_states: Tensor, actions: Tensor, type: Literal["both"]) -> tuple[Tensor, Tensor]: ...
    def get_rewards(self, next_states: Tensor, actions: Tensor, type: Literal["corrupted", "clean", "both"] = "corrupted") -> tuple[Tensor, Tensor] | Tensor:
        assert next_states.ndim == 2 and actions.ndim == 3, f"{next_states.ndim} != 2 or {actions.ndim} != 3"

        r = self.reward_map[torch.arange(self.n_envs), next_states[:, 0], next_states[:, 1]][:, None]

        if type == "clean":
            return r

        if self.attacker is None:
            raise ValueError("`attacker` is unset, please deploy with `attacker` or get type='clean' rewards only.")

        assert self.corrupted_steps is not None

        r_original = r.clone()

        r_poisoned = self.attacker.get_reward(r_original, next_states.unsqueeze(1), actions).squeeze(2)
        if r.shape[1] == 1:
            corrupted_steps = self.corrupted_steps[:, self.current_step][:, None]
        else:
            corrupted_steps = self.corrupted_steps
        r += r_poisoned * corrupted_steps

        if type == "both":
            return r, r_original

        return r

    def step(self, actions: Tensor) -> tuple[Tensor, Tensor, bool]:
        if self.current_step >= self.n_steps:
            raise RuntimeError(f"Episode has already ended (current_step exceeds n_steps={self.n_steps}).")

        with_attacker = self.attacker is not None
        # assert actions.shape == (self.n_envs, 2)
        # assert self.states.shape == n_envs

        # e -- envs
        # a -- actions
        # c -- coords (state_dim)
        add_to_state = torch.einsum("e...a,ac->e...c", actions, transit_kernel).int()
        if self.states.ndim == 2 and add_to_state.ndim == 3:
            add_to_state.squeeze_(1)
        self.states += add_to_state
        self.states = self.states.clip(0, self.square_len - 1)

        if with_attacker:
            r, r_original = self.get_rewards(self.states, actions[:, None, :], "both")
        else:
            r = self.get_rewards(self.states, actions[:, None, :], "clean")

        self.current_step += 1

        done = False
        if self.current_step == self.n_steps:
            done = True

        if with_attacker:
            return r, r_original, done
        else:
            return r, r, done

    def get_optimal_actions_per_state(self, states: Tensor, *, pbar_desc=None) -> Tensor:
        n_envs, n_steps, _ = states.shape
        optimal_actions = torch.zeros((self.n_envs, n_steps, self.n_actions), device=states.device)

        if n_steps < 10:
            loop = lambda x: x
        else:
            loop = lambda x: tqdm(x, desc=(f"{pbar_desc} - " if pbar_desc is not None else "") + f"Getting Optimal Actions")

        for env in loop(range(n_envs)):
            for step in range(n_steps):
                optimal_actions[env, step, :] = self.optimal_actions[env, states[env, step, 0], states[env, step, 1]]

        return optimal_actions

    def visualize_dataset(
        self,
        dataset: MDPDataset | MDPDatasetTorch,
        *,
        attacker_weights: Tensor | None = None,
        advantages: Tensor | None = None,
        values: Tensor | None = None,
        title: str | None = None,
        env_idx: int | None = 0,
    ) -> ArtistAnimation:
        plt.rcParams.update({"font.size": 14})
        plt.rcParams["svg.fonttype"] = "none"
        plt.rcParams["font.family"] = "Latin Modern Math"

        whitecolor = torch.tensor([1.0, 1, 1], device=self.device)
        goalcolor = (1, 0.85, 0)
        agentcolor = (0, 0, 0)

        square_len = int(sqrt(dataset.n_states))

        assert dataset.n_states == square_len**2, f"{dataset.n_states} != {square_len}**2"

        anim_envs = [env_idx] if env_idx is not None else range(dataset.states.shape[0])

        for env in anim_envs:
            reward_map = self.reward_map[env]
            states = dataset.states[env]
            actions = dataset.actions[env]
            rewards = dataset.rewards[env]
            rewards_original = dataset.rewards_original[env]

            fig, ax = plt.subplots(1, 1, figsize=(6, 7))  # type: ignore
            ax: Axes
            ax.set_xticks([])
            ax.set_yticks([])
            fig.tight_layout()

            grid = whitecolor + reward_map[..., None] / reward_map.max() * (torch.tensor(goalcolor, device=self.device) - whitecolor)
            ax.imshow(grid.numpy(force=True), origin="lower", interpolation="none", extent=(0, square_len, 0, square_len))
            ax.set_ylim(0, square_len)
            ax.set_xlim(0, square_len)
            ax.grid(True, which="both")

            if title is not None:
                ax.set_title(title + "\n")

            legend_elements = [
                plt.Line2D([0], [0], marker="o", color="w", label="Agent", markerfacecolor=agentcolor, markersize=10),  # type: ignore[arg-type]
                # plt.Line2D([0], [0], marker="o", color="w", label="Visited State", markerfacecolor="grey", markersize=10),  # type: ignore[arg-type]
                plt.Line2D([0], [0], marker="o", color="w", label="Goals", markerfacecolor=goalcolor, markersize=10),  # type: ignore[arg-type]
            ]
            ax.legend(handles=legend_elements, loc="lower right", framealpha=0.5)

            ax.invert_yaxis()

            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.xaxis.set_minor_locator(plticker.MultipleLocator(base=1.0))
            ax.yaxis.set_minor_locator(plticker.MultipleLocator(base=1.0))

            colors = {-1: "blue", 0: "gray", 1: "blue"}
            if attacker_weights is not None:
                attack = attacker_weights[env].argmax(-1) - 1
                for i in range(square_len):
                    for j in range(square_len):
                        ax.text(
                            j + 0.5,
                            i + 0.5,
                            f"{attack[i, j]:+.0f}",
                            horizontalalignment="center",
                            verticalalignment="center",
                            color=colors[int(attack[i, j].item())],
                            fontsize=10,
                        )

            if advantages is not None:
                action_dec = torch.tensor([[0, -0.35], [0, 0.35], [0.35, 0], [-0.35, 0], [0, -0.05]])

                for i in range(square_len):
                    for j in range(square_len):
                        for k in range(n_actions):
                            ax.text(
                                j + 0.5 + action_dec[k, 0].item(),
                                i + 0.5 + action_dec[k, 1].item(),
                                f"{advantages[env, i*square_len + j, k]:+.2f}",
                                horizontalalignment="center",
                                verticalalignment="center",
                                color="green",
                                fontsize=8,
                            )

            if values is not None:
                for i in range(square_len):
                    for j in range(square_len):
                        ax.text(
                            j + 0.5,
                            i + 0.55,
                            f"{values[env, i*square_len + j]:.1f}",
                            horizontalalignment="center",
                            verticalalignment="center",
                            color="red",
                            fontsize=9,
                        )

            # Animation starts here

            anim_artists = []

            scale = 1.5
            action_vectors = [(-1, 0), (1, 0), (0, 1), (0, -1), (0, 0)]
            total_reward = 0
            total_reward_original = 0

            for t, (state, action, reward, reward_original) in enumerate(zip(states, actions, rewards, rewards_original)):
                action_vector = action_vectors[torch.argmax(action)]

                agent_rect = Rectangle(tuple(state.tolist()[::-1]), 1, 1, facecolor=agentcolor)
                ax.add_patch(agent_rect)

                action_arrow = ax.arrow(
                    state[1].item() + 0.5,
                    state[0].item() + 0.5,
                    action_vector[1] * 0.4,
                    -action_vector[0] * 0.4,
                    head_width=0.2,
                    head_length=0.2,
                    fc="b",
                    ec="b",
                )

                total_reward += reward.item()
                total_reward_original += reward_original.item()

                timestep = ax.text(0, -0.2, f"t = {t}", fontsize=8 * scale)
                info = ax.text(
                    1.3,
                    -0.2,
                    f"state=[{state[0].item():.0f}, {state[1].item():.0f}], r†={reward.item(): .1f}, R†={total_reward: .1f}",
                    fontsize=8 * scale,
                )
                info2 = ax.text(5.8, -0.2, f"r̃={reward_original.item(): .1f}, R={total_reward_original: .1f}", fontsize=8 * scale)

                anim_artists.append((agent_rect, action_arrow, timestep, info, info2))

            anim = ArtistAnimation(fig, anim_artists, interval=100)

            if len(anim_envs) > 1:
                anim.save(f"darkroom{env}.mp4")
                print(f"Saved 'darkroom{env}.mp4'.")

        return anim


def get_optimal_action(state: Tensor, goal: Tensor) -> Tensor:
    if state[0] < goal[0]:
        return a_down
    if state[0] > goal[0]:
        return a_up
    if state[1] < goal[1]:
        return a_right
    if state[1] > goal[1]:
        return a_left

    return a_noop


def generate_trajectories(n_envs: int, n_steps: int, square_len: int, device: DeviceLikeType | None = None, *, pbar_desc: str | None = None) -> MDPDatasetTorch:
    ctrl = MDPRandomController(n_envs, n_steps, 2, square_len**2, 5, device)
    envs = DarkroomEnv.sample(n_envs, n_steps, square_len, device=device, pbar_desc=pbar_desc)

    dataset = envs.deploy(ctrl, pbar_desc=pbar_desc)
    return dataset
