import copy

import numpy as np

# from env.qnyh_small.buff import FangShiBuffs
from env.qnyh_small.buff import TankBuffs
from env.qnyh_small.buff import ShooterBuffs
# from env.qnyh_small.skill import FangShiSkills
from env.qnyh_small.skill import TankSkills
from env.qnyh_small.skill import ShooterSkills


class Role(object):
    def __init__(
            self,
            max_hp,
            initial_atk,
            initial_def,
            position,
            carry_skill,
            skills,
            buffs,
            crit,
            crit_rate,
            parry,
            parry_rate,
            print_game_log=False,
    ):
        self.max_HP = max_hp
        self.now_HP = max_hp
        self.initial_ATK = initial_atk
        self.now_ATK = initial_atk
        self.initial_DEF = initial_def
        self.now_DEF = initial_def
        self.position = position
        self.skills = skills
        self.buffs = buffs
        self.carry_skill = carry_skill
        self.crit = crit
        self.crit_rate = crit_rate
        self.parry = parry
        self.parry_rate = parry_rate
        self.reward = 0
        self.print_game_log = print_game_log

    def move(self, x, y, d):
        out = False
        self.position[0] += x * d / np.sqrt(np.square(x) + np.square(y))
        self.position[1] += y * d / np.sqrt(np.square(x) + np.square(y))
        if np.abs(self.position[0]) > 4:
            out = True
            self.position[0] = 4 * np.sign(self.position[0])
        if np.abs(self.position[1]) > 6:
            out = True
            self.position[1] = 6 * np.sign(self.position[1])
        if out and self.print_game_log:
            print('out of bounds')
        return out

    def get_damage(self, out_atk, damage, mul_hurt, reduce):
        """

        Returns:
            real damage
        """
        if np.random.uniform() < self.crit_rate:
            crit = self.crit
        else:
            crit = 1
        if np.random.uniform() < self.parry_rate:
            parry = self.parry
        else:
            parry = 0
        hurt = max(out_atk * 0.02, ((out_atk - self.now_DEF) * crit) * 0.15 * max(0.25, 1 + mul_hurt)) - parry
        real_damage = max(hurt * damage, 0) * reduce
        if real_damage < self.now_HP:
            self.now_HP -= real_damage
        else:
            real_damage = self.now_HP
            self.now_HP = 0
        return real_damage

    def use_skill(self, action):
        if isinstance(self, Tank):
            return action <= 12 or action >= 16
        elif isinstance(self, Shooter):
            return action <= 14 or action >= 18

    def is_not_dizzy(self, action):
        if isinstance(self, Tank):
            return self.buffs[9].now_time <= 0 and \
                   self.buffs[10].now_time <= 0 and \
                   self.buffs[25].now_time <= 0
        elif isinstance(self, Shooter):
            return 18 <= action <= 24 or (self.buffs[7].now_time <= 0 and
                                          self.buffs[18].now_time <= 0 and
                                          self.buffs[19].now_time <= 0)
        else:
            return True

    def buff_is_dizzy(self, buff):
        if isinstance(self, Tank):
            return buff in [9, 10, 25]
        elif isinstance(self, Shooter):
            return buff in [7, 18, 19]
        else:
            return False

    def is_not_immune(self):
        if isinstance(self, Tank):
            return self.buffs[4].now_time <= 0
        else:
            return True

    def is_silent(self, action):
        if isinstance(self, Tank):
            return self.buffs[23].now_time > 0 and action <= 12
        elif isinstance(self, Shooter):
            return self.buffs[10].now_time > 0 and (action <= 14 or action >= 25)
        else:
            return False

    def refresh_status(self, time_buff1, time_buff2, enemy):
        raise NotImplementedError

    def step(self, action, enemy, distance):
        """Steps the role.

        Returns:
            reward
            punish failed actions
        """
        self.reward = 0
        if self.is_not_dizzy(action):  # 没有被晕
            if self.use_skill(action):  # 使用了技能
                if self.is_silent(action):  # 被沉默
                    action = 0
                skill = self.skills[action]
                if skill.now_cd <= 0 and skill.carry:  # 技能CD好了 且 这个技能被选了
                    if skill.distance >= distance:  # 距离够
                        self.do_skill(action, skill, enemy, distance)
                        return self.reward, False
                    else:
                        if self.print_game_log:
                            print(f'\tSkill {skill.name} distance not enough')
                        return self.reward, True
                else:
                    if self.print_game_log:
                        if skill.now_cd <= 0:
                            print(f'\tSkill {skill.name} is not carried')
                        else:
                            print(f'\tSkill {skill.name} in cd')
                    return self.reward, True
            else:
                self.do_move(action, enemy, distance)
                return self.reward, False
        else:
            if self.print_game_log:
                print('\tBeing dizzy...')
            return self.reward, False

    def do_skill(self, action, skill, enemy, distance):
        raise NotImplementedError

    def do_move(self, action, enemy, distance):
        raise NotImplementedError

    def do_buff(self, buff_id, self_buff, success, enemy):
        for i in range(len(buff_id)):
            # buff to self
            if self_buff[i]:
                self.buffs[buff_id[i]].now_time = self.buffs[buff_id[i]].initial_time + 1
                if self.print_game_log:
                    print(f'\tAdd buff to self: {self.buffs[buff_id[i]].name}')

            # buff to enemy
            else:
                if success and enemy.is_not_immune():
                    if enemy.buff_is_dizzy(buff_id[i]):
                        if np.random.uniform() < enemy.buffs[buff_id[i]].value:
                            enemy.buffs[buff_id[i]].now_time = enemy.buffs[buff_id[i]].initial_time
                            if self.print_game_log:
                                print(f'\tAdd buff to opponent: {enemy.buffs[buff_id[i]].name}')
                        else:
                            if self.print_game_log:
                                print('\tAdd dizziness failed!')
                    else:
                        enemy.buffs[buff_id[i]].now_time = enemy.buffs[buff_id[i]].initial_time
                        if self.print_game_log:
                            print(f'\tAdd buff to opponent: {enemy.buffs[buff_id[i]].name}')

