import copy
import random

import numpy as np

from abc import ABC, abstractmethod
from typing import Tuple, List, Dict, Union, Any

from envs.multiagentenv import MultiAgentEnv

#-----------------------------------------------------------------------
#  There are n agents in the quarry and each agent has an action set of
#  place_explosive, move[left, right] or noop, the explosive action 
#  has a duration of *d* when n explosives explode at the same time 
#  before the time limite, agents get the optimal reward. 
#  In other cases, agents can suboptimal rewards.
#-----------------------------------------------------------------------


class GameObject(ABC):
    @abstractmethod
    def reset(self):
        raise NotImplementedError

    @abstractmethod
    def move(self):
        raise NotImplementedError


class Agent(GameObject):
    def __init__(self, 
                 init_pos: List[int],
                 seed: int = 0,
                 agent_id: int = 0) -> None:
        self.init_pos = init_pos
        self.curr_pos = np.array(copy.deepcopy(init_pos))
        self.prev_pos = np.array(copy.deepcopy(init_pos))
        self.seed = seed
        self.agent_id = agent_id

    def reset(self):
        self.init_pos = np.array(copy.deepcopy(self.init_pos))
        self.curr_pos = np.array(copy.deepcopy(self.init_pos))
        self.prev_pos = np.array(copy.deepcopy(self.init_pos))

    def move(self):
        raise NotImplementedError


class Explosive(GameObject):
    def __init__(self, 
                 init_pos: List[int],
                 seed: int = 0,
                 start_time: int = 0,
                 duration: int = 0,
                 agent_id: int = 0) -> None:
        self.init_pos = init_pos
        self.curr_pos = np.array(copy.deepcopy(init_pos))
        self.prev_pos = np.array(copy.deepcopy(init_pos))
        self.seed = seed
        self.agent_id = agent_id  # which agent put the explosive
        self.start_time = start_time
        self.timer = 0
        self.duration = duration

    def update(self):
        self.timer += 1

    def check_bomb(self):
        return self.timer == self.duration

    def get_time_left(self):
        return self.duration - self.timer

    def reset(self):
        self.init_pos = np.array(copy.deepcopy(self.init_pos))
        self.curr_pos = np.array(copy.deepcopy(self.init_pos))
        self.prev_pos = np.array(copy.deepcopy(self.init_pos))

    def move(self):
        raise NotImplementedError


class Rock(GameObject):
    def __init__(self,
                 init_pos: List[int],
                 seed: int = 0) -> None:
        self.init_pos = np.array(init_pos)
        self.curr_pos = np.array(init_pos)
        self.seed = seed
        self.blasted = False

    def reset(self):
        self.init_pos = np.array(copy.deepcopy(self.init_pos))
        self.curr_pos = np.array(copy.deepcopy(self.init_pos))
        self.blasted = False

    def move(self):
        raise NotImplementedError


