import argparse
import pathlib
import time

import numpy as np

import crafter

import os
import sys
import io
import cv2
import json
import time
import torch
import random
import datetime
import csv

# Add the project directory to the Python path
project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
print(f"Project path: {project_path}")
if project_path not in sys.path:
    sys.path.append(project_path)

import traceback
import itertools

# import numpy as np
import regex as re
import matplotlib.pyplot as plt
from openai import OpenAI  # Add OpenAI import

from PIL import Image
from glob import glob
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from huggingface_hub import login

from collections import deque
from transformers import AutoTokenizer
from lib.crafter_custom.crafter.language_wrapper import CrafterLanguageWrapper

from pytz import timezone

import lib.crafter_custom.crafter.random_init_wrapper as random_init_wrapper
from random_init_wrapper import RandomInitWrapper

try:
    import pygame
except ImportError:
    print("Please install the pygame package to use the GUI.")
    raise
from PIL import Image
import imageio  # Add imageio import for GIF creation

import logging

import src.system_prompts as system_prompts


def load_model(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    lora_path=None,
    gpu_num="0",
):
    # base model
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_auth_token=True,
        attn_implementation="sdpa",
    )
    # base_model = base_model.to("cuda:"+gpu_num)
    base_model.config.use_cache = False
    base_model.config.pretraining_tp = 1

    # Log model distribution across devices
    if hasattr(base_model, "hf_device_map"):
        logging.info("Model distribution across devices:")
        for module, device in base_model.hf_device_map.items():
            logging.info(f"{module}: {device}")

    # lora checkpoint
    if lora_path is None:
        print("\n\n===loading base model===\n\n")
        model = base_model
    else:
        print(f"\n\n===loading from {lora_path}===\n\n")
        lora_model = PeftModel.from_pretrained(base_model, lora_path)
        model = lora_model
    # model = model.float()
    model = model.to(torch.float16)

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True

    return model, tokenizer


def create_history(obs, user_history, assistant_history, system_prompt):
    history = []
    history.append({"role": "system", "content": system_prompt})

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            history.append(
                {
                    "role": "user",
                    "content": "### Current Observation\n"
                    + user_history[i]
                    + "\n\n"
                    + """Only return a single valid action name (e.g. move_left, place_table, make_wood_pickaxe) with no additional explanation or formatting.
Always output one action per turn until the episode ends.""",
                }
            )
        else:
            history.append(
                {
                    "role": "user",
                    "content": "### Previous Observation\n" + user_history[i],
                }
            )
            history.append({"role": "assistant", "content": assistant_history[i]})
    return history


def create_history_reasoning(
    obs, user_history, assistant_history, system_prompt, wrong_action=False
):
    history = []
    history.append({"role": "system", "content": system_prompt})

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            content = (
                "Current Observation:\n"
                + obs["text"]["short_term_context"]
                + "\n"
                + user_history[i]
                + "\n\n"
            )
            content += """You are an AI agent following a step-by-step reasoning approach to decide on the best action to take at each step.
You must explicitly state your reasoning before providing an action.
Your output should be formatted as:
reasoning : [your step-by-step thought process]
action : [one action from the provided action list]
You always have to output one of the above actions at a time and no other text.
You always have to output an action until the episode terminates."""
            if wrong_action:
                content += "\n\nYou must output a wrong action at this step. You always have to output one of the above actions at a time and no other text"
            history.append({"role": "user", "content": content})
        else:
            history.append(
                {"role": "user", "content": "Observation:\n" + user_history[i]}
            )
            history.append({"role": "assistant", "content": assistant_history[i]})
    return history


def create_prompt_without_special_tokens(history):
    prompt = ""
    for message in history:
        if message["role"] == "system":
            prompt += f"System: {message['content']}\n"
        elif message["role"] == "user":
            prompt += f"User: {message['content']}\n"
        elif message["role"] == "assistant":
            prompt += f"Assistant: {message['content']}\n"
    return prompt


