from typing import overload

import cv2
import gymnasium as gym
import imageio
import miniworld
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from torch._prims_common import DeviceLikeType
from tqdm import tqdm

from collect_data import rand_pos_and_dir
from mdp.mdp_attacker import MDPAttacker, MDPGridClassifierBaseAttacker
from mdp.mdp_controller import MDPImageTransformerController
from mdp.mdp_dataset import (
    MDPDatasetImages,
    MDPDatasetImagesTorch,
    process_miniworld_images,
)
from mdp.mdp_env import BaseMDP
from Miniworld.miniworld.envs.oneroom import OneRoomS6FastMulti

MW_ATTACK_RESOLUTION = 1
MW_SQUARE_LEN = 7 * MW_ATTACK_RESOLUTION
MW_N_ACTIONS = 4
MW_N_STATES = 2

video_filename_index = 0

class MiniworldEnv(BaseMDP):
    envs: list[OneRoomS6FastMulti]
    optimal_actions: Tensor | None

    def env_name(self) -> str:
        return "miniworld"

    @classmethod
    def sample(cls, n_envs: int, n_steps: int, device=None, seed=0) -> "MiniworldEnv":
        envs = []
        for env_idx in range(n_envs):
            env = gym.make("MiniWorld-OneRoomS6FastMultiFourBoxesFixedInit-v0", max_episode_steps=n_steps)  # type: ignore
            env: OneRoomS6FastMulti
            env.unwrapped.set_task(env_id=8000 + seed + env_idx)  # type: ignore
            envs.append(env)
        return MiniworldEnv(envs, n_steps, device=device)

    def __init__(self, envs: list[OneRoomS6FastMulti], n_steps: int, device: DeviceLikeType | None = None):
        super().__init__(len(envs), n_steps, MW_N_STATES, MW_N_ACTIONS)
        self.device = device
        self.envs = envs
        self.optimal_actions = None

    @property
    def task_ids(self) -> Tensor:
        return torch.tensor([env.unwrapped.env_id for env in self.envs])  # type: ignore

    def reset(self) -> Tensor:
        self.current_step = 0
        [env.reset() for env in self.envs]
        self.states = torch.tensor([env.unwrapped.agent.dir_vec[[0, -1]].tolist() for env in self.envs])  # type: ignore

        self.frame_step = 0

        return self.states

    def step(self, actions: Tensor, *, save_video: bool = False) -> tuple[Tensor, Tensor, Tensor, bool]:
        if self.current_step >= self.n_steps:
            raise RuntimeError(f"Episode has already ended (current_step exceeds n_steps={self.n_steps}).")

        # actions:
        # turn_left = 0
        # turn_right = 1
        # move_forward = 2
        # move_back = 3 (no-op)

        actions_idx = actions.argmax(dim=-1)

        rewards_original, dones = torch.zeros((2, self.n_envs, 1), device=self.device)
        infos = torch.zeros((self.n_envs, 2), dtype=torch.int)
        for env_idx, (action, env) in enumerate(zip(actions_idx, self.envs)):
            opt_a = env.opt_a(None, None, None)
            _, rewards_original[env_idx], dones[env_idx], _, info = env.step(action)

            infos[env_idx, 0] = round(info["grid_pos"][0] * MW_ATTACK_RESOLUTION)
            infos[env_idx, 1] = round(info["grid_pos"][1] * MW_ATTACK_RESOLUTION)

        self.states = torch.tensor([env.unwrapped.agent.dir_vec[[0, -1]].tolist() for env in self.envs])

        done: bool = torch.any(dones).item()  # type: ignore

        if self.attacker is None:
            return rewards_original, rewards_original, infos, done

        assert self.corrupted_steps is not None

        reward_mod = self.attacker.get_reward(rewards_original, infos[:, None, :], actions).squeeze(2)

        if rewards_original.shape[1] == 1:
            corrupted_steps = self.corrupted_steps[:, self.current_step][:, None]
        else:
            corrupted_steps = self.corrupted_steps
        rewards = rewards_original + reward_mod * corrupted_steps

        if save_video:
            env_idx = -1
            attacker_rewards = None
            if self.attacker is not None:
                attacker_rewards = self.attacker.reward_values[self.attacker.weights[env_idx].argmax(-1)]
            self.save_frames(
                self.envs[env_idx], int(actions_idx[env_idx].item()), opt_a, rewards_original[env_idx].item(), rewards[env_idx].item(), done, infos[env_idx], attacker_rewards
            )

        self.current_step += 1

        return rewards, rewards_original, infos, done

    def save_frames(
        self, env: OneRoomS6FastMulti, action: int, action_optimal: int, reward: float, reward_poisoned: float, done: bool, info: Tensor, attacker_rewards: Tensor
    ) -> None:
        if self.frame_step == 0:
            self.frames = []

        print(self.frame_step)

        img = env.unwrapped.render_top(goal_text=True, action=action, action_optimal=action_optimal, reward=reward, reward_poisoned=reward_poisoned, info=info)  # type: ignore

        white = (255, 255, 255)
        notwhite = (255, 0, 0)
        img = Image.fromarray(img)
        font = ImageFont.load_default(6)
        draw = ImageDraw.Draw(img)
        for row_num, row in enumerate(attacker_rewards):
            for col_num, item in enumerate(row):
                draw.text((1 + col_num * 4, 30 + row_num * 5), f"{item:.0f}", fill=white if item == 0.0 else notwhite, font=font)
        img = np.array(img)

        self.frames.append(img)
        self.frame_step += 1

        if done:
            global video_filename_index
            wr = imageio.get_writer(f"test_video{video_filename_index}.mp4", format="FFMPEG", fps=15)
            video_filename_index += 1
            for frame in self.frames:
                wr.append_data(frame)
            wr.close()

    @overload
    def deploy(
        self,
        controller: MDPImageTransformerController,
        *,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> MDPDatasetImagesTorch: ...

    @overload
    def deploy(
        self,
        controller: MDPImageTransformerController,
        attacker: MDPAttacker,
        eps_episodes: float,
        eps_steps: float,
        *,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> MDPDatasetImagesTorch: ...

    def deploy(
        self,
        controller: MDPImageTransformerController,
        attacker: MDPAttacker | None = None,
        eps_episodes: float | None = None,
        eps_steps: float | None = None,
        *,
        clear_dataset: bool = True,
        context_len: int | None = None,
        pbar_desc: str | None = None,
        force_show_progress: bool | None = None,
        **kwargs,
    ) -> MDPDatasetImagesTorch:
        """Deploy a controller in the environment with corruption. Returns the trajectories of the deployment."""
        self.reset()
        if clear_dataset:
            controller.clear_dataset()
        if attacker is None:
            self.attacker = None
        else:
            assert eps_episodes is not None and eps_steps is not None, "eps_episodes and eps_steps must be set"
            self._set_attacker(attacker, eps_episodes, eps_steps)

        dataset = MDPDatasetImages(self.n_envs, self.n_steps, self.n_states, self.states.shape[-1], self.action_dim, controller.device)

        if (self.n_envs < 10000 or self.n_steps < 100) and not force_show_progress:
            loop = lambda x: x
        else:
            loop = lambda x: tqdm(x, desc=(f"{pbar_desc} " if pbar_desc is not None else "") + "Deploy - " + controller.__class__.__name__)

        images = [env.render() for env in self.envs]
        states = self.states.clone()

        for _ in loop(range(self.n_steps)):
            actions = controller.sample_actions(states, images)  # .float()

            rewards, rewards_original, infos, _ = self.step(actions, **kwargs)
            states_next = self.states.clone()
            images_next = [env.render() for env in self.envs]

            controller.append(states, images, actions, rewards, states_next, images_next, rewards_original, {"infos": infos})
            dataset.append(states, images, actions, rewards, states_next, images_next, rewards_original, {"infos": infos})

            images = images_next
            states = states_next

        query_states, query_images, optimal_query_actions = self.sample_states()

        return dataset.finalize(None, query_states, query_images, optimal_query_actions, context_len=context_len)

    def visualize_dataset(self, dataset: MDPDatasetImages | MDPDatasetImagesTorch, *, attacker_weights: Tensor | None = None):
        raise NotImplementedError()

    def sample_states(self) -> tuple[Tensor, Tensor, Tensor]:
        states = []
        images = []
        opt_as = torch.zeros((self.n_envs, self.action_dim), device=self.device)
        for env_idx, env in enumerate(self.envs):
            init_pos, init_dir = rand_pos_and_dir(env)
            env.unwrapped.place_agent(pos=init_pos, dir=init_dir)  # type: ignore
            states.append(env.unwrapped.agent.dir_vec[[0, -1]].tolist())  # type: ignore

            image = env.unwrapped.render_obs()  # type: ignore
            images.append(image)

            opt_a_idx = env.unwrapped.opt_a(None, None, None)  # type: ignore
            opt_as[env_idx, opt_a_idx] = 1

        images = process_miniworld_images(images)
        return torch.tensor(states, device=self.device), images, opt_as


class MDPMiniworldAttacker(MDPGridClassifierBaseAttacker):
    _original_task_ids: Tensor

    def __init__(self, original_task_ids: Tensor, n_envs: int, *, lr: float | None = None, device: DeviceLikeType | None = None) -> None:
        super().__init__(n_envs, MW_SQUARE_LEN, lr=lr, device=device)

        # self.weights.data[:, :, :, 1] -= 0.4  # Don't bias initial attacker towards no attack

        # Save it so could load later
        self._original_task_ids = nn.Parameter(original_task_ids.clone().detach().float().to(device=self.device))
        self._original_task_ids.requires_grad_(False)

    def get_dataset_states(self, dataset: MDPDatasetImagesTorch) -> Tensor:
        return dataset.infos.int()