class Tank(Role):
    def __init__(self, carry_skill, position, print_game_log):
        super(Tank, self).__init__(
            max_hp=107951.0,
            initial_atk=[36943.0, 36943.0],
            initial_def=31038.0,
            crit=2.02,
            crit_rate=0.33,
            parry=1626,
            parry_rate=0.4,
            position=position,
            carry_skill=carry_skill,
            skills=TankSkills().get_skills(),
            buffs=TankBuffs().get_buffs(),
            print_game_log=print_game_log,
        )
        for i in self.carry_skill:
            self.skills[i].carry = True

    def refresh_status(self, time_buff1, time_buff2, enemy):
        # decrease skill cd time
        for skill in self.skills:
            skill.refresh()

        # decrease buff time
        for buff in self.buffs:
            buff.refresh()

        self.buffs[15].value = time_buff1  # damage calculation
        self.buffs[16].value = time_buff2  # attack calculation
        self.now_DEF = self.initial_DEF
        self.now_ATK = copy.deepcopy(self.initial_ATK)

        if self.buffs[5].now_time > 0:  # defense up
            self.now_DEF += self.buffs[5].value
        if self.buffs[14].now_time > 0:  # defense down
            self.now_DEF -= self.buffs[14].value
        if self.buffs[1].now_time > 0:  # defense down, attack up
            po_jun = self.now_DEF * self.buffs[1].value
            self.now_DEF -= po_jun
            self.now_ATK[0] += 0.65 * po_jun
            self.now_ATK[1] += 0.65 * po_jun

        # attack up
        for i in [0, 13]:
            if self.buffs[i].now_time > 0:
                self.now_ATK[0] += self.buffs[i].value
                self.now_ATK[1] += self.buffs[i].value
        if self.buffs[16].now_time > 0:
            self.now_ATK[0] *= 1 + self.buffs[16].value
            self.now_ATK[1] *= 1 + self.buffs[16].value

        # attack down
        for i in [3, 7, 11]:
            if self.buffs[i].now_time > 0:
                self.now_ATK[0] *= 1 - self.buffs[i].value
                self.now_ATK[1] *= 1 - self.buffs[i].value

    def do_skill(self, action, skill, enemy, distance):
        if self.print_game_log:
            print(f'\tTank released skill: {skill.name}')
        if np.random.uniform() < 0.967:  # 对敌效果是否命中
            success = True
        else:
            success = False
            if self.print_game_log:
                print(f'\tMissed!')

        if action == 9:  # 我撞过去
            x = enemy.position[0] - self.position[0]
            y = enemy.position[1] - self.position[1]
            d = distance - 1
            self.move(x, y, d)

        if action == 10 and success and enemy.is_not_immune():  # 把对面拉过来
            x = self.position[0] - enemy.position[0]
            y = self.position[1] - enemy.position[1]
            d = distance - 1
            enemy.move(x, y, d)

        # TODO
        if action == 7:
            self.debuff_clean(3)
        if action == 8:
            self.debuff_clean(9)

        skill.now_cd = skill.initial_cd

        atk = np.random.uniform(self.now_ATK[0], self.now_ATK[1])
        out_atk = skill.e1 * atk + skill.e2 if success else 0.0
        reduce = 1

        # buff: immune or damage reduction
        if isinstance(enemy, Tank):
            for i in [4, 6, 8]:
                if enemy.buffs[i].now_time > 0:
                    reduce *= 1 - enemy.buffs[i].value

        temp_damage = enemy.get_damage(out_atk, skill.damage, self.buffs[15].value, reduce)
        if self.print_game_log:
            print(f'\tDamage: {round(temp_damage, 4)}, enemy hp remained: {enemy.now_HP}')

        self.reward += temp_damage / self.max_HP * 50.0

        buff_id = skill.buff_id
        self_buff = [self.buffs[buff_id[i]].myself for i in range(len(buff_id))]
        if isinstance(enemy, Shooter):
            if action == 17:
                buff_id = [15]
                self_buff = [False]
            elif action == 4:
                buff_id = [16]
                self_buff = [False]
            elif action == 6:
                buff_id = [17]
                self_buff = [False]
            elif action == 9:
                buff_id = [0, 18]
                self_buff = [True, False]
            elif action == 10:
                buff_id = [19]
                self_buff = [False]
            elif action == 11:
                buff_id = [20]
                self_buff = [False]

        self.do_buff(buff_id, self_buff, success, enemy)

    def do_move(self, action, enemy, distance):
        if action == 13:  # move away
            if self.print_game_log:
                print('\tAway')
            x = self.position[0] - enemy.position[0]
            y = self.position[1] - enemy.position[1]
            self.move(x, y, 5)

        if action == 14:  # move closer
            if self.print_game_log:
                print('\tCloser')
            x = enemy.position[0] - self.position[0]
            y = enemy.position[1] - self.position[1]
            d = min(distance - 1, 5)
            self.move(x, y, d)

        if action == 15:  # idle
            if self.print_game_log:
                print('\tIdle')

    def debuff_clean(self, amount):
        buff_clean = []
        for buff_id in [2, 3, 11, 14, 17, 18, 19, 20, 21, 22]:
            if self.buffs[buff_id].now_time > 0:
                buff_clean.append(buff_id)
        for i in range(amount):
            if len(buff_clean) > 0:
                self.buffs[buff_clean[np.random.randint(0, len(buff_clean))]].now_time = 0


