import argparse
import pathlib
import time

import numpy as np

import crafter

import os
import sys
import io
import cv2
import json
import time
import torch
import random
import datetime
import csv

# Add the project directory to the Python path
project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
print(f"Project path: {project_path}")
if project_path not in sys.path:
    sys.path.append(project_path)

import traceback
import itertools

# import numpy as np
import regex as re
import matplotlib.pyplot as plt
from openai import OpenAI  # Add OpenAI import

from PIL import Image
from glob import glob
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from huggingface_hub import login

from collections import deque
from transformers import AutoTokenizer
from lib.crafter_custom.crafter.language_wrapper import CrafterLanguageWrapper

from pytz import timezone


import lib.crafter_custom.crafter.random_init_wrapper as random_init_wrapper

try:
    import pygame
except ImportError:
    print("Please install the pygame package to use the GUI.")
    raise
from PIL import Image
import imageio  # Add imageio import for GIF creation

import logging

import src.system_prompts as system_prompts


def load_model(
    model_name="meta-llama/Llama-3.1-8B-Instruct",
    lora_path=None,
    gpu_num="0",  # for specifying GPU number
):
    # base model
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        use_auth_token=True,
        attn_implementation="sdpa",
    )
    # base_model = base_model.to("cuda:"+gpu_num)
    base_model.config.use_cache = False
    base_model.config.pretraining_tp = 1

    # Log model distribution across devices
    if hasattr(base_model, "hf_device_map"):
        logging.info("Model distribution across devices:")
        for module, device in base_model.hf_device_map.items():
            logging.info(f"{module}: {device}")

    # lora checkpoint
    if lora_path is None:
        print("\n\n===loading base model===\n\n")
        model = base_model
    else:
        print(f"\n\n===loading from {lora_path}===\n\n")
        lora_model = PeftModel.from_pretrained(base_model, lora_path)
        model = lora_model
    # model = model.float()
    # model = model.to(torch.float16)

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True

    return model, tokenizer


def create_history(
    obs, user_history, assistant_history, system_prompt, post_step_is_valid
):
    history = []
    history.append({"role": "system", "content": system_prompt})

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            invalid_action_message = (
                "\n\nThe previous action was invalid. Try a different one."
                if not post_step_is_valid
                else ""
            )
            history.append(
                {
                    "role": "user",
                    "content": "### Current Observation\n"
                    + user_history[i]
                    + "\n\n"
                    + """Only return a single valid action name (e.g. move_left, place_table, make_wood_pickaxe) with no additional explanation or formatting.
Always output one action per turn until the episode ends."""
                    + invalid_action_message,
                }
            )
        else:
            history.append(
                {
                    "role": "user",
                    "content": "### Previous Observation\n" + user_history[i],
                }
            )
            history.append({"role": "assistant", "content": assistant_history[i]})
    return history


def create_history_reasoning(obs, user_history, assistant_history, system_prompt):
    ### TODO : fix history structure
    history = []
    history.append({"role": "system", "content": system_prompt})

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            history.append(
                {
                    "role": "user",
                    "content": "### Current observation\n"
                    + user_history[i]
                    + """You are an AI agent following a step-by-step reasoning approach to decide on the best action to take at each step.
You must explicitly state your reasoning before providing an action.

Your output should be formatted as:
reasoning : [your step-by-step thought process] // You explicitly state your reasoning before providing an action.
action : [one action from the provided action list] // You always have to output one of the above actions at a time and no other text.
""",
                }
            )
        else:
            prefix_user = "### Observation_{num}\n".format(
                num=i - len(user_history) + 1
            )
            history.append({"role": "user", "content": prefix_user + user_history[i]})
            prefix_assistant = "### Action_{num}\n".format(
                num=i - len(user_history) + 1
            )
            history.append(
                {
                    "role": "assistant",
                    "content": prefix_assistant + assistant_history[i],
                }
            )

    return history


