import cv2
import numpy as np
import platform
import time
import keyboard  # 这个库用来发送按键命令
import os
import ray
import mss
from typing import Dict

if platform.system() == "Windows":
    import pygetwindow as gw
else:
    print("pygetwindow not imported, not Windows system.")
    gw = None


# 根据窗口标题获取Hollow Knight游戏窗口的句柄
def get_game_window(title):
    windows = gw.getWindowsWithTitle(title)
    return windows[0] if windows else None

# 捕获指定窗口的图像
def capture_window_image(window, sct):
    if not window:
        return None
    
    # detect if window is already active
    # if not window.isActive:
    #     print("Pausing to activate window...")
    #     return None
    window.restore()
    window.activate()

    bbox = window.box
    # convert to (x1, y1, x2, y2)
    bbox = (bbox.left, bbox.top, bbox.left + bbox.width, bbox.top + bbox.height)
    # remove window border
    bbox = (bbox[0] + 15, bbox[1] + 64, bbox[2] - 16, bbox[3] - 16)
    screenshot = sct.grab(bbox)
    frame = np.array(screenshot)
    return frame    

def tail(filename, n=1):
    with open(filename, 'rb') as file:
        file.seek(0, 2)  # 移动到文件末尾
        filesize = file.tell()
        lines_found = []
        while filesize > 0 and len(lines_found) <= n:
            file.seek(filesize - 1)
            next_char = file.read(1)
            if next_char == b'\n':
                lines_found.append(file.readline().decode())
            filesize -= 1
        if filesize == 0:
            file.seek(0)
            lines_found.append(file.readline().decode())
        return lines_found[-n:]


def tail_until(filename, last_id):
    with open(filename, 'rb') as file:
        file.seek(0, 2)  # 移动到文件末尾
        filesize = file.tell()
        lines_found = []
        ignore_last_line = True
        while filesize > 0:
            file.seek(filesize - 1)
            next_char = file.read(1)
            if next_char == b'\n':
                if ignore_last_line:
                    ignore_last_line = False
                else:
                    line = file.readline().decode()
                    current_id = int(line.split(':')[0])
                    if current_id <= last_id:
                        break
                    lines_found.append(line)
            filesize -= 1
        
        if len(lines_found) == 0:
            new_last_id = last_id
        else:
            new_last_id = int(lines_found[0].split(':')[0])
        return lines_found, new_last_id


class ActionSpace():
    def __init__(self, action_list) -> None:
        self.action_list = action_list
        self.n = len(action_list)
    
    def sample(self):
        action = np.random.randint(self.n)
        return action


class IndependentActionSpace():
    def __init__(self, independent_action_list) -> None:
        self.independent_action_list = independent_action_list
        self.dim = len(independent_action_list)
        self.choices_per_dim = 2
    
    def sample(self):
        action = np.random.randint(0, self.choices_per_dim, self.dim)
        return action

    def explain(self, action): # 把action解释成按键并返回
        action_list = []
        for i in range(self.dim):
            if action[i] == 1:
                action_list.append(self.independent_action_list[i])
        return "+".join(action_list)
    
    def place_holder_action(self):
        return np.zeros(self.dim, dtype=np.int32)

class DeathEventRecorder():
    boss_name_to_target_entities_dict = {
        "GodTamer": ["Lobster"],
        "HornetProtector": ["Hornet Boss 1"],
        "MegaMossCharger": ["Mega Moss Charger"],
        "MantisLords": ["Mantis Lord", "Mantis Lord S1", "Mantis Lord S2"],
        "MageLord": ["Mage Lord", "Mage Lord Phase2"],
        "Mawlek": ["Mawlek Body"],
        "HKPrime": ["HK Prime"]
    }

    def __init__(self, boss_name):
        self.boss_name = boss_name
        self.target_entities_dict: Dict[str, bool] = None # initialized in reset()

    def update(self, target_name) -> bool:
        '''
        Update the target entity status
        Return True if all target entities are killed -> episode terminate
        '''
        self.target_entities_dict[target_name] = True
        return all(self.target_entities_dict.values())

    def reset(self):
        self.target_entities_dict = {key: False for key in DeathEventRecorder.boss_name_to_target_entities_dict[self.boss_name]}