def generate_prompt_and_tokenize_llama(history, tokenizer):
    # generate prompt w.o special tokens
    # prompt = create_prompt_without_special_tokens(history)

    # generate prompt w. special tokens
    prompt = tokenizer.apply_chat_template(history, tokenize=False)
    prompt += "<|start_header_id|>assistant<|end_header_id|>"

    # tokenize the input
    model_inputs = tokenizer(
        [prompt],
        return_tensors="pt",
        max_length=2**13,  # 8K
        padding=True,
        truncation=True,
    ).to("cuda")
    tokenized_chat = model_inputs.input_ids
    # print("token len:", len(tokenized_chat[0]))
    return model_inputs, tokenized_chat


def generate_response_llama(model, tokenizer, history):
    model_inputs, tokenized_chat = generate_prompt_and_tokenize_llama(
        history, tokenizer
    )

    attention_mask = model_inputs["attention_mask"]
    with torch.cuda.amp.autocast():
        generated_ids = model.generate(
            tokenized_chat,
            do_sample=False,
            max_new_tokens=256,
            # temperature=0.1,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=attention_mask,
        )

        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(tokenized_chat, generated_ids)
        ]

        raw_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
            0
        ]
        raw_response = raw_response.lower().strip()
    return raw_response


def get_response_openai(history, model_name="gpt-4o"):
    """
    Get a response from GPT-4o using the OpenAI API and check token counts.

    Args:
        history: List of dictionaries with 'role' and 'content' keys

    Returns:
        The model's response text
    """
    # Check if OpenAI API key is set
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "OPENAI_API_KEY environment variable is not set. Please set your OpenAI API key."
        )

    try:
        client = OpenAI(api_key=api_key)

        response = client.chat.completions.create(
            model=model_name, messages=history, temperature=0, max_tokens=256
        )

        # Print input tokens, output tokens, and total token count
        print(f"Input tokens: {response.usage.prompt_tokens}")
        print(f"Output tokens: {response.usage.completion_tokens}")
        print(f"Total tokens used: {response.usage.total_tokens}")

        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return "noop"  # Default to noop if API call fails


def filter_action(raw_response, action_str):
    action = 0
    try:
        action = action_str.index(raw_response)
    except ValueError:
        if "move" in raw_response:
            if "left" in raw_response or "west" in raw_response:
                action = action_str.index("move_left")
            elif "right" in raw_response or "east" in raw_response:
                action = action_str.index("move_right")
            elif "up" in raw_response or "north" in raw_response:
                action = action_str.index("move_up")
            elif "down" in raw_response or "south" in raw_response:
                action = action_str.index("move_down")
        elif "make" in raw_response:
            if "make" in raw_response and "pick" in raw_response:
                action = action_str.index("make_wood_pickaxe")
            elif "stone" in raw_response and "pick" in raw_response:
                action = action_str.index("make_stone_pickaxe")
            elif "iron" in raw_response and "pick" in raw_response:
                action = action_str.index("make_iron_pickaxe")
            elif "wood" in raw_response and "sword" in raw_response:
                action = action_str.index("make_wood_sword")
            elif "stone" in raw_response and "sword" in raw_response:
                action = action_str.index("make_stone_sword")
            elif "iron" in raw_response and "sword" in raw_response:
                action = action_str.index("make_iron_sword")
        elif "place" in raw_response:
            if "stone" in raw_response:
                action = action_str.index("place_stone")
            elif "table" in raw_response:
                action = action_str.index("place_table")
            elif "furnace" in raw_response:
                action = action_str.index("place_furnace")
            elif "plant" in raw_response:
                action = action_str.index("place_plant")
        elif raw_response.isdigit():
            action = int(raw_response)
        else:
            action = 0  # Default to noop if no match found
    text_action = action_str[action]
    return action, text_action


def extract_reasoning_and_action(raw_response):
    reasoning = raw_response.split("action :")[0].split("reasoning :")[1].strip()
    action = raw_response.split("action :")[1].strip()
    return reasoning, action


