import gym
import numpy as np

from env.multi_agent_env import MultiAgentEnv
from env.qnyh_small.base_env import Battle


class QnyhSmallSelfplay(MultiAgentEnv):
    def __init__(
        self,
        print_game_log=False,
        select_skill=True,
        print_action_list=False,
        races=None,
        classify=False,
        rule_team=-1,
        hard_ai=False,
    ):
        if races is None:
            races = {"left": "Shooter", "right": "Tank"}  # 'Tank' 'Shooter' 'FangShi'
        self.env = Battle(
            print_game_log=print_game_log,
            select_skill=select_skill,
            races=[races["left"], races["right"]],
            classify=classify,
            rule_team=rule_team,  # -1, 0, 1
            hard_ai=hard_ai,
        )
        self.print_action_list = print_action_list

        self.action_list = []
        self.i = 0

        self.observation_space = {}
        self.action_space = {}
        for i, race in races.items():
            if race == "Tank":
                if classify:
                    self.observation_space[i] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(37,))
                    self.action_space[i] = gym.spaces.Discrete(8)
                else:
                    self.observation_space[i] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(37,))
                    self.action_space[i] = gym.spaces.Discrete(13)
            elif race == "Shooter":
                if classify:
                    self.observation_space[i] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(49,))
                    self.action_space[i] = gym.spaces.Discrete(5)
                else:
                    self.observation_space[i] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(49,))
                    self.action_space[i] = gym.spaces.Discrete(15)
            elif race == "FangShi":
                self.observation_space[i] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(61,))
                self.action_space[i] = gym.spaces.Discrete(17)
            else:
                raise NotImplementedError

        self.step_cnt = 0
        self.skill_record = {0: {}, 1: {}}
        self.special_skills = [8, 11, 12]

    def reset(self):
        self.i = 0
        self.action_list = []
        self.step_cnt = 0
        self.skill_record = {0: {}, 1: {}}
        # self.model_ids = {}
        # for team_model_info in game_info:
        #     model_id = team_model_info["member_infos"][0]["member_model_id"]
        #     if isinstance(model_id, list):
        #         model_id = model_id[0]
        #     self.model_ids[team_model_info["team_id"]] = model_id
        obs = self.env.reset()
        if self.print_action_list:
            print('skills0', self.env.players[0].carry_skill)
            print('skills1', self.env.players[1].carry_skill)
        return {"left": obs[0], "right": obs[1]}

    def step(self, actions):
        act_list = []
        for k in ["left", "right"]:
            if k in actions:
                act_list.append(actions[k])
            else:
                act_list.append(0)
        observations, rewards, done, winner = self.env.step(act_list)
        # log = self._logpack()
        infos = {"left": {}, "right": {}}
        if done:
            # log.add_scalar("win", int(0 == self.env.winner))
            # log.add_scalar('skill_fail_rate', float(self.env.use_skill_fail) / float(self.env.game_time))
            if self.print_action_list:
                print("action list", self.action_list, 0 == self.env.winner)
                print(self.i)
            if winner == -1:  # draw
                infos["left"]["result"] = "draw"
                infos["right"]["result"] = "draw"
                infos["left"]["win"] = 0
                infos["right"]["win"] = 0
            elif winner == 0:
                infos["left"]["result"] = "win"
                infos["right"]["result"] = "lose"
                infos["left"]["win"] = 1
                infos["right"]["win"] = 0
            else:
                infos["left"]["result"] = "lose"
                infos["right"]["result"] = "win"
                infos["left"]["win"] = 0
                infos["right"]["win"] = 1

        obs = {"left": observations[0], "right": observations[1]}
        rew = {"left": rewards[0], "right": rewards[1]}
        return obs, rew, {"__all__": done}, infos


def tfest():
    import time

    env = QnyhSmallSelfplay(
        # races={"left": "Shooter", "right": "Shooter"},
        # races={"left": "Tank", "right": "Tank"},
        races={"left": "Tank", "right": "Shooter"},
        # races={"left": "Shooter", "right": "Tank"},
        rule_team=0,
        # hard_ai=True,
        # print_game_log=True,
        # classify=True
    )
    # env = QnyhSmallSelfplay()
    # print("obs:", obs)  # obs.shape, type(obs))
    print("Env spaces:", env.action_space, env.observation_space)
    # env.action_space: MultiBinary(12 * num_players)
    # env.observation_space: Box(0, 255, (200, 256, 3), uint8) (IMAGE)
    # Box(0, 255, (65536,), uint8) (RAM)
    num_episodes = 1000
    left_win = 0
    right_win = 0
    draw = 0
    for i in range(num_episodes):
        obs = env.reset()
        done = {"__all__": False}
        while not done["__all__"]:
            act = [env.action_space["left"].sample(), env.action_space["right"].sample()]
            # act[1] = 17
            # act[1] = 15
            # print("act:", act)
            obs, rew, done, info = env.step({"left": act[0], "right": act[1]})
            assert obs["left"] in env.observation_space["left"]
            assert obs["right"] in env.observation_space["right"]
            # (200, 256, 3) 0.0 False {'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}
            # env.render("human")
            if done["__all__"]:
                print(f"Episode {i}, reward: {rew}, info: {info}")
                if info["left"]["result"] == "win":
                    left_win += 1
                elif info["right"]["result"] == "win":
                    right_win += 1
                else:
                    draw += 1
    print(f"left_win: {left_win}, right_win: {right_win}, draw: {draw}")
    env.close()


if __name__ == "__main__":
    tfest()
