import gym
import random
import copy
import numpy as np

from env.multi_agent_env import MultiAgentEnv
from env.grid_world.plot import PlotTrace


class GridShooting(MultiAgentEnv):

    SHOOTING_REWARD = 10
    STAR_REWARD = 1
    MOVE_REWARD = -1

    def __init__(
        self,
        size=9,
        # init_life=None,
        max_cd=None,
        log_print=False,
    ):
        # if init_life is None:
        #     init_life = [5, 5]
        if max_cd is None:
            max_cd = [5, 5]
        self.size = size
        # self.init_life = init_life
        self.max_cd = max_cd

        self.log_print = log_print
        self.trace = []
        self.count = 0

        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(6 * self.size + 2,)
        )
        self.action_space = gym.spaces.Discrete(9)

        self.frame_no = 0
        # self.team_life = copy.deepcopy(self.init_life)
        self.agents = {}
        self.actions = {}
        self.star_pos = None

        self.star_idx = None
        self.create_star()
        self.reset()

    def reset(self):
        self.frame_no = 0
        # self.team_life = copy.deepcopy(self.init_life)
        self.agents = {
            0: {"pos": [0, 0], "cd": 0},
            1: {"pos": [self.size - 1, self.size - 1], "cd": 0},
        }
        obs = {"left": self.get_state(0), "right": self.get_state(1)}
        if self.log_print:
            self.trace.append([
                ['green', self.agents[0]["pos"] + [0]],
                ['green', self.agents[1]["pos"] + [0]],
                ['yellow', self.star_pos + [0]],
            ])
        return obs

    def create_star(self):
        # self.star_pos = [random.randint(0, self.size - 1), random.randint(0, self.size - 1)]
        if self.star_idx:
            candidates = [0, 1, 2]
            candidates.remove(self.star_idx)
            self.star_idx = random.choice(candidates)
        else:
            self.star_idx = random.randint(0, 2)

        if self.star_idx == 0:
            self.star_pos = [int(self.size / 2), int(self.size / 2)]
        elif self.star_idx == 1:
            self.star_pos = [0, self.size - 1]
        else:
            self.star_pos = [self.size - 1, 0]

    def get_state(self, team_id, action_mask=False):
        state = []
        # agent position
        enemy_team_id = team_id ^ 1
        agents_pos = {}
        for id in range(2):
            agent_pos = [0.0 for _ in range(self.size * 2)]
            agent_pos[self.agents[id]["pos"][0]] = 1
            agent_pos[self.agents[id]["pos"][1] + self.size] = 1
            agents_pos[id] = agent_pos
        state.extend(agents_pos[team_id])
        state.extend(agents_pos[enemy_team_id])
        # star position
        star_pos = [0.0 for _ in range(self.size * 2)]
        star_pos[self.star_pos[0]] = 1
        star_pos[self.star_pos[1] + self.size] = 1
        state.extend(star_pos)
        # cd
        max_cd = self.max_cd[0] if team_id == 0 else self.max_cd[1]
        state.append(self.agents[team_id]["cd"] / float(max_cd))
        state.append(self.agents[enemy_team_id]["cd"] / float(max_cd))

        if action_mask:
            if self.agents[team_id][2] > 0:
                legal_actions = {}
                for i in range(5):
                    legal_actions[i] = 1
            else:
                legal_actions = None
            extra_info = {"legal_actions": legal_actions}
            return state, extra_info
        else:
            return state

    def check_shooting(self, team_id, shooting_action):
        enemy_team_id = 1 if team_id == 0 else 0
        shooting_action -= 5

        def check_pos(equal_idx, is_less_equal):
            compare_idx = 1 if equal_idx == 0 else 0
            if self.agents[team_id]["pos"][equal_idx] == self.agents[enemy_team_id]["pos"][equal_idx]:
                if (
                    is_less_equal
                    and self.agents[team_id]["pos"][compare_idx]
                    <= self.agents[enemy_team_id]["pos"][compare_idx]
                ):
                    return True
                if (
                    not is_less_equal
                    and self.agents[team_id]["pos"][compare_idx]
                    >= self.agents[enemy_team_id]["pos"][compare_idx]
                ):
                    return True
            return False

        if shooting_action == 0:
            return check_pos(1, True)
        elif shooting_action == 1:
            return check_pos(0, True)
        elif shooting_action == 2:
            return check_pos(1, False)
        elif shooting_action == 3:
            return check_pos(0, False)

        return False

    def move(self, team_id, action):
        if action == 0:
            return
        if action == 1:
            self.agents[team_id]["pos"][0] = min(self.agents[team_id]["pos"][0] + 1, self.size - 1)
        elif action == 2:
            self.agents[team_id]["pos"][1] = min(self.agents[team_id]["pos"][1] + 1, self.size - 1)
        elif action == 3:
            self.agents[team_id]["pos"][0] = max(self.agents[team_id]["pos"][0] - 1, 0)
        elif action == 4:
            self.agents[team_id]["pos"][1] = max(self.agents[team_id]["pos"][1] - 1, 0)

    def step(self, actions):
        self.frame_no += 1
        done = False
        # rewards = [0, 0]
        rewards = [self.MOVE_REWARD, self.MOVE_REWARD]
        team0_color, team1_color = "green", "green"

        # first check shooting action
        if actions["left"] >= 5:
            # shooting not allowed if cd > 0
            if self.agents[0]["cd"] > 0:
                actions["left"] = 0
        if actions["right"] >= 5:
            if self.agents[1]["cd"] > 0:
                actions["right"] = 0
        # team 0 use shooting
        if actions["left"] >= 5:
            # hit the target
            if self.check_shooting(0, actions["left"]):
                done = True
                rewards[0] += self.SHOOTING_REWARD
                rewards[1] -= self.SHOOTING_REWARD
                # self.team_life[1] -= 1
                team1_color = "red"
            # shooting cd
            self.agents[0]["cd"] = self.max_cd[0]
        # team 1 use shooting
        if actions["right"] >= 5:
            # hit the target
            if self.check_shooting(1, actions["right"]):
                done = True
                rewards[0] -= self.SHOOTING_REWARD
                rewards[1] += self.SHOOTING_REWARD
                # self.team_life[0] -= 1
                team0_color = "red"
            self.agents[1]["cd"] = self.max_cd[1]

        # if self.team_life[0] <= 0 or self.team_life[1] <= 0:
        #     done = True
        # else:
        if True:
            # decrease cd
            self.agents[0]["cd"] = max(self.agents[0]["cd"] - 1, 0)
            self.agents[1]["cd"] = max(self.agents[1]["cd"] - 1, 0)

            # move action
            if actions["left"] < 5:
                self.move(0, actions["left"])
            if actions["right"] < 5:
                self.move(1, actions["right"])

            # check star reward
            recreate_star = False
            if (
                self.star_pos[0] == self.agents[0]["pos"][0]
                and self.star_pos[1] == self.agents[0]["pos"][1]
            ):
                recreate_star = True
                rewards[0] += self.STAR_REWARD
            if (
                self.star_pos[0] == self.agents[1]["pos"][0]
                and self.star_pos[1] == self.agents[1]["pos"][1]
            ):
                recreate_star = True
                rewards[1] += self.STAR_REWARD

            if recreate_star:
                self.create_star()

            if self.frame_no >= 100:
                done = True

        if self.log_print:
            self.trace.append([
                [team0_color, self.agents[0]["pos"] + [0]],
                [team1_color, self.agents[1]["pos"] + [0]],
                ['yellow', self.star_pos + [0]],
            ])

        infos = {"left": {}, "right": {}}
        if done:
            if rewards[0] == rewards[1]:  # draw
                infos["left"]["result"] = "draw"
                infos["right"]["result"] = "draw"
                infos["left"]["win"] = 0
                infos["right"]["win"] = 0
            elif rewards[0] > rewards[1]:
                infos["left"]["result"] = "win"
                infos["right"]["result"] = "lose"
                infos["left"]["win"] = 1
                infos["right"]["win"] = 0
            else:
                infos["left"]["result"] = "lose"
                infos["right"]["result"] = "win"
                infos["left"]["win"] = 0
                infos["right"]["win"] = 1
            # infos["left"]["life_remaining"] = int(self.team_life[0])
            # infos["right"]["life_remaining"] = int(self.team_life[1])
            if self.log_print:
                plot = PlotTrace(self.size)
                plot.plot(self.trace, self.count)
                self.count += 1

        obs = {"left": self.get_state(0), "right": self.get_state(1)}
        rew = {"left": rewards[0], "right": rewards[1]}
        return obs, rew, {"__all__": done}, infos


def test():
    import time

    env = GridShooting(log_print=True)
    obs = env.reset()
    print("obs:", obs)  # obs.shape, type(obs))
    print("====:", env.action_space, env.observation_space)
    # env.action_space: MultiBinary(12 * num_players)
    # env.observation_space: Box(0, 255, (200, 256, 3), uint8) (IMAGE)
    # Box(0, 255, (65536,), uint8) (RAM)
    for _ in range(300):
        act0 = env.action_space.sample()
        act1 = env.action_space.sample()
        # print("act:", act0, act1)
        obs, rew, done, info = env.step({"left": act0, "right": act1})
        # (200, 256, 3) 0.0 False {'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}
        # env.render("human")
        if done["__all__"]:
            obs = env.reset()
            print("done:", done)
        # time.sleep(0.1)
    env.close()


if __name__ == "__main__":
    test()