class Shooter(Role):
    def __init__(self, carry_skill, position, print_game_log):
        super(Shooter, self).__init__(
            max_hp=66371.0,
            initial_atk=[24794.0, 46706.0],
            initial_def=21503.0,
            crit=2.25,
            crit_rate=0.45,
            parry=0,
            parry_rate=0,
            position=position,
            carry_skill=carry_skill,
            skills=ShooterSkills().get_skills(),
            buffs=ShooterBuffs().get_buffs(),
            print_game_log=print_game_log,
        )
        for i in self.carry_skill:
            self.skills[i].carry = True

    def refresh_status(self, time_buff1, time_buff2, enemy):
        for skill in self.skills:
            skill.refresh()
        for buff in self.buffs:
            buff.refresh()

        self.buffs[13].value = time_buff1  # damage calculation
        self.buffs[14].value = time_buff2  # attack calculation
        self.now_ATK = copy.deepcopy(self.initial_ATK)

        # attack up
        if self.buffs[14].now_time > 0:
            self.now_ATK[0] *= 1 + self.buffs[14].value
            self.now_ATK[1] *= 1 + self.buffs[14].value

        for i in [2, 3, 6, 10, 19, 20, 23]:
            self.skills[i].damage = self.skills[i].initial_damage

        if isinstance(enemy, Shooter):
            if enemy.buffs[0].now_time > 0:
                for i in [2, 3, 6, 19, 20, 23]:
                    self.skills[i].damage *= 1 + self.buffs[0].value
            if enemy.buffs[5].now_time > 0:
                for i in [2, 3, 6, 19, 20, 23]:
                    self.skills[i].damage *= 1 + self.buffs[5].value

        self.skills[10].damage *= 2 - self.now_HP / self.max_HP
        self.buffs[7].value = self.buffs[7].initial_value
        self.buffs[7].value += 0.48 * (1 - enemy.now_HP / enemy.max_HP)

    def do_skill(self, action, skill, enemy, distance):
        if self.print_game_log:
            print(f'\tShooter released skill: {skill.name}')
        if np.random.uniform() < 0.967:  # 对敌效果是否命中
            success = True
        else:
            success = False
            if self.print_game_log:
                print('\tMissed!')

        if action == 1 and self.buffs[1].now_time <= 0:
            temp = self.skills[5].carry
            self.skills[5].carry = True
            self.step(5, enemy, distance)
            self.skills[5].carry = temp
            self.buffs[1].now_time = self.buffs[1].initial_time

        if action == 8 and success and distance < 4 and enemy.is_not_immune():  # 把对面击退
            x = enemy.position[0] - self.position[0]
            y = enemy.position[1] - self.position[1]
            d = 4.0
            enemy.move(x, y, d)

        if isinstance(enemy, Shooter) and self.buffs[12].now_time > 0:  # slow down
            skill.now_cd = int(np.ceil(skill.initial_cd / (1 - enemy.buffs[12].value)))
        else:
            skill.now_cd = skill.initial_cd

        atk = np.random.uniform(self.now_ATK[0], self.now_ATK[1])
        out_atk = skill.e1 * atk + skill.e2 if success else 0.0
        reduce = 1

        # buff: immune or damage reduction
        if isinstance(enemy, Tank):
            for i in [4, 6, 8]:
                if enemy.buffs[i].now_time > 0:
                    reduce *= 1 - enemy.buffs[i].value

        temp_damage = enemy.get_damage(out_atk, skill.damage, self.buffs[13].value, reduce)
        if self.print_game_log:
            print(f'\tDamage: {round(temp_damage, 4)}, hp remained: {enemy.now_HP}')

        self.reward += temp_damage / self.max_HP * 50.0

        buff_id = skill.buff_id
        self_buff = [self.buffs[buff_id[i]].myself for i in range(len(buff_id))]
        if isinstance(enemy, Tank):
            if action == 26:
                buff_id = [24]
                self_buff = [False]
            elif action == 27:
                buff_id = [23]
                self_buff = [False]
            elif action == 8:
                buff_id = [25]
                self_buff = [False]

        self.do_buff(buff_id, self_buff, success, enemy)

    def do_move(self, action, enemy, distance):
        if action == 15:
            if self.print_game_log:
                print('\tAway')
            x = self.position[0] - enemy.position[0]
            y = self.position[1] - enemy.position[1]
            self.move(x, y, 5)
        if action == 16:
            if self.print_game_log:
                print('\tClose')
            x = enemy.position[0] - self.position[0]
            y = enemy.position[1] - self.position[1]
            d = min(distance - 1, 5)
            self.move(x, y, d)
        if action == 17:
            if self.print_game_log:
                print('\tIdle')