def generate_prompt_and_tokenize_llama(history, tokenizer):

    # generate prompt w. special tokens
    prompt = tokenizer.apply_chat_template(history, tokenize=False)
    prompt += "<|start_header_id|>assistant<|end_header_id|>"

    # tokenize the input
    model_inputs = tokenizer(
        [prompt],
        return_tensors="pt",
        max_length=2**13,  # 8K
        padding=True,
        truncation=True,
    ).to("cuda")
    tokenized_chat = model_inputs.input_ids
    # print("token len:", len(tokenized_chat[0]))
    return model_inputs, tokenized_chat


def generate_response_llama(model, tokenizer, history):
    model_inputs, tokenized_chat = generate_prompt_and_tokenize_llama(
        history, tokenizer
    )

    attention_mask = model_inputs["attention_mask"]
    with torch.cuda.amp.autocast():
        generated_ids = model.generate(
            tokenized_chat,
            do_sample=False,
            max_new_tokens=256,
            temperature=0.0,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=attention_mask,
        )

        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(tokenized_chat, generated_ids)
        ]

        raw_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
            0
        ]
        raw_response = raw_response.lower().strip()
    return raw_response


def get_response_openai(history, model_name="gpt-4o"):
    """
    Get a response from GPT-4o using the OpenAI API and check token counts.

    Args:
        history: List of dictionaries with 'role' and 'content' keys

    Returns:
        The model's response text
    """
    # Check if OpenAI API key is set
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "OPENAI_API_KEY environment variable is not set. Please set your OpenAI API key."
        )

    try:
        client = OpenAI(api_key=api_key)

        response = client.chat.completions.create(
            model=model_name, messages=history, temperature=0, max_tokens=256
        )

        print(f"Total tokens used: {response.usage.total_tokens}")

        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return "noop"  # Default to noop if API call fails


def filter_action(raw_response, action_str):
    action = 0
    try:
        action = action_str.index(raw_response)
    except ValueError:
        if "move" in raw_response:
            if "left" in raw_response or "west" in raw_response:
                action = action_str.index("move_left")
            elif "right" in raw_response or "east" in raw_response:
                action = action_str.index("move_right")
            elif "up" in raw_response or "north" in raw_response:
                action = action_str.index("move_up")
            elif "down" in raw_response or "south" in raw_response:
                action = action_str.index("move_down")
        elif "make" in raw_response:
            if "make" in raw_response and "pick" in raw_response:
                action = action_str.index("make_wood_pickaxe")
            elif "stone" in raw_response and "pick" in raw_response:
                action = action_str.index("make_stone_pickaxe")
            elif "iron" in raw_response and "pick" in raw_response:
                action = action_str.index("make_iron_pickaxe")
            elif "wood" in raw_response and "sword" in raw_response:
                action = action_str.index("make_wood_sword")
            elif "stone" in raw_response and "sword" in raw_response:
                action = action_str.index("make_stone_sword")
            elif "iron" in raw_response and "sword" in raw_response:
                action = action_str.index("make_iron_sword")
        elif "place" in raw_response:
            if "stone" in raw_response:
                action = action_str.index("place_stone")
            elif "table" in raw_response:
                action = action_str.index("place_table")
            elif "furnace" in raw_response:
                action = action_str.index("place_furnace")
            elif "plant" in raw_response:
                action = action_str.index("place_plant")
        elif raw_response.isdigit():
            action = int(raw_response)
        else:
            action = 0  # Default to noop if no match found
    text_action = action_str[action]
    return action, text_action


def extract_reasoning_and_action(raw_response):
    reasoning = raw_response.split("action :")[0].split("reasoning :")[1].strip()
    action = raw_response.split("action :")[1].strip()
    return reasoning, action


