import os

import cv2
import gymnasium as gym
import numpy as np
import vizdoom as vzd

from tianshou.env import ShmemVectorEnv

try:
    import envpool
except ImportError:
    envpool = None


def normal_button_comb():
    actions = []
    m_forward = [[0.0], [1.0]]
    t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
    for i in m_forward:
        for j in t_left_right:
            actions.append(i + j)
    return actions


def battle_button_comb():
    actions = []
    m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
    m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
    t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
    attack = [[0.0], [1.0]]
    speed = [[0.0], [1.0]]

    for m in attack:
        for n in speed:
            for j in m_left_right:
                for i in m_forward_backward:
                    for k in t_left_right:
                        actions.append(i + j + k + m + n)
    return actions


class Env(gym.Env):

    def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False):
        super().__init__()
        self.save_lmp = save_lmp
        self.health_setting = "battle" in cfg_path
        if save_lmp:
            os.makedirs("lmps", exist_ok=True)
        self.res = res
        self.skip = frameskip
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=res, dtype=np.float32
        )
        self.game = vzd.DoomGame()
        self.game.load_config(cfg_path)
        self.game.init()
        if "battle" in cfg_path:
            self.available_actions = battle_button_comb()
        else:
            self.available_actions = normal_button_comb()
        self.action_num = len(self.available_actions)
        self.action_space = gym.spaces.Discrete(self.action_num)
        self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
        self.count = 0

    def get_obs(self):
        state = self.game.get_state()
        if state is None:
            return
        obs = state.screen_buffer
        self.obs_buffer[:-1] = self.obs_buffer[1:]
        self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))

    def reset(self):
        if self.save_lmp:
            self.game.new_episode(f"lmps/episode_{self.count}.lmp")
        else:
            self.game.new_episode()
        self.count += 1
        self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
        self.get_obs()
        self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
        self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
        self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
        return self.obs_buffer

    def step(self, action):
        self.game.make_action(self.available_actions[action], self.skip)
        reward = 0.0
        self.get_obs()
        health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
        if self.health_setting:
            reward += health - self.health
        elif health > self.health:  # positive health reward only for d1/d2
            reward += health - self.health
        self.health = health
        killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
        reward += 20 * (killcount - self.killcount)
        self.killcount = killcount
        ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
        # if ammo2 > self.ammo2:
        reward += ammo2 - self.ammo2
        self.ammo2 = ammo2
        done = False
        info = {}
        if self.game.is_player_dead() or self.game.get_state() is None:
            done = True
        elif self.game.is_episode_finished():
            done = True
            info["TimeLimit.truncated"] = True
        return self.obs_buffer, reward, done, info

    def render(self):
        pass

    def close(self):
        self.game.close()


def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num):
    test_num = min(os.cpu_count() - 1, test_num)
    if envpool is not None:
        task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
        lmp_save_dir = "lmps/" if save_lmp else ""
        reward_config = {
            "KILLCOUNT": [20.0, -20.0],
            "HEALTH": [1.0, 0.0],
            "AMMO2": [1.0, -1.0],
        }
        if "battle" in task:
            reward_config["HEALTH"] = [1.0, -1.0]
        env = train_envs = envpool.make_gymnasium(
            task_id,
            frame_skip=frame_skip,
            stack_num=res[0],
            seed=seed,
            num_envs=training_num,
            reward_config=reward_config,
            use_combined_action=True,
            max_episode_steps=2625,
            use_inter_area_resize=False,
        )
        test_envs = envpool.make_gymnasium(
            task_id,
            frame_skip=frame_skip,
            stack_num=res[0],
            lmp_save_dir=lmp_save_dir,
            seed=seed,
            num_envs=test_num,
            reward_config=reward_config,
            use_combined_action=True,
            max_episode_steps=2625,
            use_inter_area_resize=False,
        )
    else:
        cfg_path = f"maps/{task}.cfg"
        env = Env(cfg_path, frame_skip, res)
        train_envs = ShmemVectorEnv(
            [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)]
        )
        test_envs = ShmemVectorEnv(
            [
                lambda: Env(cfg_path, frame_skip, res, save_lmp)
                for _ in range(test_num)
            ]
        )
        train_envs.seed(seed)
        test_envs.seed(seed)
    return env, train_envs, test_envs


if __name__ == '__main__':
    # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
    env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
    print(env.available_actions)
    action_num = env.action_space.n
    obs = env.reset()
    print(env.spec.reward_threshold)
    print(obs.shape, action_num)
    for _ in range(4000):
        obs, rew, done, info = env.step(0)
        if done:
            env.reset()
    print(obs.shape, rew, done)
    cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