# class FangShi(Role):
#     def __init__(self, carry_skill, position):
#         super(FangShi, self).__init__(
#             max_hp=83899.0,
#             initial_atk=[17206.0, 39464.0],
#             initial_def=24042.0,
#             crit=1.98,
#             crit_rate=0.46,
#             parry=0,
#             parry_rate=0,
#             position=position,
#             carry_skill=carry_skill,
#             skills=FangShiSkills().get_skills(),
#             buffs=FangShiBuffs().get_buffs(),
#         )
#         for i in self.carry_skill:
#             self.skills[i].carry = True
#
#     def refresh_status(self, time_buff1, time_buff2, enemy):
#         for skill in self.skills:
#             skill.refresh()
#         for i, buff in enumerate(self.buffs):
#             buff.refresh()
#             if i in [2, 4, 6, 8] and buff.now_time <= 0:
#                 buff.value = buff.initial_value
#         self.buffs[18].value = time_buff1
#         self.buffs[19].value = time_buff2
#         self.now_DEF = self.initial_DEF
#         if self.buffs[12].now_time > 0:
#             self.now_DEF *= 1 + self.buffs[12].value
#         self.now_ATK = copy.deepcopy(self.initial_ATK)
#         if self.buffs[14].now_time > 0:
#             self.now_ATK[0] *= 1 + self.buffs[14].value
#             self.now_ATK[1] *= 1 + self.buffs[14].value
#         if self.buffs[13].now_time > 0:
#             self.now_ATK[0] += 4050.0
#             self.now_ATK[1] += 4050.0
#         self.skills[21].damage = 0.68 * (1 + 0.3 * self.buffs[2].value)
#         self.skills[22].damage = 0.68 * (0.7 + 0.07 * self.buffs[4].value)
#         self.skills[23].e1 = 1.05 + 0.12 * self.buffs[6].value
#         self.skills[24].e1 = 0.8 + 0.08 * self.buffs[8].value
#         if self.now_HP / self.max_HP < 0.4:
#             self.buffs[11].value = 0.55
#         else:
#             self.buffs[11].value = 0
#
#     def do_action(self, action, enemy, distance):
#         self.reward = 0
#         if action <= 16 or action >= 20:  # 使用了技能
#             skill = self.skills[action]
#             if skill.now_cd <= 0 and skill.carry:  # 技能CD好了 且 这个技能被选了
#                 if skill.distance >= distance:  # 距离够
#                     if self.print_game_log:
#                         print("使用了：", skill.name)
#                     if np.random.uniform() < 0.967:  # 对敌效果是否命中
#                         success = True
#                     else:
#                         success = False
#                         if self.print_game_log:
#                             print("失手了")
#                     if action == 13 and self.now_HP / self.max_HP < 0.4:
#                         self.now_HP += 0.2 * self.max_HP
#                     if self.buffs[17].now_time <= 0:  # 是否减攻速
#                         skill.now_cd = skill.initial_cd
#                     else:
#                         skill.now_cd = int(
#                             np.ceil(skill.initial_cd / (1 - enemy.buffs[17].value))
#                         )
#                     atk = np.random.uniform(self.now_ATK[0], self.now_ATK[1])
#                     out_atk = skill.e1 * atk + skill.e2
#                     buff_id = skill.buff_id
#                     if not success:
#                         out_atk = 0.0
#                     reduce = 1
#                     for i in [11, 14]:
#                         if enemy.buffs[i].now_time > 0:
#                             reduce *= 1 - enemy.buffs[i].value
#                     temp_damage = enemy.get_damage(
#                         out_atk, skill.damage, self.buffs[18].value, reduce
#                     )
#                     if self.print_game_log:
#                         print("造成了伤害", round(temp_damage, 4), "，剩余血量", enemy.now_HP)
#                     self.reward += temp_damage / self.max_HP * 50.0
#                     if len(buff_id) > 0:
#                         for i in range(len(buff_id)):
#                             if self.buffs[buff_id[i]].myself:
#                                 self.buffs[buff_id[i]].now_time = (
#                                     self.buffs[buff_id[i]].initial_time + 1
#                                 )
#                                 if self.print_game_log:
#                                     print("给自己上了buff：", self.buffs[buff_id[i]].name)
#                             else:
#                                 if success:
#                                     enemy.buffs[buff_id[i]].now_time = self.buffs[
#                                         buff_id[i]
#                                     ].initial_time
#                                     if self.print_game_log:
#                                         print("给对方上了buff：", self.buffs[buff_id[i]].name)
#                                     if buff_id[i] in [2, 6]:
#                                         if self.print_game_log:
#                                             print(
#                                                 "目前等级：", enemy.buffs[buff_id[i]].value
#                                             )
#                                         enemy.buffs[buff_id[i]].value = min(
#                                             enemy.buffs[buff_id[i]].value + 1, 3
#                                         )
#                                     if buff_id[i] in [4, 8]:
#                                         if self.print_game_log:
#                                             print(
#                                                 "目前等级：", enemy.buffs[buff_id[i]].value
#                                             )
#                                         enemy.buffs[buff_id[i]].value = min(
#                                             enemy.buffs[buff_id[i]].value + 1, 6
#                                         )
#                     return self.reward, False
#                 else:
#                     if self.print_game_log:
#                         print("技能距离不够")
#                     return self.reward, True
#             else:
#                 if self.print_game_log:
#                     print("技能在CD 或 技能没选")
#                 return self.reward, True
#         else:
#             if action == 17:
#                 if self.print_game_log:
#                     print("远离了")
#                 x = self.position[0] - enemy.position[0]
#                 y = self.position[1] - enemy.position[1]
#                 self.move(x, y, 5)
#                 return self.reward, False
#             if action == 18:
#                 if self.print_game_log:
#                     print("靠近了")
#                 x = enemy.position[0] - self.position[0]
#                 y = enemy.position[1] - self.position[1]
#                 d = min(distance - 1, 5)
#                 self.move(x, y, d)
#                 return self.reward, False
#             if action == 19:
#                 if self.print_game_log:
#                     print("发呆了")
#                 return self.reward, False