def check_action_validity_post_step(
    action,
    info_inventory,
    previous_info_inventory,
    previous_player_pos,
    current_player_pos,
    previous_player_facing,
    current_player_facing,
    previous_info_semantic,
    id_to_item,
):
    """
    Check if the action is valid based on post-step conditions.

    Args:
        action (str): The action taken
        info_inventory (dict): Current inventory
        previous_info_inventory (dict): Previous inventory
        current_player_pos (ndarray): Current player position
        previous_player_pos (ndarray): Previous player position

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True
    facing_object = ""
    if action:
        # Condition 1: Check if make/do actions actually changed inventory
        if "make" in action.lower() or action.lower() == "do":
            # Check if none of the values in info_inventory have increased (excluding health)
            has_increase = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key != "health"
                    and key in previous_info_inventory
                    and info_inventory[key] > previous_info_inventory[key]
                ):
                    has_increase = True
                    break
            if not has_increase:
                is_valid = False

            # Exception for do action when player facing creatures
            facing_object = id_to_item[
                previous_info_semantic[
                    int(previous_player_pos[0] + previous_player_facing[0]),
                    int(previous_player_pos[1] + previous_player_facing[1]),
                ]
            ]
            if action.lower() == "do" and facing_object in [
                "cow",
                "zombie",
                "skeleton",
            ]:
                is_valid = True

        if "place" in action.lower():
            # Check if none of the values in info_inventory have decreased (excluding status)
            has_decrease = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key not in ["health", "food", "drink", "energy"]
                    and key in previous_info_inventory
                    and info_inventory[key] < previous_info_inventory[key]
                ):
                    has_decrease = True
                    break
            if not has_decrease:
                is_valid = False

        # Condition 5: Check if move actions actually changed position
        if "move" in action.lower():
            if previous_player_pos is not None and current_player_pos is not None:
                if np.array_equal(
                    previous_player_pos, current_player_pos
                ) and np.array_equal(previous_player_facing, current_player_facing):
                    is_valid = False

    return is_valid


def check_loop(assistant_history_for_loop_detect, action_str, random_state):
    """
    Check if there's a loop in the action history.

    Args:
        assistant_history_for_loop_detect: A deque of recent actions
        action_str: List of available actions
        random_state: NumPy RandomState object for deterministic randomness

    Returns:
        is_in_loop: Boolean indicating if a loop is detected
        escape_loop_action: Action to take to escape the loop
    """
    actions = list(assistant_history_for_loop_detect)
    is_loop = False

    if (
        len(actions) < 6
    ):  # Need at least 6 actions to detect a loop of length 2 repeated 3 times
        return False, 0  # Default action if not enough history

    # Check for loops of length 2 to 6
    for loop_length in range(2, 7):
        # Need at least loop_length * 3 actions to detect a loop
        if len(actions) >= loop_length * 3:
            # Check if the last loop_length s* 3 actions form a repeating pattern
            last_actions = actions[-loop_length * 3 :]

            # Extract potential loop pattern (the last loop_length actions)
            pattern = actions[-loop_length:]

            # Check if this pattern repeats 3 or more times consecutively
            is_loop = True
            for i in range(3):
                if i == 0:
                    segment = last_actions[-loop_length:]
                elif i == 1:
                    segment = last_actions[-(i + 1) * loop_length : -(i) * loop_length]
                elif i == 2:
                    segment = last_actions[:loop_length]

                if segment != pattern:  # if i=0, always true
                    # logging.info(f"segment{i} : {segment} pattern : {pattern}")
                    is_loop = False
                    break

            if is_loop:
                # Loop detected, find an escape action
                all_move_actions = ["move_left", "move_right", "move_up", "move_down"]
                escape_actions = [
                    action for action in all_move_actions if action not in pattern
                ]

                # If all move actions are in the loop, choose a random one using the provided random state
                if not escape_actions:
                    escape_action = action_str.index(
                        random_state.choice(all_move_actions)
                    )
                else:
                    escape_action = action_str.index(
                        random_state.choice(escape_actions)
                    )

                return True, escape_action

    # No loop detected
    return False, 0  # Default action if no loop detected


def check_action_validity_pre_step(action, info):
    """
    Check if the action is valid based on pre-step conditions.

    Args:
        action (str): The action taken
        info (dict): Current game state info

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True

    if action:
        # Condition 1: Check if sleeping
        if info["sleeping"]:
            is_valid = False

        # Condition 2: Check if place_table has enough wood
        if action == "place_table" and info["inventory"]["wood"] < 2:
            is_valid = False

        # Condition 3: Check if place_stone has enough stone
        if action == "place_stone" and info["inventory"]["stone"] < 1:
            is_valid = False

        # Condition 4: Check if place_furnace has enough stone
        if action == "place_furnace" and info["inventory"]["stone"] < 4:
            is_valid = False

    if not is_valid:
        logging.info(f"Pre step invalid")

    return is_valid


