import os, json, time, argparse, yaml


def convert_to_id(action: str, leagl_actions, env_name):
    for action_id, action_str in leagl_actions.items():
        if action == action_id or action == action_str:
            return action_id
        
    arrow_map = {
        "\x1b[A": "<UP>",
        "\x1b[B": "<DOWN>",
        "\x1b[D": "<LEFT>",
        "\x1b[C": "<RIGHT>",
        "S": "<STAY>",
        "I": "<INTERACT>"
    }
    key = action.upper()
    if key in arrow_map:
        for action_id, action_str in leagl_actions.items():
            if arrow_map[key] == action_str:
                return action_id

    print("Invalid action. Choose one of the listed actions.")
    return None


def main():
    parser = argparse.ArgumentParser(description="Human player client for interactive game")
    parser.add_argument("--player", type=int, required=True, help="Player index (e.g., 0 or 1)")
    args = parser.parse_args()
    player_id = args.player

    with open(f"configs/exp_configs/human.yaml", 'r') as f:
        config = yaml.safe_load(f)

    base_path = config['user_terminal_path']

    log_path = f"{base_path}/player_{player_id}_log.jsonl"
    jpg_path = f"{base_path}/player_{player_id}_latest.jpg"
    print(f"Human player {player_id} client started.")
    print(f"Please open '{jpg_path}' to view the game state animation.")

    # Wait for the log file to be created by the main process
    while not os.path.exists(log_path):
        print("Waiting for the game to start...")
        time.sleep(1)
    episode_count = 0
    last_line_index = 0

    # If the game already started and there are lines in the log, prepare to process them
    with open(log_path, "r") as f:
        lines = f.readlines()
        last_line_index = len(lines)
        # Check if the last line is an observation awaiting response (unanswered)
        if last_line_index > 0: #last_line_index = 0 means VLM didn't response yet
            try:
                last_entry = json.loads(lines[-1])
            except json.JSONDecodeError:
                last_entry = {}
            if "action" not in last_entry:
                # There's a pending observation line without a corresponding action
                last_line_index -= 1  # backtrack to process this line

    is_end = 0
    game_res = None
    try:
        # Continuous loop to read new log entries
        while True:
            with open(log_path, "r") as f:
                lines = f.readlines()
            # If no new lines, wait briefly and continue (VLM didn't response yet)
            if len(lines) <= last_line_index:
                time.sleep(0.5)
                continue
            
            # Process any new lines appended to the log
            while last_line_index < len(lines):
                line = lines[last_line_index].strip()
                last_line_index += 1
                if not line:
                    continue
                try:
                    data = json.loads(line)
                except json.JSONDecodeError:
                    continue

                if "game_end" in data:
                    is_end = 1
                    game_res = data['results']
                    break

                # If this line is a user action (echoed), skip displaying it
                if "action" in data:
                    continue

                if "episode_end" in data:
                    print(f"=== Episode {episode_count} End, returns = {data['env_info']}===\n")
                    time.sleep(3)
                    break
                
                # This line is an observation from the main process
                step = data.get("step", None)
                legal_actions = data.get("legal_actions", {})

                # If step is 1 (or not provided, implying a new episode start), increment episode count
                if step == 1 or (step is None and episode_count == 0):
                    episode_count += 1
                    print(f"\n=== Episode {episode_count} Start===")

                # Display the list of legal actions
                print(f"{data['rules']}")

                if data['env'] in ['battle_of_the_colors', 'coin_dilemma', 'monster_hunt']:
                    print("Legal actions: (可以用键盘上下左右键代替<UP>, <RIGHT>, <LEFT>, <DOWN>; 'S'代替<STAY>")
                elif data['env'] in ['atari_pong']:
                    print("Legal actions: (可以用键盘上下键代替<UP>, <DOWN>; 'S'代替<STAY>")
                if data['env'] in ['overcooked']:
                    print("Legal actions: (可以用键盘上下左右键代替<UP>, <RIGHT>, <LEFT>, <DOWN>; 'S'代替<STAY>, 'I'代替<INTERACT>)")
                elif data['env'] == 'breakthrough':
                    print("Legal actions: (也可以直接输入动作字符串，比如'ricj')")
                else:
                    print("Legal actions:")

                for act_id, desc in legal_actions.items():
                    print(f"  {act_id}: {desc}")                 

                # Prompt the human player for a valid action
                chosen_action = None
                if not data['is_begin']:                
                    while chosen_action is None:
                        action = input("Enter your action with action id: ").strip()
                        chosen_action = convert_to_id(action, legal_actions, data['env'])
                        print(f"Action {chosen_action} submitted. Waiting for the next state...")
                else:
                    print(f"Waiting for the next state...")
                    continue

                # Write the chosen action to the log file as a JSON line
                action_entry = {
                    "timestamp": time.time(),
                    "step": step,
                    "action": chosen_action
                }
                with open(log_path, "a") as f:
                    json.dump(action_entry, f)
                    f.write("\n")
            
            if is_end:
                print(f"\n=== Game End, average return = {game_res} ===")
                break

    except KeyboardInterrupt:
        print("\nHuman player client exiting.")
        return

if __name__ == "__main__":
    main()