class EnemyHealthTracker():
    '''
    HollowKnight sometimes generates 2 identical health manager for a same object, like in MantisLords
    This will cause the hitting reward being doubled, so we need to track the health of the enimies mannually
    Maintaining this dict in C# Mods requires extra effort, so we do it here
    '''
    def __init__(self):
        self.health_dict = None # initialized in reset()
    
    def update(self, target_name, current_health, delta_health):
        '''
        filter if this delta health will contribute to the reward
        TODO: Currently can't handle spawing enemies like Fluke Mother 
        '''
        if target_name not in self.health_dict:
            self.health_dict[target_name] = current_health
            return delta_health
        else:
            if current_health == self.health_dict[target_name]:
                return 0
            else:
                self.health_dict[target_name] = current_health
                return delta_health

    def reset(self):
        self.health_dict = {}


@ray.remote(resources={"env_runner": 1})
class HKEnv():
    '''
    Hollow Knight Gym-like Environment
    Default observation space: (711, 1275, 3)
    '''
    def __init__(self, boss_name, obs_size=None, color_convert=True, target_fps=12) -> None:
        supprted_bosses = DeathEventRecorder.boss_name_to_target_entities_dict.keys()
        assert boss_name in supprted_bosses, f"Boss name not supported: {boss_name}"
        self.boss_name = boss_name

        self.death_event_recorder = DeathEventRecorder(boss_name)
        self.enemy_health_tracker = EnemyHealthTracker()

        # base attack is related to charm configuration, see:
        # https://hollowknight.fandom.com/wiki/Damage_Values_and_Enemy_Health_(Hollow_Knight)
        # TODO: automate this
        self.attack_normalize_factor = 32 
        
        # base damage to Knight is related to boss difficulty, level 1/2/3 -> 1/2/9999
        # TODO: some boss's base damage is 2
        # TODO: overchage would double the damage
        if boss_name == "HKPrime":
            self.damage_normalize_factor = 2
        else:
            self.damage_normalize_factor = 1

        # base healing is related to charm configuration
        self.healing_normalize_factor = 1

        self.health = None # initialized in reset()

        # action space
        self.independent_action_list = ["w", "a", "s", "d", "j", "k", "l", "i"]
        self.action_space = IndependentActionSpace(self.independent_action_list)

        # if the last episode is a win battle
        self.win_last_battle = False

        # game window
        self.game_window = get_game_window("Hollow Knight")
        if not self.game_window:
            assert False, f"Hollow Knight not found."
        self.sct = mss.mss()
        
        # log file
        self.log_file_path = os.getenv('APPDATA').replace('Roaming', 'LocalLow') + '\\Team Cherry\\Hollow Knight' + '\\custom_log.log'
        assert os.path.exists(self.log_file_path), f"Log file not found: {self.log_file_path}"

        # env config
        self.obs_size = obs_size
        self.color_convert = color_convert
        self.target_time_per_frame = 1.0 / target_fps

        # cache the last action, judge if the action is the same as last step
        # used in send_action()
        self.last_action = self.action_space.place_holder_action()
    
    def send_action(self, action, last_action):
        action_penalty = 0
        for i in range(len(action)):
            if action[i] != last_action[i]:
                if action[i] == 1:
                    keyboard.press(self.independent_action_list[i]) # 0 -> 1
                else:
                    keyboard.release(self.independent_action_list[i]) # 1 -> 0
                    if self.independent_action_list[i] == "j" or self.independent_action_list[i] == "i":
                        action_penalty += 0.01 # small penalty for attack/spell
        return action_penalty

    def release_action(self):
        for i in range(len(self.independent_action_list)):
            keyboard.release(self.independent_action_list[i])

    def get_last_log_step(self):
        last_lines = tail(self.log_file_path, 1)
        return int(last_lines[0].split(':')[0])

    def read_last_lines(self):
        last_lines, last_id = tail_until(self.log_file_path, self.last_log_step)
        self.last_log_step = last_id
        return last_lines

    def reset(self, first_run=False, wait=0): # wait should be 6, coverd by training time, should be specified in training script
        # Game interaction part >>>
        self.release_action()
        if not first_run:
            time.sleep(wait) # wait for game to load
            self.game_window.restore()
            self.game_window.activate()
            keyboard.press_and_release("w")
            time.sleep(3)
        self.game_window.restore()
        self.game_window.activate()
        keyboard.press_and_release("w")
        time.sleep(1)
        self.game_window.restore()
        self.game_window.activate()
        keyboard.press_and_release("k")
        time.sleep(5) # to the main scene
        if self.boss_name == "HKPrime":
            time.sleep(10.7) # extra 98 frames for HKPrime
            # TODO: parametrize this

        frame = capture_window_image(self.game_window, self.sct)
        # <<< Game interaction part

        # Environment part >>>
        self.last_log_step = self.get_last_log_step()
        self.health = 9
        self.epsisode_step = 0
        print(f"Resetting from LogID: {self.last_log_step}")

        if self.obs_size is not None:
            frame = cv2.resize(frame, self.obs_size)
        if self.color_convert:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        self.frame_timer = time.time()

        # reset last action cache
        self.last_action = self.action_space.place_holder_action()

        self.death_event_recorder.reset()
        self.enemy_health_tracker.reset()
        # <<< Environment part
            
        return frame,  {"episode_frame_number": self.epsisode_step}
    
    def precise_sleep(self, duration): # accurate sleep, use time.sleep in Ray is not accurate
        start_time = time.perf_counter()
        while (time.perf_counter() - start_time) < duration:
            pass
        
    def step(self, action):
        self.epsisode_step += 1

        # FPS control
        time_diff = time.time()-self.frame_timer
        if time_diff < self.target_time_per_frame:
            self.precise_sleep(self.target_time_per_frame - time_diff)

        self.frame_timer = time.time()

        frame = capture_window_image(self.game_window, self.sct)

        # # send action immediately after frame capture
        action_penalty = self.send_action(action, self.last_action)
        
        # read new lines in log file
        last_lines = self.read_last_lines()
        # calc reward
        # reward = -action_penalty
        reward = 0
        terminate = False
        win_battle = False
        life_loss = False
        for line in last_lines: # all the .strip() are removing "\r\n"
            if "[EnemyHealthChange]" in line: # hit enemy
                target_name, current_health, delta_health = line.split('-')[-3:] # TODO: handle negative cases
                current_health, delta_health = int(current_health.strip()), int(delta_health.strip())
                delta_health = self.enemy_health_tracker.update(target_name, current_health, delta_health)
                normalized_delta_health = delta_health / self.attack_normalize_factor
                reward += normalized_delta_health
            elif "[AfterTakeDamage]" in line: # hurt by enemy
                # 32: [AfterTakeDamage]-1-2 -> 2 damage
                amount = int(line.split('-')[-1].strip()) # base damage is 1
                normalized_amount = amount / self.damage_normalize_factor
                # reward -= 4 * normalized_amount
                life_loss = True
                self.health -= amount
            elif "[BeforeAddHealth]" in line: # heal
                amount = int(line.split('-')[-1].strip())
                normalized_amount = amount / self.healing_normalize_factor
                reward += 2 * normalized_amount
                self.health += amount
            elif "[OnReceiveDeathEvent]" in line: # kill enemy
                target_name = line.split('-')[-1].strip()
                all_defeated = self.death_event_recorder.update(target_name)
                if all_defeated:
                    win_battle = True
                    self.win_last_battle = True
                    terminate = True

        # judge living
        if self.health <= 0:
            self.win_last_battle = False
            terminate = True
        truncated = False

        if self.obs_size is not None:
            frame = cv2.resize(frame, self.obs_size)
        if self.color_convert:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        if terminate: # prevent meaningless action between episodes
            self.release_action()

        # update last action cache
        self.last_action = action
        return frame, reward, terminate, truncated, {"episode_frame_number": self.epsisode_step, "win_battle": win_battle, "health": self.health, "life_loss": life_loss}
    
    def pause(self):
        self.game_window.restore()
        self.game_window.activate()
        keyboard.press_and_release("esc")
    
    def resume(self):
        self.game_window.restore()
        self.game_window.activate()
        keyboard.press_and_release("esc")


class LocalAbstractHKEnv():
    def __init__(self, env):
        self.first_run = True
        self.env = env

    def release_action(self):
        self.env.release_action.remote()
    
    def reset(self):
        obs, info = ray.get(self.env.reset.remote(first_run=self.first_run))
        self.first_run = False
        return obs, info
    
    def step(self, action):
        return ray.get(self.env.step.remote(action))