def check_action_validity_post_step(
    action,
    info_inventory,
    previous_info_inventory,
    previous_player_pos,
    current_player_pos,
    previous_player_facing,
    current_player_facing,
    previous_info_semantic,
    id_to_item,
):
    """
    Check if the action is valid based on post-step conditions.

    Args:
        action (str): The action taken
        info_inventory (dict): Current inventory
        previous_info_inventory (dict): Previous inventory
        current_player_pos (ndarray): Current player position
        previous_player_pos (ndarray): Previous player position

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True
    facing_object = ""
    if action:
        # Condition 1: Check if make/do actions actually changed inventory
        if "make" in action.lower() or action.lower() == "do":
            has_increase = False
            for key in info_inventory:
                if (
                    key != "health"
                    and key in previous_info_inventory
                    and info_inventory[key] > previous_info_inventory[key]
                ):
                    has_increase = True
                    break
            if not has_increase:
                is_valid = False

            # Exception for do action when player facing creatures
            target_x = int(previous_player_pos[0] + previous_player_facing[0])
            target_y = int(previous_player_pos[1] + previous_player_facing[1])
            if (
                0 <= target_x < previous_info_semantic.shape[0]
                and 0 <= target_y < previous_info_semantic.shape[1]
            ):
                facing_object = id_to_item[previous_info_semantic[target_x, target_y]]
                if action.lower() == "do" and facing_object in [
                    "cow",
                    "zombie",
                    "skeleton",
                ]:
                    is_valid = True
            else:
                is_valid = False

        if "place" in action.lower():
            has_decrease = False
            for key in info_inventory:
                if (
                    key not in ["health", "food", "drink", "energy"]
                    and key in previous_info_inventory
                    and info_inventory[key] < previous_info_inventory[key]
                ):
                    has_decrease = True
                    break
            if not has_decrease:
                is_valid = False

        # Condition 5: Check if move actions actually changed position
        if "move" in action.lower():
            if previous_player_pos is not None and current_player_pos is not None:
                if np.array_equal(
                    previous_player_pos, current_player_pos
                ) and np.array_equal(previous_player_facing, current_player_facing):
                    is_valid = False

    return is_valid


def post_step_update_csv(
    env, step, is_valid, post_step_is_valid, status_change, achievement_change
):
    """
    Update the validity of the last written row in the CSV file.
    """
    if hasattr(env, "csv_file"):
        # Get filename from the current file handle
        filename = env.csv_file.name

        # Close the current file in append mode
        env.csv_file.close()

        # Reopen in read mode first to read content
        with open(filename, "r", newline="") as file:
            rows = list(csv.DictReader(file))

        if rows:  # Only proceed if there are rows
            # Update the last row's validity
            if is_valid and not post_step_is_valid:
                rows[-1]["valid"] = post_step_is_valid
            rows[-1]["status_change"] = status_change
            rows[-1]["achievement_change"] = achievement_change
            # Reopen in write mode to update the entire file
            with open(filename, "w", newline="") as file:
                writer = csv.DictWriter(file, fieldnames=rows[0].keys())
                writer.writeheader()
                writer.writerows(rows)

        # Reopen the file in append mode for future writes
        env.csv_file = open(filename, "a", newline="")
        env.csv_writer = csv.DictWriter(env.csv_file, fieldnames=list(rows[0].keys()))


def save_data_csv(
    env,
    step,
    history,
    raw_response,
    reasoning,
    action,
    run_timestamp,
    log_name,
    status_change,
    observation_goal,
    player_pos,
    player_facing,
    action_goal,
    is_valid,
    is_sleeping,
):
    # Format row data
    row = {
        "step": step,
        "history": history,
        "raw_response": raw_response,
        "action": action,
        "reasoning": reasoning,
        "observation_goal": observation_goal,
        "action_goal": action_goal,
        "player_pos": player_pos,
        "player_facing": player_facing,
        "status_change": "None",  # Change after post-step change check is completed
        "achievement_change": "None",  # Change after post-step change check is completed
        "valid": is_valid,
        "is_sleeping": is_sleeping,
    }

    # Write to CSV if file is not open yet
    if not hasattr(env, "csv_file"):
        filename = f"eval/csv/{log_name}_{run_timestamp}.csv"

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        # Open file in append mode
        env.csv_file = open(filename, "a", newline="")
        env.csv_writer = csv.DictWriter(env.csv_file, fieldnames=list(row.keys()))

        # Write header only if file is empty (new file)
        if env.csv_file.tell() == 0:
            env.csv_writer.writeheader()

    # Write row and flush immediately
    env.csv_writer.writerow(row)
    env.csv_file.flush()

    return env


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--area", nargs=2, type=int, default=(64, 64))
    parser.add_argument("--length", type=int, default=10000)
    parser.add_argument("--health", type=int, default=9)
    parser.add_argument("--record", type=pathlib.Path, default=None)
    parser.add_argument("--episodes", type=int, default=1)
    parser.add_argument("--size", type=int, nargs=2, default=(0, 0))
    parser.add_argument("--window", type=int, nargs=2, default=(600, 600))
    parser.add_argument(
        "--gif_dir", type=str, default="eval/gifs"
    )  # Add argument for GIF directory
    parser.add_argument(
        "--system_prompt",
        type=str,
        default="eval_benchmark",
        choices=list(system_prompts.SYSTEM_PROMPTS.keys()),
        help="System prompt to use for the agent",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="meta-llama/Llama-3.1-8B-Instruct",  # meta-llama/Llama-3.1-8B-Instruct, Qwen/Qwen2.5-7B-Instruct
        help="Model name to use for inference",
    )
    parser.add_argument(
        "--lora_path", type=str, default=None, help="Path to LoRA weights"
    )
    parser.add_argument(
        "--log_name",
        type=str,
        default=None,
        help="Log file name. Defaults to current date and time if not provided.",
    )
    parser.add_argument(
        "--random_init",
        type=bool,
        default=False,
        help="Enable random initialization of the environment.",
    )

    args = parser.parse_args()

    # Set default log file name if not provided
    if args.log_name is None:
        args.log_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S.log")

    # Create GIF directory if it doesn't exist
    os.makedirs(args.gif_dir, exist_ok=True)
    os.makedirs("eval/logs", exist_ok=True)
    # os.makedirs("history_goal", exist_ok=True)

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level to INFO instead of DEBUG
        format="%(asctime)s - %(levelname)s - %(message)s",  # Log format
        handlers=[
            logging.FileHandler("eval/logs/" + args.log_name),  # Log to a file
            logging.StreamHandler(),  # Also log to console
        ],
    )

    # Create a single timestamp for the entire run
    run_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Get the selected system prompt
    system_prompt = system_prompts.SYSTEM_PROMPTS[args.system_prompt]
    logging.info(f"Using system prompt: {args.system_prompt}")

    action_str = [
        "noop",
        "move_left",
        "move_right",
        "move_up",
        "move_down",
        "do",
        "sleep",
        "place_stone",
        "place_table",
        "place_furnace",
        "place_plant",
        "make_wood_pickaxe",
        "make_stone_pickaxe",
        "make_iron_pickaxe",
        "make_wood_sword",
        "make_stone_sword",
        "make_iron_sword",
    ]

    # init pygame
    pygame.init()
    screen = pygame.display.set_mode(args.window)
    running = True

    size = list(args.size)
    size[0] = size[0] or args.window[0]
    size[1] = size[1] or args.window[1]

    # Create id_to_item mapping
    id_to_item = [0] * 19
    dummyenv = crafter.Env()
    for name, ind in itertools.chain(
        dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
    ):
        name = (
            str(name)[str(name).find("objects.") + len("objects.") : -2].lower()
            if "objects." in str(name)
            else str(name)
        )
        id_to_item[ind] = name

    ### load finetuned model
    if args.lora_path:
        finetuned_model, tokenizer = load_model(
            model_name=args.model_name, lora_path=args.lora_path
        )
    ### load pretrained llama model
    elif "llama" in args.model_name or "Qwen" in args.model_name:
        pretrained_model, tokenizer = load_model(
            model_name=args.model_name, lora_path=args.lora_path
        )

    # List to track achievement number across episodes
    total_achievements_num = []
    # Dictionary to track achievements across episodes
    achievement_counts = {}

    # Initialize status and achievement changes list
    status_changes = []
    achievement_changes = []

    seed_list = [42 + 10 * i for i in range(args.episodes)]

    for episode in range(args.episodes):
        frames = []

        random_state = np.random.RandomState(args.seed)
        crafter.constants.items["health"]["max"] = args.health
        crafter.constants.items["health"]["initial"] = args.health
        env = crafter.Env(area=args.area, length=args.length, seed=seed_list[episode])
        env = crafter.Recorder(env, args.record)
        env = CrafterLanguageWrapper(env)
        logging.info(f"### Episode {episode} | seed: {seed_list[episode]}")

        obs, reward, done, info = env.reset()
        done = False

        # Initialize FIFO queues for history of observations and actions
        user_history = deque(maxlen=2)
        assistant_history = deque(maxlen=1)
        assistant_history_for_loop_detect = deque(maxlen=20)

        # Initialize previous inventory state and player position
        status_change = []
        previous_inventory = info["inventory"]
        previous_player_pos = info["player_pos"]
        previous_player_facing = info["player_facing"]
        previous_achievement = info["achievements"]
        previous_info_semantic = info["semantic"]
        post_step_is_valid = True
        is_in_loop = False
        escape_loop_action = 0

        step = 0
        achievements = set()
        while True:

            # Rendering
            image = env.render(size)
            if size != args.window:
                image = Image.fromarray(image)
                image = image.resize(args.window, resample=Image.NEAREST)
                image = np.array(image)

            # Add frame to GIF frames list
            frames.append(image.copy())

            if done or step > 1000:
                break

            # update user history
            user_history.append(
                obs["text"]["short_term_context"]
                + "\n\n"
                + obs["text"]["long_term_context"]
            )
            user_history_goal = (
                obs["text"]["long_term_context"]
                + "\n"
                + obs["text"]["short_term_context"]
            )

            ##### 1. generate prompt for current episode w/ history queue #####
            history = create_history(
                obs, user_history, assistant_history, system_prompt, post_step_is_valid
            )

            ##### 1. generate prompt w/ reasoning for current episode w/ history queue #####
            # history = create_history_reasoning(obs, user_history, assistant_history, system_prompt)

            ##### 2. generate action w. model #####
            if "llama" in args.model_name or "Qwen" in args.model_name:
                if args.lora_path:
                    raw_response = generate_response_llama(
                        finetuned_model, tokenizer, history
                    )
                else:
                    raw_response = generate_response_llama(
                        pretrained_model, tokenizer, history
                    )

            ##### 2. generate action w. openai #####
            elif "gpt" in args.model_name:
                try:
                    raw_response = get_response_openai(
                        history, model_name=args.model_name
                    )
                except ValueError as e:
                    logging.error(f"Error calling OpenAI API: {e}")
                    logging.info("Falling back to noop action")
                    raw_response = "noop"

            #### 2-1. extract reasoning and action from raw response #####
            reasoning = ""
            # reasoning, raw_response = extract_reasoning_and_action(raw_response)
            # print(f"step: {step} reasoning : {reasoning}")
            # print(f"step: {step} action : {text_action}")

            ##### 2-2.random action for debugging #####
            # action = random.randint(0, env.action_space.n)

            ##### 3. Handle cases where raw_response isn't an exact match #####
            action, text_action = filter_action(raw_response, action_str)

            ##### 3-1. Execute random action with 0.1 probability (random integer between 1 and 16) #####
            # if random.random() < 0.1:
            #     action = random.randint(0, 16)  # Random integer from 1 to 16 (except noop)
            #     text_action = action_str[action]
            #     logging.info(f"Random action selected: {text_action}")

            ##### 3-2. Execute random move action when fall in loop #####
            if is_in_loop:
                action = escape_loop_action
                is_in_loop = False
                logging.info(f"!! Escape loop !!")

            logging.info(
                f"step{step:<4} raw response : {raw_response:<30} | action : {action_str[action]:<20}"
            )

            # update action history
            assistant_history.append(action_str[action])
            assistant_history_goal = action_str[action]
            assistant_history_for_loop_detect.append(action_str[action])

            # Check pre-step validity
            is_valid = (
                True if step == 0 else check_action_validity_pre_step(text_action, info)
            )

            ##### 4. Save data to csv with status change and goals #####
            save_data_csv(
                env,
                step,
                history,
                raw_response,
                reasoning,
                text_action,
                run_timestamp,
                args.log_name,
                status_change,
                user_history_goal,
                info["player_pos"],
                info["player_facing"],
                assistant_history_goal,
                is_valid,
                info["sleeping"],
            )

            ##### 5. environment step #####
            obs, reward, done, info = env.step(action)
            step += 1

            ##### 6. Identify status change #####
            # Calculate changes in inventory, excluding 'health'
            current_inventory = info["inventory"]
            status_change_dict = (
                {
                    key: current_inventory[key] - previous_inventory.get(key, 0)
                    for key in current_inventory
                }
                if previous_inventory
                else {}
            )
            status_change = [
                key
                for key, change in status_change_dict.items()
                if (key != "health" and change > 0)
                or (key not in ["health", "food", "drink", "energy"] and change < 0)
            ]

            status_changes.append(status_change if status_change else None)

            ##### 7. Identify achievement change #####
            # Calculate changes in inventory, excluding 'health'
            current_achievement = info["achievements"]
            achievement_change_dict = (
                {
                    key: current_achievement[key] - previous_achievement.get(key, 0)
                    for key in current_achievement
                }
                if previous_achievement
                else {}
            )
            achievement_change = [
                key for key, change in achievement_change_dict.items() if change > 0
            ]

            achievement_changes.append(
                achievement_change if achievement_change else None
            )

            ##### 8. Check post-step validity #####
            post_step_is_valid = check_action_validity_post_step(
                text_action,
                info["inventory"],
                previous_inventory,
                previous_player_pos,
                info["player_pos"],
                previous_player_facing,
                info["player_facing"],
                previous_info_semantic,
                id_to_item,
            )
            if not post_step_is_valid:
                logging.info(f"step{step-1} : invalid action")

            # Update CSV if validity changed
            post_step_update_csv(
                env,
                step,
                is_valid,
                post_step_is_valid,
                status_change,
                achievement_change,
            )

            # Check if loop is detected
            is_in_loop, escape_loop_action = check_loop(
                assistant_history_for_loop_detect, action_str, random_state
            )

            # Check achievements
            unlocked = {
                name
                for name, count in info["achievements"].items()
                if count > 0 and name not in achievements
            }

            for name in unlocked:
                achievements |= unlocked
                total = len(info["achievements"].items())
                logging.info(f"Achievement ({len(achievements)}/{total}): {name}")

            # Update previous inventory state
            previous_inventory = info["inventory"]
            previous_achievement = info["achievements"]
            previous_player_pos = info["player_pos"]
            previous_info_semantic = info["semantic"]
            previous_player_facing = info["player_facing"]

        # Create GIF from frames
        gif_path = os.path.join(
            args.gif_dir, f"episode_{args.log_name}_{episode}_{run_timestamp}.gif"
        )
        imageio.mimsave(gif_path, frames, fps=5, loop=True, palettesize=64)
        logging.info(f"GIF saved to {gif_path}")

        # print total achievements
        total = len(info["achievements"].items())
        logging.info(f"Episode{episode} Achievements: {achievements}")
        logging.info(
            f"Episode{episode} Achievement ratio ({len(achievements)}/{total}) = {len(achievements)/total}"
        )
        total_achievements_num.append(len(achievements))

        # Update achievement counts
        achievement_counts.update(
            {a: achievement_counts.get(a, 0) + 1 for a in achievements}
        )

        logging.info(f"Episode{episode} length: {step}")

    # Print achievement counts and percentages across all episodes
    logging.info("\nAchievement statistics across all episodes:")
    for achievement, count in achievement_counts.items():
        percentage = (count / args.episodes) * 100
        logging.info(
            f"{achievement}: {count}/{args.episodes} episodes ({percentage:.1f}%)"
        )

    total = len(info["achievements"].items())
    logging.info(
        f"Average achievement ratio: {np.mean(total_achievements_num)/total:.3f} ± {np.std(total_achievements_num)/total:.3f}"
    )
    pygame.quit()


if __name__ == "__main__":
    main()