def check_action_validity_pre_step(action, info):
    """
    Check if the action is valid based on pre-step conditions.

    Args:
        action (str): The action taken
        info (dict): Current game state info

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True

    if action:
        # Condition 1: Check if sleeping
        if info["sleeping"]:
            is_valid = False

        # Condition 2: Check if place_table has enough wood
        if action == "place_table" and info["inventory"]["wood"] < 2:
            is_valid = False

        # Condition 3: Check if place_stone has enough stone
        if action == "place_stone" and info["inventory"]["stone"] < 1:
            is_valid = False

        # Condition 4: Check if place_furnace has enough stone
        if action == "place_furnace" and info["inventory"]["stone"] < 4:
            is_valid = False

    if not is_valid:
        logging.info(f"Pre step invalid")

    return is_valid


def check_action_validity_post_step(
    action,
    info_inventory,
    previous_info_inventory,
    previous_player_pos,
    current_player_pos,
    previous_player_facing,
    current_player_facing,
    previous_info_semantic,
    id_to_item,
):
    """
    Check if the action is valid based on post-step conditions.

    Args:
        action (str): The action taken
        info_inventory (dict): Current inventory
        previous_info_inventory (dict): Previous inventory
        current_player_pos (ndarray): Current player position
        previous_player_pos (ndarray): Previous player position

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True
    facing_object = ""
    if action:
        # Condition 1: Check if make/do actions actually changed inventory
        if "make" in action.lower() or action.lower() == "do":
            # Check if none of the values in info_inventory have increased (excluding health)
            has_increase = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key != "health"
                    and key in previous_info_inventory
                    and info_inventory[key] > previous_info_inventory[key]
                ):
                    has_increase = True
                    break
            if not has_increase:
                is_valid = False

            # Exception for do action when player facing creatures
            facing_object = id_to_item[
                previous_info_semantic[
                    int(previous_player_pos[0] + previous_player_facing[0]),
                    int(previous_player_pos[1] + previous_player_facing[1]),
                ]
            ]
            if action.lower() == "do" and facing_object in [
                "cow",
                "zombie",
                "skeleton",
            ]:
                is_valid = True

        if "place" in action.lower():
            # Check if none of the values in info_inventory have decreased (excluding status)
            has_decrease = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key not in ["health", "food", "drink", "energy"]
                    and key in previous_info_inventory
                    and info_inventory[key] < previous_info_inventory[key]
                ):
                    has_decrease = True
                    break
            if not has_decrease:
                is_valid = False

        # Condition 5: Check if move actions actually changed position
        if "move" in action.lower():
            if previous_player_pos is not None and current_player_pos is not None:
                if np.array_equal(
                    previous_player_pos, current_player_pos
                ) and np.array_equal(previous_player_facing, current_player_facing):
                    is_valid = False

    return is_valid


def post_step_update_csv(
    env, step, is_valid, post_step_is_valid, status_change, achievement_change
):
    """
    Update the validity of the last written row in the CSV file.
    """
    if hasattr(env, "csv_file"):
        # Get filename from the current file handle
        filename = env.csv_file.name

        # Close the current file in append mode
        env.csv_file.close()

        # Reopen in read mode first to read content
        with open(filename, "r", newline="") as file:
            rows = list(csv.DictReader(file))

        if rows:  # Only proceed if there are rows
            # Update the last row's validity
            if is_valid and not post_step_is_valid:
                rows[-1]["valid"] = post_step_is_valid
            rows[-1]["status_change"] = status_change
            rows[-1]["achievement_change"] = achievement_change
            # Reopen in write mode to update the entire file
            with open(filename, "w", newline="") as file:
                writer = csv.DictWriter(file, fieldnames=rows[0].keys())
                writer.writeheader()
                writer.writerows(rows)

        # Reopen the file in append mode for future writes
        env.csv_file = open(filename, "a", newline="")
        env.csv_writer = csv.DictWriter(env.csv_file, fieldnames=list(rows[0].keys()))


