import numpy as np
import gym

# from .role import FangShi
from .role import Tank
from .role import Shooter


class Battle(gym.Env):
    def __init__(
        self,
        print_game_log=False,
        select_skill=False,
        races=None,
        classify=False,
        rule_team=-1,  # 0, 1
        fail_reward=0.1,
        hard_ai=False,
    ):
        if races is None:
            races = ['Tank', 'Tank']
        self.races = races
        self.print_game_log = print_game_log
        self.select_skill = select_skill
        self.classify = classify
        self.rule_team = rule_team
        self.fail_reward = fail_reward
        self.hard_ai = hard_ai

        self.players = []
        self.small_skills = []
        self.time_buff1 = []
        self.skill_num = []
        self.buff_num = []
        for race in self.races:
            if race == 'Tank':
                self.players.append(Tank(position=[0, 0], carry_skill=[], print_game_log=self.print_game_log))
                self.small_skills.append(np.array(list(range(1, 11)), dtype=np.int8))
                self.time_buff1.append(-0.1)
                self.skill_num.append(12)
                self.buff_num.append(16)
            if race == 'Shooter':
                self.players.append(Shooter(position=[0, 0], carry_skill=[], print_game_log=self.print_game_log))
                self.small_skills.append(np.array(list(range(1, 13)), dtype=np.int8))
                self.time_buff1.append(-0.13)
                self.skill_num.append(14)
                self.buff_num.append(14)
            # if race == 'FangShi':
            #     self.players.append(FangShi(position=[0, 0], carry_skill=[]))
            #     self.small_skills.append(np.array(list(range(1, 17)), dtype=np.int8))
            #     self.time_buff1.append(-0.2)
            #     self.skill_num.append(16)
            #     self.buff_num.append(19)
        self.done = False
        self.reward = 0.0
        self.game_time = 0
        self.distance = 8.0
        self.winner = -1
        self.time_buff2 = 0.0
        self.use_skill_fail = [0, 0]
        self.punish = [0.0, 0.0]

        if self.print_game_log:
            print(f"Teams {self.races}")
            if self.rule_team != -1:
                print(f'Team {self.rule_team} use built-in rule')

    def reset(self):
        for i, race in enumerate(self.races):
            select_small_skills = np.random.choice(self.small_skills[i], 5, replace=False).tolist()
            self.use_skill_fail = [0, 0]
            if race == 'Tank':
                if np.random.uniform() < 0.5:
                    select_small_skills.append(11)
                if np.random.uniform() < 0.5:
                    select_small_skills.append(12)
                initial_skills = [5, 6, 8, 9, 10, 11, 12]
                position = [0, -1] if i == 0 else [0, 1]
                carry_skill = select_small_skills if self.select_skill else initial_skills
                self.players[i] = Tank(position=position, carry_skill=carry_skill, print_game_log=self.print_game_log)
                self.time_buff1[i] = -0.1
            if race == 'Shooter':
                if np.random.uniform() < 0.5:
                    select_small_skills.append(13)
                if np.random.uniform() < 0.5:
                    select_small_skills.append(14)
                initial_skills = [1, 3, 8, 12, 13, 14]
                position = [0, -2] if i == 0 else [0, 2]
                carry_skill = select_small_skills if self.select_skill else initial_skills
                self.players[i] = Shooter(position=position, carry_skill=carry_skill, print_game_log=self.print_game_log)
                self.time_buff1[i] = -0.13
            # if race == 'FangShi':
            #     if np.random.uniform() < 0.5:
            #         select_small_skills.append(15)
            #     if np.random.uniform() < 0.5:
            #         select_small_skills.append(16)
            #     initial_skills = [6, 16]
            #     position = [0, -2] if i == 0 else [0, 2]
            #     carry_skill = select_small_skills if self.select_skill else initial_skills
            #     self.players[i] = FangShi(position=position, carry_skill=carry_skill)
            #     self.time_buff1[i] = -0.2
            #     self.distance = 4.0
        self.time_buff2 = 0.0
        self.done = False
        self.game_time = 0
        self.winner = -1

        obs = [self.get_state(0), self.get_state(1)]
        return obs

    def step(self, actions):
        self.reward = 0.0
        self.punish = [0.0, 0.0]

        # Execute rounds for both teams.
        if self.print_game_log:
            print('=' * 50)
        for i in list(range(len(self.races) - 1, -1, -1)):
            opponent = i ^ 1
            if self.print_game_log:
                print(f'Team {i} round:')
            if self.classify:
                if self.races[i] == 'Tank':
                    actions[i] = self.translate_action_jiashi(actions[i], self.players[i])
                if self.races[i] == 'Shooter':
                    actions[i] = self.translate_action_sheshou(actions[i], self.players[i], self.players[opponent])

            if self.rule_team == i:
                actions[i] = self.enemy_ai()
            if self.rule_team == opponent:
                actions[opponent] = self.enemy_ai()
            if i == 1:
                self.reward += self.round(self.players[i], self.players[opponent], actions[i])
            else:
                self.reward -= self.round(self.players[i], self.players[opponent], actions[i])

            self.players[i].refresh_status(self.time_buff1[i], self.time_buff2, self.players[opponent])
            self.distance = np.sqrt(
                np.square(self.players[i].position[0] - self.players[opponent].position[0]) +
                np.square(self.players[i].position[1] - self.players[opponent].position[1])
            )

        # Update games.
        self.game_time += 1
        if self.game_time > 30 and self.game_time % 2 == 0:
            self.time_buff1[0] += 0.1
            self.time_buff1[1] += 0.1
            self.time_buff2 += 0.02
        if self.game_time >= 120:
            self.done = True
        if self.players[0].now_HP <= 0:
            self.done = True
            self.winner = 1
            self.reward -= 10
        if self.players[1].now_HP <= 0:
            self.done = True
            self.winner = 0
            self.reward += 10

        obs = [self.get_state(0), self.get_state(1)]
        reward = [self.reward-self.punish[0], -self.reward-self.punish[1]]
        info = self.winner
        return obs, reward, self.done, info

    def get_state(self, team_id):
        # if self.print_game_log:
        #     # print('-' * 50)
        #     select_skill = []
        #     skill_cd = []
        #     my_buffs = []
        #     enemy_buffs = []
        #     for i in range(self.skill_num[team_id]):
        #         if self.players[team_id].skills[i + 1].carry:
        #             select_skill.append(i + 1)
        #     for i in range(self.skill_num[team_id]):
        #         skill_cd.append(self.players[team_id].skills[i + 1].now_cd)
        #     for i in range(self.buff_num[team_id]):
        #         my_buffs.append(self.players[0].buffs[i].now_time)
        #     for i in range(self.buff_num[team_id]):
        #         enemy_buffs.append(self.players[1].buffs[i].now_time)
        #     # print('my skill', select_skill)
        #     # print('my skill cd', skill_cd)
        #     # print('0\' buffs', my_buffs)
        #     # print('1\' buffs', enemy_buffs)

        legal_actions = [True]
        observation = list()
        observation.append(self.players[0].now_HP / self.players[0].max_HP)
        observation.append(self.players[0].now_ATK[0] / self.players[0].initial_ATK[0])
        observation.append(self.players[0].now_ATK[1] / self.players[0].initial_ATK[1])
        observation.append(self.players[0].now_DEF / self.players[0].initial_DEF)
        observation.append(self.players[1].now_HP / self.players[1].max_HP)
        observation.append(self.players[1].now_ATK[0] / self.players[1].initial_ATK[0])
        observation.append(self.players[1].now_ATK[1] / self.players[1].initial_ATK[1])
        observation.append(self.players[1].now_DEF / self.players[1].initial_DEF)
        for i in range(self.skill_num[team_id]):
            if self.players[team_id].skills[i + 1].carry:
                if self.players[team_id].skills[i + 1].now_cd <= 0:
                    observation.append(1.0)
                    legal_actions.append(True)
                else:
                    observation.append(0.0)
                    legal_actions.append(False)
            else:
                observation.append(-1.0)
                legal_actions.append(False)
        if self.races[team_id] == 'Tank':
            for i in [2, 4, 6, 8, 9, 10, 12]:
                observation.append(min(1.0, self.players[0].buffs[i].now_time))
            for i in [2, 4, 6, 8, 9, 10, 12]:
                observation.append(min(1.0, self.players[1].buffs[i].now_time))
            observation.append(self.players[0].buffs[3].now_time / 30.0)
            observation.append(self.players[1].buffs[3].now_time / 30.0)
        else:
            for i in range(self.buff_num[team_id] - 1):
                observation.append(min(1.0, self.players[0].buffs[i].now_time))
            for i in range(self.buff_num[team_id] - 1):
                observation.append(min(1.0, self.players[1].buffs[i].now_time))
        observation.append(max(np.floor((self.game_time - 30) / 2), 0))
        observation = np.array(observation)
        return observation  # , self.use_skill_fail, legal_actions

    def round(self, attacker, defender, action):
        reward_all = 0

        if isinstance(attacker, Tank):
            if self.print_game_log:
                print("\t---BUFF---")
            # suffering buff from Tank
            if isinstance(defender, Tank):
                if attacker.buffs[2].now_time > 0:
                    if self.print_game_log:
                        print("\tattacker buff: 受到持续伤害（流火·炎）")
                    reward, _ = defender.step(16, attacker, self.distance)
                    reward_all += reward

            # suffering buff from Shooter
            elif isinstance(defender, Shooter):
                skill_of_buff = [18, 19, 20, 21, 22, 23]
                for buff_id in [17, 18, 19, 20, 21, 22]:
                    if attacker.buffs[buff_id].now_time > 0:
                        if self.print_game_log:
                            print(f"\tattacker buff: 受到持续伤害 ({attacker.buffs[buff_id].name})")
                        reward, _ = defender.step(skill_of_buff[buff_id], attacker, self.distance)
                        reward_all += reward

            if attacker.buffs[12].now_time > 0:
                if self.print_game_log:
                    print("\tattacker buff: 光环（青龙·灭），给敌人伤害并降低其防御力数值")
                reward, _ = attacker.step(17, defender, self.distance)
                reward_all -= reward

            if self.print_game_log:
                print("\t---ACTION---")
            reward, punish = attacker.step(action, defender, self.distance)
            reward_all -= reward
            if punish:
                if attacker == self.players[0]:
                    self.punish[0] += self.fail_reward
                    self.use_skill_fail[0] += 1
                else:
                    self.punish[1] += self.fail_reward
                    self.use_skill_fail[1] += 1

        if isinstance(attacker, Shooter):
            if self.print_game_log:
                print("\t---BUFF---")
            # suffering buff from Tank
            if isinstance(defender, Tank):
                if attacker.buffs[16].now_time > 0:
                    if self.print_game_log:
                        print("\tattacker buff: 受到持续伤害（流火·炎）")
                    reward, _ = defender.step(16, attacker, self.distance)
                    reward_all += reward

            # suffering buff from Shooter
            elif isinstance(defender, Shooter):
                skill_of_buff = {0: 18, 2: 19, 3: 20, 4: 21, 5: 22, 6: 23}
                for buff_id in [0, 2, 3, 4, 5, 6]:
                    if attacker.buffs[buff_id].now_time > 0:
                        if self.print_game_log:
                            print(f"\tattacker buff: 受到持续伤害 (buff={buff_id})")
                        reward, _ = defender.step(skill_of_buff[buff_id], attacker, self.distance)
                        reward_all += reward

            if attacker.buffs[9].now_time > 0:
                if self.print_game_log:
                    print("\tattacker buff: 光环伤害")
                reward, _ = attacker.step(24, defender, self.distance)
                reward_all -= reward
            if attacker.buffs[8].now_time > 0:
                if self.print_game_log:
                    print("\tattacker buff: 光环伤害1")
                reward, _ = attacker.step(27, defender, self.distance)
                reward_all -= reward
            elif attacker.buffs[11].now_time > 0:
                if self.print_game_log:
                    print("\tattacker buff: 光环伤害2")
                reward, _ = attacker.step(25, defender, self.distance)
                reward_all -= reward
                reward, _ = attacker.step(26, defender, self.distance)
                reward_all -= reward

            if self.print_game_log:
                print("\t---ACTION---")
            reward, punish = attacker.step(action, defender, self.distance)
            reward_all -= reward
            if punish:
                if attacker == self.players[0]:
                    self.punish[0] += self.fail_reward
                    self.use_skill_fail[0] += 1
                else:
                    self.punish[1] += self.fail_reward
                    self.use_skill_fail[1] += 1

        # if self.races[attacker_id] == 'FangShi':
        #     if attacker.buffs[0].now_time > 0:
        #         reward, _ = defender.do_action(20, attacker, self.distance)
        #         reward_all += reward
        #     if attacker.buffs[9].now_time > 0:
        #         reward, _ = attacker.do_action(25, defender, self.distance)
        #         reward_all -= reward
        #     if attacker.buffs[10].now_time > 0:
        #         reward, _ = attacker.do_action(26, defender, self.distance)
        #         reward_all -= reward
        #     yin_dao = False
        #     for i in [1, 3, 5, 7, 15, 16]:
        #         if attacker.buffs[i].now_time > 0:
        #             yin_dao = True
        #             break
        #     if yin_dao:
        #         if i == 1:
        #             reward, _ = attacker.do_action(21, defender, self.distance)
        #             reward_all -= reward
        #         if i == 3:
        #             reward, _ = attacker.do_action(22, defender, self.distance)
        #             reward_all -= reward
        #         if i == 5:
        #             reward, _ = attacker.do_action(23, defender, self.distance)
        #             reward_all -= reward
        #         if i == 7:
        #             reward, _ = attacker.do_action(24, defender, self.distance)
        #             reward_all -= reward
        #         if i == 15:
        #             reward, _ = attacker.do_action(27, defender, self.distance)
        #             reward_all -= reward
        #         if i == 16:
        #             reward, _ = attacker.do_action(28, defender, self.distance)
        #             reward_all -= reward
        #             reward, _ = attacker.do_action(29, defender, self.distance)
        #             reward_all -= reward
        #     else:
        #         reward, punish = attacker.do_action(action, defender, self.distance)
        #         reward_all -= reward
        #         if punish:
        #             if attacker == self.players[0]:
        #                 self.punish[0] += self.fail_reward
        #                 self.use_skill_fail[0] += 1
        #             else:
        #                 self.punish[1] += self.fail_reward
        #                 self.use_skill_fail[1] += 1
        return reward_all

    # TODO
    def enemy_ai(self):
        self_player = self.players[self.rule_team]
        enemy_player = self.players[(self.rule_team + 1) % 2]

        if isinstance(enemy_player, Tank):
            buff_dict = {
                "眩晕（虎贲·冲）": 9,
                "眩晕（贯索·擒）": 10,
                "攻击力百分比降低（司怪·困）": 3,
                "光环（青龙·灭）": 12,
            }
        elif isinstance(enemy_player, Shooter):
            buff_dict = {
                "眩晕（虎贲·冲）": 18,
                "眩晕（贯索·擒）": 19,
                "攻击力百分比降低（司怪·困）": 17,
                "光环（青龙·灭）": 14,
            }
    
        if isinstance(self_player, Tank):
            action = 5
            if enemy_player.buffs[buff_dict["眩晕（虎贲·冲）"]].now_time <= 0 and \
                    enemy_player.buffs[buff_dict["眩晕（贯索·擒）"]].now_time <= 0:
                if self_player.skills[9].now_cd <= 0 and self_player.skills[9].carry:
                    action = 9
                if self_player.skills[10].now_cd <= 0 and self_player.skills[10].carry:
                    action = 10
            elif self.hard_ai:
                if self_player.skills[12].now_cd <= 0 and \
                        self_player.skills[12].carry and \
                        self.game_time >= 60:
                    action = 12
            if self.hard_ai:
                if enemy_player.buffs[buff_dict["攻击力百分比降低（司怪·困）"]].now_time <= 5 and \
                        self_player.skills[6].now_cd <= 0 and \
                        self_player.skills[6].carry:
                    action = 6
                if self_player.skills[8].now_cd <= 0 and \
                        self_player.skills[8].carry and \
                        self_player.buffs[11].now_time:
                    action = 8
                if self_player.skills[11].now_cd <= 0 and \
                        self_player.skills[11].carry and \
                        enemy_player.buffs[buff_dict["光环（青龙·灭）"]].now_time:
                    action = 11
            else:
                if self_player.skills[12].now_cd <= 0 and \
                        self_player.skills[12].carry and \
                        self_player.now_HP < 0.4 * self_player.max_HP:
                    action = 12
                if self_player.skills[11].now_cd <= 0 and \
                        self_player.skills[11].carry and \
                        self_player.now_HP < 0.5 * self_player.max_HP:
                    action = 11
            if self_player.skills[action].distance < self.distance:  # 距离不够就靠近
                action = 14

        elif isinstance(self_player, Shooter):
            action = 0  # 默认物攻
            if self_player.skills[3].now_cd <= 0 and self_player.skills[3].carry:
                action = 3
            if isinstance(enemy_player, Shooter):
                if self_player.skills[1].now_cd <= 0 and \
                    self_player.skills[1].carry and \
                    (enemy_player.buffs[0].now_time <= 0 or
                     (enemy_player.buffs[0].now_time > 0 and
                      enemy_player.buffs[3].now_time > 0)):
                    action = 1
            if self_player.skills[8].now_cd <= 0 and self_player.skills[8].carry:
                action = 8
            if self_player.skills[12].now_cd <= 0 and self_player.skills[12].carry:
                action = 12
            if self_player.skills[14].now_cd <= 0 and self_player.skills[14].carry:
                action = 14
            if self_player.skills[13].now_cd <= 0 and self_player.skills[13].carry:
                action = 13
            if self_player.skills[action].distance < self.distance:  # 距离不够就靠近
                action = 16
        # if self.config.school == 'FangShi':
        #     action = 6
        #     if self_player.skills[16].now_cd <= 0:
        #         action = 16
        else:
            raise ValueError(type(self_player))
        return action

    def translate_action_jiashi(self, input_action, atker):
        if input_action == 7:
            return 3
        if input_action == 6:
            return 12
        if input_action == 5:
            return 11
        if input_action == 0:
            for i in [5, 1]:
                if self.players[0].skills[i].carry:
                    return i
            return 0
        if input_action == 2:
            return 6
        if input_action == 4:
            return 4
        if input_action == 1:
            if atker.skills[10].carry and atker.skills[10].now_cd <= 0:
                return 10
            else:
                return 9
        if input_action == 3:
            if atker.skills[7].carry and atker.skills[7].now_cd <= 0:
                return 7
            else:
                return 8

    def translate_action_sheshou(self, input_action, atker, defer):
        if input_action == 1:
            return 8
        if input_action == 2:
            return 12
        if input_action == 3:
            return 13
        if input_action == 4:
            return 14
        if input_action == 0:
            if atker.skills[1].carry and atker.skills[1].now_cd <= 0 and defer.buffs[0].now_time <= 0:
                return 1
            if atker.skills[5].carry and atker.skills[5].now_cd <= 0 and defer.buffs[5].now_time <= 0:
                return 5
            if atker.skills[3].carry and atker.skills[3].now_cd <= 0 and defer.buffs[3].now_time <= 0:
                return 3
            if atker.skills[6].carry and atker.skills[6].now_cd <= 0 and defer.buffs[6].now_time <= 0 and defer.buffs[5].now_time:
                return 6
            if atker.skills[10].carry and atker.skills[10].now_cd <= 0 and atker.now_HP / atker.max_HP < 0.5:
                return 10
            if atker.skills[3].carry and atker.skills[3].now_cd <= 0:
                return 3
            if atker.skills[1].carry and atker.skills[1].now_cd <= 0:
                return 1
            return 0
