import argparse
import pathlib
import time
import multiprocessing as mp
from functools import partial

import numpy as np

import crafter

import os
import sys
import io
import cv2
import json
import time
import torch
import random
import datetime
import csv

# Add the project directory to the Python path
project_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
print(f"Project path: {project_path}")
if project_path not in sys.path:
    sys.path.append(project_path)

import traceback
import itertools

# import numpy as np
import regex as re
import matplotlib.pyplot as plt

from PIL import Image
from glob import glob
from pathlib import Path

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

from huggingface_hub import login

from collections import deque
from transformers import AutoTokenizer
from lib.crafter_custom.crafter.language_wrapper import CrafterLanguageWrapper

from pytz import timezone

import lib.crafter_custom.crafter.random_init_wrapper as random_init_wrapper
from random_init_wrapper import RandomInitWrapper

import lib.crafter_custom.crafter.eval_wrapper as eval_wrapper
from eval_wrapper import EvalWrapper

try:
    import pygame
except ImportError:
    print("Please install the pygame package to use the GUI.")
    raise
from PIL import Image
import imageio  # Add imageio import for GIF creation

import logging

from openai import OpenAI  # Add OpenAI import
import anthropic
import google.generativeai as genai

import src.system_prompts as system_prompts


def load_model(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    lora_path=None,
):
    # base model
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
        use_auth_token=True,
        # attn_implementation="flash_attention",
        attn_implementation="sdpa",
    )

    # print model distribution across devices
    if hasattr(base_model, "hf_device_map"):
        print("Model distribution across devices:")
        for module, device in base_model.hf_device_map.items():
            print(f"{module}: {device}")

    base_model.config.use_cache = False
    base_model.config.pretraining_tp = 1

    # lora checkpoint
    if lora_path is None:
        print("\n\n===loading base model===\n\n")
        model = base_model
    else:
        print(f"\n\n===loading from {lora_path}===\n\n")
        lora_model = PeftModel.from_pretrained(base_model, lora_path)
        model = lora_model

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.add_eos_token = True

    return model, tokenizer


def create_history_with_instruction(
    obs, user_history, assistant_history, system_prompt, instruction, post_step_is_valid
):
    history = []
    history.append({"role": "system", "content": system_prompt + instruction})

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            invalid_action_message = (
                "\n\nThe previous action was invalid. Try a different one."
                if not post_step_is_valid
                else ""
            )
            history.append(
                {
                    "role": "user",
                    "content": "### Current Observation\n"
                    + user_history[i]
                    + "\n\n"
                    + """Only return a single valid action name (e.g. move_left, place_table, make_wood_pickaxe) with no additional explanation or formatting.
Always output one action per turn until the episode ends."""
                    + invalid_action_message,
                }
            )
        else:
            history.append(
                {
                    "role": "user",
                    "content": "### Previous Observation\n" + user_history[i],
                }
            )
            history.append({"role": "assistant", "content": assistant_history[i]})
    return history


def create_history_reasoning_with_instruction(
    obs, user_history, assistant_history, system_prompt, instruction
):
    ### TODO : fix history structure
    history = []
    # Replace "Your goal :" part in system prompt with instruction
    modified_system_prompt = re.sub(
        r"Your goal:.*", f"Your goal : {instruction}", system_prompt, count=1
    )
    history.append({"role": "system", "content": modified_system_prompt})

    print("modified_system_prompt : \n", modified_system_prompt)

    for i in range(max(len(assistant_history), len(user_history))):
        if i == max(len(assistant_history), len(user_history)) - 1:
            history.append(
                {
                    "role": "user",
                    "content": "### Current observation\n"
                    + user_history[i]
                    + """You are an AI agent following a step-by-step reasoning approach to decide on the best action to take at each step.
You must explicitly state your reasoning before providing an action.

Your output should be formatted as:
reasoning : [your step-by-step thought process] // You explicitly state your reasoning before providing an action.
action : [one action from the provided action list] // You always have to output one of the above actions at a time and no other text.
""",
                }
            )
        else:
            prefix_user = "### Observation_{num}\n".format(
                num=i - len(user_history) + 1
            )
            history.append({"role": "user", "content": prefix_user + user_history[i]})
            prefix_assistant = "### Action_{num}\n".format(
                num=i - len(user_history) + 1
            )
            history.append(
                {
                    "role": "assistant",
                    "content": prefix_assistant + assistant_history[i],
                }
            )
    return history