def save_data_csv(
    env,
    step,
    history,
    raw_response,
    reasoning,
    action,
    run_timestamp,
    log_name,
    status_change,
    observation_goal,
    player_pos,
    player_facing,
    action_goal,
    is_valid,
    is_sleeping,
):
    # Format row data
    row = {
        "step": step,
        "history": history,
        "raw_response": raw_response,
        "action": action,
        "reasoning": reasoning,
        "observation_goal": observation_goal,
        "action_goal": action_goal,
        "player_pos": player_pos,
        "player_facing": player_facing,
        "status_change": "None",  # Change after post-step change check is completed
        "achievement_change": "None",  # Change after post-step change check is completed
        "valid": is_valid,
        "is_sleeping": is_sleeping,
    }

    # Write to CSV if file is not open yet
    if not hasattr(env, "csv_file"):
        filename = f"log/data_collection/csv/{log_name}_{run_timestamp}.csv"

        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        # Open file in append mode
        env.csv_file = open(filename, "a", newline="")
        env.csv_writer = csv.DictWriter(env.csv_file, fieldnames=list(row.keys()))

        # Write header only if file is empty (new file)
        if env.csv_file.tell() == 0:
            env.csv_writer.writeheader()

    # Write row and flush immediately
    env.csv_writer.writerow(row)
    env.csv_file.flush()

    return env


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--area", nargs=2, type=int, default=(64, 64))
    parser.add_argument("--length", type=int, default=10000)
    parser.add_argument("--health", type=int, default=9)
    parser.add_argument("--record", type=pathlib.Path, default=None)
    parser.add_argument("--episodes", type=int, default=1)
    parser.add_argument("--size", type=int, nargs=2, default=(0, 0))
    parser.add_argument("--window", type=int, nargs=2, default=(600, 600))
    parser.add_argument(
        "--gif_dir", type=str, default="collect_demo/gifs"
    )  # Add argument for GIF directory
    parser.add_argument(
        "--system_prompt",
        type=str,
        default="gpt",
        choices=list(system_prompts.SYSTEM_PROMPTS.keys()),
        help="System prompt to use for the agent",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="gpt-4o",  # meta-llama/Llama-3.2-3B-Instruct or meta-llama/Llama-3.2-11B-Vision-Instruct
        help="Model name to use for inference",
    )
    parser.add_argument(
        "--lora_path", type=str, default=None, help="Path to LoRA weights"
    )
    parser.add_argument(
        "--log_name",
        type=str,
        default=None,
        help="Log file name. Defaults to current date and time if not provided.",
    )
    parser.add_argument(
        "--random_init",
        type=bool,
        default=False,
        help="Enable random initialization of the environment.",
    )

    args = parser.parse_args()

    # Create GIF directory if it doesn't exist
    os.makedirs(args.gif_dir, exist_ok=True)
    os.makedirs("collect_demo/logs", exist_ok=True)
    os.makedirs("history_goal", exist_ok=True)

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,  # Set the logging level
        format="%(asctime)s - %(levelname)s - %(message)s",  # Log format
        handlers=[
            logging.FileHandler("collect_demo/logs/" + args.log_name),  # Log to a file
            logging.StreamHandler(),  # Also log to console
        ],
    )

    # Create a single timestamp for the entire run
    run_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create id_to_item mapping
    id_to_item = [0] * 19
    dummyenv = crafter.Env()
    for name, ind in itertools.chain(
        dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
    ):
        name = (
            str(name)[str(name).find("objects.") + len("objects.") : -2].lower()
            if "objects." in str(name)
            else str(name)
        )
        id_to_item[ind] = name

    # Get the selected system prompt
    system_prompt = system_prompts.SYSTEM_PROMPTS[args.system_prompt]
    logging.info(f"Using system prompt: {args.system_prompt}")

    random = np.random.RandomState(args.seed)
    crafter.constants.items["health"]["max"] = args.health
    crafter.constants.items["health"]["initial"] = args.health
    env = crafter.Env(area=args.area, length=args.length, seed=args.seed)
    env = crafter.Recorder(env, args.record)
    if args.random_init:
        env = RandomInitWrapper(env)

    env = CrafterLanguageWrapper(env)

    # unwrapping for private attributes
    env._world = env.unwrapped._world
    # env._player = env.unwrapped._player
    # env._step = env.unwrapped._step

    action_str = [
        "noop",
        "move_left",
        "move_right",
        "move_up",
        "move_down",
        "do",
        "sleep",
        "place_stone",
        "place_table",
        "place_furnace",
        "place_plant",
        "make_wood_pickaxe",
        "make_stone_pickaxe",
        "make_iron_pickaxe",
        "make_wood_sword",
        "make_stone_sword",
        "make_iron_sword",
    ]

    # init pygame
    pygame.init()
    screen = pygame.display.set_mode(args.window)
    running = True

    size = list(args.size)
    size[0] = size[0] or args.window[0]
    size[1] = size[1] or args.window[1]

    ### load llama model
    if "llama" in args.model_name.lower() or "qwen" in args.model_name.lower():
        pretrained_model, tokenizer = load_model(
            model_name=args.model_name, lora_path=args.lora_path
        )
        if args.lora_path:
            finetuned_model, _ = load_model(
                model_name=args.model_name, lora_path=args.lora_path
            )

    # List to track achievement number across episodes
    total_achievements_num = []
    # Dictionary to track achievements across episodes
    achievement_counts = {}

    # Initialize status and achievement changes list
    status_changes = []
    achievement_changes = []

    for episode in range(args.episodes):
        # List to store frames for GIF
        frames = []

        start = time.time()
        obs, reward, done, info = env.reset()
        done = False

        # Print inventory contents if random_init is enabled
        if args.random_init:
            logging.info("Random initialization enabled, initial inventory:")
            for item, value in info["inventory"].items():
                logging.info(f"{item}: {value}")

        # Initialize FIFO queues with max size 10 using collections.deque
        user_history = deque(maxlen=2)
        assistant_history = deque(maxlen=1)

        # Initialize previous inventory state and player position
        status_change = []
        previous_inventory = info["inventory"]
        previous_player_pos = info["player_pos"]
        previous_player_facing = info["player_facing"]
        previous_achievement = info["achievements"]
        previous_info_semantic = info["semantic"]

        step = 0
        achievements = set()
        while step < 100:

            # Rendering
            image = env.render(size)
            if size != args.window:
                image = Image.fromarray(image)
                image = image.resize(args.window, resample=Image.NEAREST)
                image = np.array(image)

            # Add frame to GIF frames list
            frames.append(image.copy())

            if done:
                break

            surface = pygame.surfarray.make_surface(image.transpose((1, 0, 2)))
            screen.blit(surface, (0, 0))
            pygame.display.flip()

            # update user history
            user_history.append(
                obs["text"]["short_term_context"]
                + "\n\n"
                + obs["text"]["long_term_context"]
            )
            user_history_goal = (
                obs["text"]["long_term_context"]
                + "\n"
                + obs["text"]["short_term_context"]
            )

            ##### 1. generate prompt for current episode w/ history queue #####
            history = create_history(
                obs, user_history, assistant_history, system_prompt
            )

            ##### 1. generate prompt w/ reasoning for current episode w/ history queue #####
            # history = create_history_reasoning(obs, user_history, assistant_history, system_prompt)

            ##### 2. generate action w. model #####
            if "llama" in args.model_name.lower() or "qwen" in args.model_name.lower():
                if args.lora_path:
                    raw_response = generate_response_llama(
                        finetuned_model, tokenizer, history
                    )
                else:
                    raw_response = generate_response_llama(
                        pretrained_model, tokenizer, history
                    )

            ##### 2. generate action w. openai #####
            elif "gpt" in args.model_name.lower():
                try:
                    raw_response = get_response_openai(
                        history, model_name=args.model_name
                    )
                except ValueError as e:
                    logging.error(f"Error calling OpenAI API: {e}")
                    logging.info("Falling back to noop action")
                    raw_response = "noop"

            logging.info(f"step{step} raw response : {raw_response}  |  ")

            #### 2-1. extract reasoning and action from raw response #####
            reasoning = ""
            # reasoning, raw_response = extract_reasoning_and_action(raw_response)
            # print(f"step: {step} reasoning : {reasoning}")
            # print(f"step: {step} action : {text_action}")

            ##### 2-2.random action for debugging #####
            # action = random.randint(0, env.action_space.n)

            ##### 3. Handle cases where raw_response isn't an exact match #####
            action, text_action = filter_action(raw_response, action_str)
            logging.info(f"action : {action_str[action]}")

            # update action history
            assistant_history.append(action_str[action])
            assistant_history_goal = action_str[action]

            # Check pre-step validity
            is_valid = (
                True if step == 0 else check_action_validity_pre_step(text_action, info)
            )

            ##### 4. Save data to csv with status change and goals #####
            save_data_csv(
                env,
                step,
                history,
                raw_response,
                reasoning,
                text_action,
                run_timestamp,
                args.log_name,
                status_change,
                user_history_goal,
                info["player_pos"],
                info["player_facing"],
                assistant_history_goal,
                is_valid,
                info["sleeping"],
            )

            # Store current state for next iteration
            # previous_inventory = info.get('inventory', None) if step > 0 else None
            # previous_player_pos = current_player_pos

            ##### 5. environment step #####
            obs, reward, done, info = env.step(action)
            step += 1

            ##### 6. Identify status change #####
            # Calculate changes in inventory, excluding 'health'
            current_inventory = info["inventory"]
            status_change_dict = (
                {
                    key: current_inventory[key] - previous_inventory.get(key, 0)
                    for key in current_inventory
                }
                if previous_inventory
                else {}
            )
            status_change = [
                key
                for key, change in status_change_dict.items()
                if (key != "health" and change > 0)
                or (key not in ["health", "food", "drink", "energy"] and change < 0)
            ]

            status_changes.append(status_change if status_change else None)

            ##### 7. Identify achievement change #####
            # Calculate changes in inventory, excluding 'health'
            current_achievement = info["achievements"]
            achievement_change_dict = (
                {
                    key: current_achievement[key] - previous_achievement.get(key, 0)
                    for key in current_achievement
                }
                if previous_achievement
                else {}
            )
            achievement_change = [
                key for key, change in achievement_change_dict.items() if change > 0
            ]

            achievement_changes.append(
                achievement_change if achievement_change else None
            )

            ##### 8. Check post-step validity and update CSV if needed #####
            post_step_is_valid = check_action_validity_post_step(
                text_action,
                info["inventory"],
                previous_inventory,
                previous_player_pos,
                info["player_pos"],
                previous_player_facing,
                info["player_facing"],
                previous_info_semantic,
                id_to_item,
            )

            # Update CSV if validity changed
            post_step_update_csv(
                env,
                step,
                is_valid,
                post_step_is_valid,
                status_change,
                achievement_change,
            )

            # prints unlocked achievements
            unlocked = {
                name
                for name, count in info["achievements"].items()
                if count > 0 and name not in achievements
            }

            for name in unlocked:
                achievements |= unlocked
                total = len(info["achievements"].items())
                logging.info(f"Achievement ({len(achievements)}/{total}): {name}")

            # Update previous inventory state
            previous_inventory = info["inventory"]
            previous_achievement = info["achievements"]
            previous_player_pos = info["player_pos"]
            previous_info_semantic = info["semantic"]
            previous_player_facing = info["player_facing"]

        # Create GIF from frames
        gif_path = os.path.join(args.gif_dir, f"episode_{run_timestamp}_{episode}.gif")
        imageio.mimsave(gif_path, frames, fps=5, loop=True, palettesize=64)
        logging.info(f"GIF saved to {gif_path}")

        # print total achievements
        total = len(info["achievements"].items())
        logging.info(f"Episode{episode} Achievements: {achievements}")
        logging.info(
            f"Episode{episode} Achievement ratio ({len(achievements)}/{total}) = {len(achievements)/total}"
        )
        total_achievements_num.append(len(achievements))

        # Update achievement counts
        achievement_counts.update(
            {a: achievement_counts.get(a, 0) + 1 for a in achievements}
        )

        logging.info(f"Episode{episode} length: {step}")

    # Print achievement counts and percentages across all episodes
    logging.info("\nAchievement statistics across all episodes:")
    for achievement, count in achievement_counts.items():
        percentage = (count / args.episodes) * 100
        logging.info(
            f"{achievement}: {count}/{args.episodes} episodes ({percentage:.1f}%)"
        )

    total = len(info["achievements"].items())
    logging.info(
        f"Average achievement ratio: {np.mean(total_achievements_num)/total:.3f} ± {np.std(total_achievements_num)/total:.3f}"
    )
    time.sleep(1)  # Add 1 second delay for saving gif
    pygame.quit()


if __name__ == "__main__":
    main()