class QuarryEnv(MultiAgentEnv):
    """
    2 agents are placing explosives in a quary, aiming to get the optimal rewards
    """
    ACTION_NO_OP = 1
    ACTION_BOMB = 2

    AGENT_FLAG = 3
    STONE_FLAG  = 7
    EXPLOSIVE_FLAG = 11
    GRID_FLAG  = 0
    def __init__(self,
                 seed=None,
                 map_name='2agents',
                 episode_limit=9,
                 reward_no_op=-0.1,
                 reward_one_explosion=5.4,
                 reward_n_explosion=15,
                 reward_no_explosion=-15,
                 n_agents=2,
                 grid_size=5,
                 async_agents='all',
                 max_delay=8,
                 state_last_action=True,
                 explosive_in_place_replace=True,  # if new explosive committed, replace or use the exist one
                 end_position=False,  # agent starts from the both of the corridor
                 policy_override_rule='no',
                 policy_override_prob=1):
        """
        S: Stone, A: Agents, agent 0 (left), agent 1 (right)
         ___________________
        |_A_|___|_S_|___|_A_|
        """
        self.seed                 = seed
        self.map_name             = map_name
        self.episode_limit        = episode_limit
        self.reward_no_op         = reward_no_op
        self.reward_one_explosion = reward_one_explosion
        self.reward_n_explosion   = reward_n_explosion
        self.reward_no_explosion  = reward_no_explosion
        self.n_agents             = n_agents
        self.grid_size            = grid_size
        self.async_agents         = async_agents
        self.max_delay            = max_delay
        self.end_position         = end_position
        self.state_last_action    = state_last_action
        self.policy_override_rule = policy_override_rule
        self.policy_override_prob = policy_override_prob
        
        self.explosive_in_place_replace = explosive_in_place_replace

        self.delay_step_max = self.max_delay_step = max_delay
        self.unit_dim = grid_size

        assert grid_size % 2 != 0, f"The grid_size must be an odd number, current value is {grid_size}"

        # agent's actions' duration
        self.action_durations = {
            0: {
                self.ACTION_NO_OP: 0,
                self.ACTION_BOMB: 8,
            },
            1: {
                self.ACTION_NO_OP: 0,
                self.ACTION_BOMB: 4,
            }
        }

        self._init_agent_action_sapce()

        # last actions, one-hot encoding
        self.last_action = np.zeros((self.n_agents, self.n_actions)) - 1
        self.agents = self._create_agents()
        self.rock = Rock(init_pos=[0, grid_size//2], seed=self.seed)
        self.curr_t = 0
        self.agent_explosives = {
            0: None,
            1: None,
        }
        self.debug_episode_actions = []
        self._init_state_obs()

    def _init_agent_action_sapce(self):
        self._available_actions = [0, 1, 1]  # 0: maskout action, 1: no_op 2:shot

    def _create_agents(self):
        assert self.end_position is False, "self.end_position=True is not supported in this environment."

        agents = []
        agents.append(Agent(init_pos=[0, self.grid_size//2-2], 
                            seed=self.seed,
                            agent_id=0))
        agents.append(Agent(init_pos=[0, self.grid_size//2+2], 
                            seed=self.seed,
                            agent_id=1))
        return agents

    def step(self, actions):
        """ Returns reward, terminated, info """

        assert len(actions) == self.n_agents, f"len(actions) ({len(actions)}) != self.n_agents ({self.n_agents})"

        reward, terminated, info = self.reward_no_op, False, {'won': 0, 'one_hit': 0}
        
        self.last_action = np.eye(self.n_actions)[np.array([int(act) for act in actions])]
        self.debug_episode_actions.append(tuple([int(act) for act in actions]))

        actions = [int(act) for act in actions]
        for agent_id, action in enumerate(actions):
            assert action in {self.ACTION_NO_OP, self.ACTION_BOMB}, \
                f'action out of range, agent_id: {agent_id}, action: {int(action)}'

            if action == self.ACTION_BOMB:
                explosive = Explosive(init_pos=[0, self.grid_size//2-2 if agent_id==0 else self.grid_size//2+2],
                                      seed=self.seed,
                                      start_time=self.curr_t,
                                      duration=self.action_durations[agent_id][action],
                                      agent_id=agent_id)
                if self.explosive_in_place_replace and self.agent_explosives[agent_id] is None:
                    self.agent_explosives[agent_id] = explosive

        # update the explosive
        for agent_id, explosive in self.agent_explosives.items():
            if explosive is not None:
                explosive.update()

        # check explosion
        count = self._check_explosion()
        # 1. One explosive explodes
        if count == 1:
            info['one_hit'] = 1
            terminated = True
            reward = self.reward_one_explosion
        # 2. Two explosives explode
        elif count == self.n_agents:
            info['won'] = 1
            terminated = True
            reward = self.reward_n_explosion
        else:
            pass

        # 3. Time limits
        if self.curr_t == self.episode_limit - 1:
            terminated = True
            if count == 0:
                reward = self.reward_no_explosion
        
        self.curr_t += 1
        self._update_obs()
        self._update_state()
        return reward, terminated, info

    def _check_explosion(self):
        # count how many explosives bomb
        count = 0 
        for agent_id, explosive in self.agent_explosives.items():
            if explosive is None:
                continue
            if explosive.check_bomb():
                count += 1
                self.agent_explosives[agent_id] = None  # realse the object, it will be processed by the gc
        return count

    def reset(self):
        self.curr_t = 0
        self.agent_explosives = {
            0: None,
            1: None,
        }
        for agent in self.agents:
            agent.reset()
        self._init_state_obs()
        return self.get_obs(), self.get_state()

    def _init_state_obs(self):
        self._state = []
        self._state.append(np.array([self.AGENT_FLAG, self.GRID_FLAG, self.STONE_FLAG, self.GRID_FLAG, self.AGENT_FLAG]))
        if self.state_last_action:
            self._state.append(self.last_action)
        else:
            self._state.append(np.zeros((self.n_agents, self.n_actions)) - 1)
        
        self._obs = [
            [self.AGENT_FLAG, self.GRID_FLAG, self.STONE_FLAG],
            [self.STONE_FLAG, self.GRID_FLAG, self.AGENT_FLAG]
        ]

    def _update_obs(self):
        for agent_id, explosive in self.agent_explosives.items():
            if explosive is not None:
                self._obs[agent_id][1] = self.EXPLOSIVE_FLAG

    def _update_state(self):
        self._state[0] = np.array([self.AGENT_FLAG, self._obs[0][1], self.STONE_FLAG, self._obs[1][1], self.AGENT_FLAG])
        
        self._state[0] = np.array(self._state[0]).astype('float64')

        if self.state_last_action:
            self._state[1] = self.last_action
        else:
            self._state[1] = np.zeros((self.n_agents, self.n_actions)) - 1

    def get_obs(self):
        """ Returns all agent observations in a list """
        return copy.deepcopy(self._obs)

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        return copy.deepcopy(self._obs[agent_id])

    def get_obs_size(self):
        """ Returns the shape of the observation """
        return len(self._obs[0])

    def get_state(self):
        return copy.deepcopy(np.append(*self._state).flatten())

    def get_state_size(self):
        """ Returns the shape of the state"""
        return int(sum([np.prod(np.array(arr).shape) for arr in self._state]))

    def get_avail_actions(self):
        return [self._available_actions for _ in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        return self._available_actions

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        return len(self._available_actions)

    def render(self):
        flags = [
            '_' if explosive is None else 'E'
            for agent_id, explosive in self.agent_explosives.items()
        ]
        print(f""" 
         ___________________
        |_A_|_{flags[0]}_|_S_|_{flags[1]}_|_A_|
        """)
    
    def policy_override(self, actions, override_policy=True):
        if self.policy_override_rule == 'no' or not override_policy:
            return actions
        else:
            # TODO
            pass
        return actions

    def get_action_durations(self, actions: List, curr_timestep: int):
        """
        Return action durations
        """
        # NOTE current results are not right, 
        # TODO return the right action duration
        results = [9, 4]
        assert len(results) == len(actions), "len(actions) != n_agents"
        return results

    def close(self):
        pass

    def seed(self):
        raise self.seed

    @property
    def n_actions(self):
        return len(self._available_actions)

    def _get_max_delay(self):
        return self.max_delay_step

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "max_delay": self._get_max_delay(),
                    "unit_dim": self.unit_dim}
        return env_info

    def get_stats(self):
        stats = {
        }
        return stats