def generate_prompt_and_tokenize_llama(history, tokenizer):
    # generate prompt w. special tokens
    prompt = tokenizer.apply_chat_template(history, tokenize=False)
    prompt += "<|start_header_id|>assistant<|end_header_id|>"

    # tokenize the input
    model_inputs = tokenizer(
        [prompt],
        return_tensors="pt",
        max_length=2**13,  # 8K
        padding=True,
        truncation=True,
    ).to("cuda")
    tokenized_chat = model_inputs.input_ids
    # print("token len:", len(tokenized_chat[0]))
    return model_inputs, tokenized_chat


def generate_response_llama(model, tokenizer, history):
    model_inputs, tokenized_chat = generate_prompt_and_tokenize_llama(
        history, tokenizer
    )

    attention_mask = model_inputs["attention_mask"]
    with torch.cuda.amp.autocast():
        generated_ids = model.generate(
            tokenized_chat,
            do_sample=False,
            max_new_tokens=256,
            # temperature=0.1,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=attention_mask,
        )

        generated_ids = [
            output_ids[len(input_ids) :]
            for input_ids, output_ids in zip(tokenized_chat, generated_ids)
        ]

        raw_response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
            0
        ]
        raw_response = raw_response.lower().strip()
    return raw_response


def get_response_openai(history, model_name="gpt-4o"):
    """
    Get a response from GPT-4o using the OpenAI API and check token counts.

    Args:
        history: List of dictionaries with 'role' and 'content' keys

    Returns:
        The model's response text
    """
    # Check if OpenAI API key is set
    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        raise ValueError(
            "OPENAI_API_KEY environment variable is not set. Please set your OpenAI API key."
        )

    try:
        client = OpenAI(api_key=api_key)

        response = client.chat.completions.create(
            model=model_name, messages=history, temperature=0, max_tokens=256
        )

        print(f"Total tokens used: {response.usage.total_tokens}")

        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error calling OpenAI API: {e}")
        return "noop"  # Default to noop if API call fails


def get_response_claude(history, model_name="claude-3-opus-20240229"):
    """
    Get a response from Claude using the Anthropic API.

    Args:
        history: List of dictionaries with 'role' and 'content' keys

    Returns:
        The model's response text
    """
    api_key = os.environ.get("ANTHROPIC_API_KEY")
    if not api_key:
        raise ValueError("ANTHROPIC_API_KEY environment variable is not set.")

    client = anthropic.Anthropic(api_key=api_key)

    # Convert history to Claude format (Anthropic doesn't use 'system' role directly)
    system_prompt = ""
    claude_messages = []

    for message in history:
        if message["role"] == "system":
            system_prompt = message["content"]
        elif message["role"] in ["user", "assistant"]:
            claude_messages.append(
                {
                    "role": message["role"],
                    "content": [{"type": "text", "text": message["content"]}],
                }
            )
    # print("*** SYSTEM PROMPT ***\n", system_prompt)
    # print("*** CLAUDE MESSAGES ***\n", claude_messages)

    try:
        completion = client.messages.create(
            model=model_name,
            system=system_prompt,
            messages=claude_messages,
            temperature=0.0,
            max_tokens=256,
            # "top_p": 1.0,
        )
        return completion.content[0].text.strip()
    except Exception as e:
        print(f"Error calling Claude API: {e}")
        return "noop"


def get_response_gemini(history, model_name="gemini-1.5-pro-001"):
    """
    Get a response from Gemini using the Google Generative AI API.

    Args:
        history: List of messages in the format [{"role": "user"/"assistant"/"system", "content": "text"}]
        model_name: The Gemini model to use (default: gemini-1.5-pro-001)

    Returns:
        str: The generated response text
    """
    # Check API key
    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        raise ValueError("GEMINI_API_KEY environment variable is not set.")

    # Configure Gemini
    genai.configure(api_key=api_key)

    # Build system + conversation prompt
    system_prompt = ""
    prompt_messages = []

    for message in history:
        if message["role"] == "system":
            system_prompt = message["content"]
        else:
            prefix = "User: " if message["role"] == "user" else "Assistant: "
            prompt_messages.append(prefix + message["content"])

    prompt = system_prompt + "\n" + "\n".join(prompt_messages)
    print("*** PROMPT ***\n", prompt)

    # Generation config (adjust as needed)
    generation_config = {
        "temperature": 0.0,
        # "top_p": 1.0,
        "max_output_tokens": 256,
    }

    # Safety settings (optional defaults)
    safety_settings = [
        {"category": "HARM_CATEGORY_HARASSMENT", "threshold": 3},
        {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": 3},
        {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": 3},
        {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": 3},
    ]

    # Create model
    model = genai.GenerativeModel(
        model_name=model_name,
        generation_config=generation_config,
        safety_settings=safety_settings,
    )

    try:
        response = model.generate_content([prompt])

        return response.text.strip()
    except Exception as e:
        print(f"Error calling Gemini API: {e}")
        return "noop"


def filter_action(raw_response, action_str):
    action = 0
    try:
        action = action_str.index(raw_response)
    except ValueError:
        if "move" in raw_response:
            if "left" in raw_response or "west" in raw_response:
                action = action_str.index("move_left")
            elif "right" in raw_response or "east" in raw_response:
                action = action_str.index("move_right")
            elif "up" in raw_response or "north" in raw_response:
                action = action_str.index("move_up")
            elif "down" in raw_response or "south" in raw_response:
                action = action_str.index("move_down")
        elif "make" in raw_response:
            if "wood" in raw_response and "pick" in raw_response:
                action = action_str.index("make_wood_pickaxe")
            elif "stone" in raw_response and "pick" in raw_response:
                action = action_str.index("make_stone_pickaxe")
            elif "iron" in raw_response and "pick" in raw_response:
                action = action_str.index("make_iron_pickaxe")
            elif "wood" in raw_response and "sword" in raw_response:
                action = action_str.index("make_wood_sword")
            elif "stone" in raw_response and "sword" in raw_response:
                action = action_str.index("make_stone_sword")
            elif "iron" in raw_response and "sword" in raw_response:
                action = action_str.index("make_iron_sword")
        elif "place" in raw_response:
            if "stone" in raw_response:
                action = action_str.index("place_stone")
            elif "table" in raw_response:
                action = action_str.index("place_table")
            elif "furnace" in raw_response:
                action = action_str.index("place_furnace")
            elif "plant" in raw_response:
                action = action_str.index("place_plant")
        elif raw_response.isdigit():
            action = int(raw_response)
        else:
            action = 0  # Default to noop if no match found
    text_action = action_str[action]
    return action, text_action


def check_action_validity_post_step(
    action,
    info_inventory,
    previous_info_inventory,
    previous_player_pos,
    current_player_pos,
    previous_player_facing,
    current_player_facing,
    previous_info_semantic,
    id_to_item,
):
    """
    Check if the action is valid based on post-step conditions.

    Args:
        action (str): The action taken
        info_inventory (dict): Current inventory
        previous_info_inventory (dict): Previous inventory
        current_player_pos (ndarray): Current player position
        previous_player_pos (ndarray): Previous player position

    Returns:
        bool: True if action is valid, False otherwise
    """
    is_valid = True
    facing_object = ""
    if action:
        # Condition 1: Check if make/do actions actually changed inventory
        if "make" in action.lower() or action.lower() == "do":
            # Check if none of the values in info_inventory have increased (excluding health)
            has_increase = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key != "health"
                    and key in previous_info_inventory
                    and info_inventory[key] > previous_info_inventory[key]
                ):
                    has_increase = True
                    break
            if not has_increase:
                is_valid = False

            # Exception for do action when player facing creatures
            facing_object = id_to_item[
                previous_info_semantic[
                    int(previous_player_pos[0] + previous_player_facing[0]),
                    int(previous_player_pos[1] + previous_player_facing[1]),
                ]
            ]
            if action.lower() == "do" and facing_object in [
                "cow",
                "zombie",
                "skeleton",
            ]:
                is_valid = True

        if "place" in action.lower():
            # Check if none of the values in info_inventory have decreased (excluding status)
            has_decrease = False
            ### If there is no increase in other inventory or status except for health auto-regeneration, it is invalid.
            for key in info_inventory:
                if (
                    key not in ["health", "food", "drink", "energy"]
                    and key in previous_info_inventory
                    and info_inventory[key] < previous_info_inventory[key]
                ):
                    has_decrease = True
                    break
            if not has_decrease:
                is_valid = False

        # Condition 5: Check if move actions actually changed position
        if "move" in action.lower():
            if previous_player_pos is not None and current_player_pos is not None:
                if np.array_equal(
                    previous_player_pos, current_player_pos
                ) and np.array_equal(previous_player_facing, current_player_facing):
                    is_valid = False

    return is_valid


def extract_reasoning_and_action(raw_response):
    try:
        reasoning = raw_response.split("action :")[0].split("reasoning :")[1].strip()
        action = raw_response.split("action :")[1].strip()
    except:
        reasoning = ""
        action = "noop"
    return reasoning, action


def check_loop(assistant_history_for_loop_detect, action_str):
    """
    Check if there's a loop in the action history.

    Args:
        assistant_history_for_loop_detect: A deque of recent actions

    Returns:
        is_in_loop: Boolean indicating if a loop is detected
        escape_loop_action: Action to take to escape the loop
    """
    actions = list(assistant_history_for_loop_detect)
    is_loop = False

    if (
        len(actions) < 6
    ):  # Need at least 6 actions to detect a loop of length 2 repeated 3 times
        return False, 0  # Default action if not enough history

    # Check for loops of length 2 to 6
    for loop_length in range(2, 7):
        # Need at least loop_length * 3 actions to detect a loop
        if len(actions) >= loop_length * 3:
            # Check if the last loop_length s* 3 actions form a repeating pattern
            last_actions = actions[-loop_length * 3 :]

            # Extract potential loop pattern (the last loop_length actions)
            pattern = actions[-loop_length:]

            # Check if this pattern repeats 3 or more times consecutively
            is_loop = True
            for i in range(3):
                if i == 0:
                    segment = last_actions[-loop_length:]
                elif i == 1:
                    segment = last_actions[-(i + 1) * loop_length : -(i) * loop_length]
                elif i == 2:
                    segment = last_actions[:loop_length]

                if segment != pattern:  # if i=0, always true
                    # logging.info(f"segment{i} : {segment} pattern : {pattern}")
                    is_loop = False
                    break

            if is_loop:
                # Loop detected, find an escape action
                all_move_actions = ["move_left", "move_right", "move_up", "move_down"]
                escape_actions = [
                    action for action in all_move_actions if action not in pattern
                ]

                # If all move actions are in the loop, choose a random one
                if not escape_actions:
                    escape_action = action_str.index(random.choice(all_move_actions))
                else:
                    escape_action = action_str.index(random.choice(escape_actions))

                return True, escape_action

    # No loop detected
    return False, 0  # Default action if no loop detected


def run_task_episode(
    task,
    episode,
    args,
    pretrained_model,
    tokenizer,
    finetuned_model=None,
    seed=None,
    id_to_item=None,
    gif_dir=None,
):
    """
    Run a single task-episode combination.

    Args:
        task: The task to run
        episode: Episode number
        args: Command line arguments
        pretrained_model: The pretrained model
        tokenizer: The tokenizer
        finetuned_model: Optional finetuned model

    Returns:
        tuple: (task, episode, success, steps)
            - task: str, the task name
            - episode: int, episode number
            - success: bool, whether the task was completed
            - steps: int, number of steps taken
    """
    # Convert task to achievement format (lowercase and replace spaces with underscores)
    task_name = task.lower().replace(" ", "_")

    action_str = [
        "noop",
        "move_left",
        "move_right",
        "move_up",
        "move_down",
        "do",
        "sleep",
        "place_stone",
        "place_table",
        "place_furnace",
        "place_plant",
        "make_wood_pickaxe",
        "make_stone_pickaxe",
        "make_iron_pickaxe",
        "make_wood_sword",
        "make_stone_sword",
        "make_iron_sword",
    ]
    # Create environment for this process
    env = crafter.Env(area=args.area, length=args.length, seed=seed)
    env = crafter.Recorder(env, args.record)
    env = EvalWrapper(env, task_name=task_name)
    env = CrafterLanguageWrapper(env)

    # Initialize pygame for this process
    pygame.init()
    screen = pygame.display.set_mode(args.window)

    size = list(args.size)
    size[0] = size[0] or args.window[0]
    size[1] = size[1] or args.window[1]

    # Initialize state
    obs, reward, done, info = env.reset()
    done = False

    # Initialize FIFO queues for history of observations and actions
    user_history = deque(maxlen=2)
    assistant_history = deque(maxlen=1)
    assistant_history_for_loop_detect = deque(maxlen=20)

    # Initialize previous inventory state and player position
    status_change = []
    previous_inventory = info["inventory"]
    previous_player_pos = info["player_pos"]
    previous_player_facing = info["player_facing"]
    previous_achievement = info["achievements"]
    previous_info_semantic = info["semantic"]
    post_step_is_valid = True
    is_in_loop = False
    escape_loop_action = 0

    step = 0
    achievements = set()
    success = False

    logging.info(f"########## Starting Task: {task}, Episode: {episode} ##########")

    # Print initial inventory status
    logging.info("Initial Inventory Status:")
    for item, amount in info["inventory"].items():
        if amount > 0:  # Only show items that are present in inventory
            logging.info(f"  {item}: {amount}")

    try:
        frames = []
        while True:

            # Rendering
            image = env.render(size)
            if size != args.window:
                image = Image.fromarray(image)
                image = image.resize(args.window, resample=Image.NEAREST)
                image = np.array(image)

            # Add frame to GIF frames list
            frames.append(image.copy())

            if done or step > 100:
                break

            surface = pygame.surfarray.make_surface(image.transpose((1, 0, 2)))
            screen.blit(surface, (0, 0))
            pygame.display.flip()

            # update user history
            user_history.append(
                obs["text"]["short_term_context"]
                + "\n\n"
                + obs["text"]["long_term_context"]
            )

            # Generate prompt and get action
            history = create_history_with_instruction(
                obs,
                user_history,
                assistant_history,
                system_prompts.SYSTEM_PROMPTS[args.system_prompt],
                task,
                post_step_is_valid,
            )

            if "llama" in args.model_name or "Qwen" in args.model_name:
                if args.lora_path:
                    raw_response = generate_response_llama(
                        finetuned_model, tokenizer, history
                    )
                else:
                    raw_response = generate_response_llama(
                        pretrained_model, tokenizer, history
                    )
            elif "gpt" in args.model_name:
                try:
                    raw_response = get_response_openai(
                        history, model_name=args.model_name
                    )
                except ValueError as e:
                    logging.error(f"Error calling OpenAI API: {e}")
                    logging.info("Falling back to noop action")
                    raw_response = "noop"
            elif "claude" in args.model_name:
                try:
                    raw_response = get_response_claude(
                        history, model_name=args.model_name
                    )
                except ValueError as e:
                    logging.error(f"Error calling OpenAI API: {e}")
                    logging.info("Falling back to noop action")
                    raw_response = "noop"
            elif "gemini" in args.model_name:
                try:
                    raw_response = get_response_gemini(
                        history, model_name=args.model_name
                    )
                except ValueError as e:
                    logging.error(f"Error calling Gemini API: {e}")
                    raw_response = "noop"

            logging.info(f"step{step} raw response : {raw_response}  |  ")

            # reasoning, raw_response = extract_reasoning_and_action(raw_response)

            # Process action
            action, text_action = filter_action(raw_response, action_str)

            ##### 3-1. Execute random action with 0.1 probability (random integer between 1 and 16) #####
            # if random.random() < 0.1:
            #     action = random.randint(0, 16)  # Random integer from 1 to 16 (except noop)
            #     text_action = action_str[action]
            #     logging.info(f"Random action selected: {text_action}")

            ##### 3-2. Execute random move action when fall in loop #####
            if is_in_loop:
                action = escape_loop_action
                is_in_loop = False
                logging.info(f"!! Escape loop !!")

            logging.info(
                f"step{step:<4} raw response : {raw_response:<30} | action : {action_str[action]:<20}"
            )

            # Update history and check validity
            assistant_history.append(action_str[action])
            assistant_history_for_loop_detect.append(action_str[action])

            # Environment step
            obs, reward, done, info = env.step(action)
            step += 1

            ##### 8. Check post-step validity #####
            post_step_is_valid = check_action_validity_post_step(
                text_action,
                info["inventory"],
                previous_inventory,
                previous_player_pos,
                info["player_pos"],
                previous_player_facing,
                info["player_facing"],
                previous_info_semantic,
                id_to_item,
            )
            if not post_step_is_valid:
                logging.info(f"step{step-1} : invalid action")

            # Check if loop is detected
            is_in_loop, escape_loop_action = check_loop(
                assistant_history_for_loop_detect, action_str
            )

            # Check achievements
            unlocked = {
                name
                for name, count in info["achievements"].items()
                if count > 0 and name not in achievements
            }

            for name in unlocked:
                achievements |= unlocked
                total = len(info["achievements"].items())
                logging.info(f"Achievement ({len(achievements)}/{total}): {name}")

                # Check if task is completed
                if task_name in name:
                    logging.info(f"Task Success! Completed task: {task}")
                    done = True
                    success = True
                    break

    finally:
        if gif_dir:
            # Create GIF from frames
            gif_path = os.path.join(gif_dir, f"episode_{task}_{episode}.gif")
            imageio.mimsave(gif_path, frames, fps=5, loop=True, palettesize=64)
            logging.info(f"GIF saved to {gif_path}")

        # Proper cleanup
        pygame.display.quit()
        pygame.quit()
        env.close()  # Make sure environment is properly closed

    # Return task, episode, success status and steps
    return task, episode, success, step


def main():
    # Remove spawn method setting as we'll run sequentially

    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--area", nargs=2, type=int, default=(64, 64))
    parser.add_argument("--length", type=int, default=10000)
    parser.add_argument("--health", type=int, default=9)
    parser.add_argument("--record", type=pathlib.Path, default=None)
    parser.add_argument("--episodes", type=int, default=1)
    parser.add_argument("--size", type=int, nargs=2, default=(0, 0))
    parser.add_argument("--window", type=int, nargs=2, default=(600, 600))
    parser.add_argument("--gif_dir", type=str, default="")
    parser.add_argument(
        "--system_prompt",
        type=str,
        default="eval_each_task",
        choices=list(system_prompts.SYSTEM_PROMPTS.keys()),
        help="System prompt to use for the agent",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="meta-llama/Llama-3.1-8B-Instruct",  # meta-llama/Llama-3.1-8B-Instruct, Qwen/Qwen2.5-7B-Instruct
        help="Model name to use for inference",
    )
    parser.add_argument(
        "--lora_path", type=str, default=None, help="Path to LoRA weights"
    )
    parser.add_argument(
        "--log_name",
        type=str,
        default=None,
        help="Log file name. Defaults to current date and time if not provided.",
    )
    parser.add_argument(
        "--random_init",
        type=bool,
        default=False,
        help="Enable random initialization of the environment.",
    )
    parser.add_argument(
        "--task_range",
        type=int,
        nargs=2,
        default=None,
        help="Specify the start and end indices for tasks to run.",
    )
    # Remove num_processes argument as we're not using multiprocessing anymore

    args = parser.parse_args()

    # Set default log file name if not provided
    if args.log_name is None:
        args.log_name = datetime.datetime.now().strftime("%Y%m%d_%H%M%S.log")

    # Create run timestamp
    args.run_timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Create directories
    os.makedirs("history_goal", exist_ok=True)
    os.makedirs("eval/logs", exist_ok=True)

    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler("eval/logs/" + args.log_name),
            logging.StreamHandler(),
        ],
    )

    # Get system prompt
    logging.info(f"Using system prompt: {args.system_prompt}")

    # Set up environment constants
    random = np.random.RandomState(args.seed)
    crafter.constants.items["health"]["max"] = args.health
    crafter.constants.items["health"]["initial"] = args.health

    # Create id_to_item mapping
    id_to_item = [0] * 19
    dummyenv = crafter.Env()
    for name, ind in itertools.chain(
        dummyenv._world._mat_ids.items(), dummyenv._sem_view._obj_ids.items()
    ):
        name = (
            str(name)[str(name).find("objects.") + len("objects.") : -2].lower()
            if "objects." in str(name)
            else str(name)
        )
        id_to_item[ind] = name

    # Load model once for all episodes
    pretrained_model = None
    finetuned_model = None
    tokenizer = None

    if "llama" in args.model_name or "Qwen" in args.model_name:
        pretrained_model, tokenizer = load_model(
            model_name=args.model_name, lora_path=None
        )
        if args.lora_path:
            finetuned_model, _ = load_model(
                model_name=args.model_name, lora_path=args.lora_path
            )

    # Define tasks
    tasks = [
        "Collect wood",
        "Place table",
        "Eat cow",
        "Collect sapling",
        "Collect drink",
        "Make wood pickaxe",
        "Make wood sword",
        "Place plant",
        "Defeat zombie",
        "Collect stone",
        "Place stone",
        "Eat plant",
        "Defeat skeleton",
        "Make stone pickaxe",
        "Make stone sword",
        "Wake up",
        "Place furnace",
        "Collect coal",
        "Collect iron",
        "Make iron pickaxe",
        "Make iron sword",
        "Collect diamond",
    ]

    # Apply task range if specified
    if args.task_range:
        start, end = args.task_range
        tasks = tasks[start:end]

    # Initialize results dictionary
    task_results = {}
    for task in tasks:
        task_results[task] = {"successes": 0, "steps": []}

    seed_list = [42 + 10 * i for i in range(args.episodes)]

    # Run tasks sequentially instead of in parallel
    for task in tasks:
        logging.info(f"Starting task: {task}")
        for episode in range(args.episodes):
            logging.info(f"Running episode {episode+1}/{args.episodes}")
            result = run_task_episode(
                task,
                episode,
                args,
                pretrained_model,
                tokenizer,
                finetuned_model,
                seed_list[episode],
                id_to_item,
                gif_dir=args.gif_dir,
            )

            # Process results
            _, _, success, steps = result
            task_results[task]["successes"] += int(success)
            task_results[task]["steps"].append(steps)

            # Log episode result
            logging.info(
                f"Task : {task} | Episode {episode+1} {'successful' if success else 'failed'} after {steps} steps"
            )

        # Log task-specific success and failure counts after each task
        success_count = task_results[task]["successes"]
        failure_count = args.episodes - success_count
        logging.info(
            f"######### Task : {task} | Successes: {success_count}, Failures: {failure_count} ##########"
        )

    # Print task-specific statistics
    logging.info("\n########### Statistics ########### ")
    for task, results in task_results.items():
        success_rate = (results["successes"] / args.episodes) * 100
        avg_steps = np.mean(results["steps"])
        logging.info(f"#### Task : {task} Result ####")
        logging.info(f"  Success Count: {results['successes']}/{args.episodes}")
        logging.info(f"  Success Rate: {success_rate:.1f}%")
        logging.info(f"  Average Steps: {avg_steps:.1f}")


if __name__ == "__main__":
    print("Starting evaluation...")
    main()
