# -*- coding: utf-8 -*-
"""
run_robocerebra_eval.py
"""
import json
import logging
import os
import re
import sys
from collections import deque
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
import random

from typing import Optional, Union, Dict, List, Tuple, Any, Sequence

import draccus
import numpy as np
import tqdm
from robosuite import load_controller_config
from libero.libero import benchmark
import libero.libero.envs.bddl_utils as BDDLUtils
from libero.libero.envs import *  # noqa: F403

import h5py
import wandb
import imageio
# Append current directory so that interpreter can find experiments.robot
sys.path.append("../..")
from experiments.robot.libero.libero_utils import (  # noqa: E402
    get_libero_dummy_action,
    get_libero_env,
    get_libero_image,
    get_libero_wrist_image,
    quat2axisangle,
    DATE_TIME,
    # save_rollout_video,
)
BASE_DIR = Path.cwd() / "rollouts" / DATE_TIME
BASE_DIR.mkdir(parents=True, exist_ok=True)
from experiments.robot.openvla_utils import (  # noqa: E402
    get_action_head,
    get_noisy_action_projector,
    get_processor,
    get_proprio_projector,
    resize_image_for_policy,
)
from experiments.robot.robot_utils import (  # noqa: E402
    get_action,
    get_image_resize_size,
    get_model,
    invert_gripper_action,
    normalize_gripper_action,
    set_seed_everywhere,
)
from prismatic.vla.constants import NUM_ACTIONS_CHUNK  # noqa: E402

# ★★ List of movable objects - only these are allowed to be moved as distractors ★★
MOVABLE_OBJECT_LIST = [
    "alphabet_soup", "bbq_sauce", "butter", "chocolate_pudding", "cookies", "cream_cheese",
    "ketchup", "macaroni_and_cheese", "milk", "orange_juice", "popcorn", "salad_dressing",
    "new_salad_dressing", "tomato_sauce", "white_bowl", "akita_black_bowl", "plate",
    "glazed_rim_porcelain_ramekin", "red_coffee_mug", "porcelain_mug", "white_yellow_mug",
    "chefmate_8_frypan", "bowl_drainer", "moka_pot", "window", "faucet",
    "black_book", "yellow_book", "desk_caddy", "wine_bottle"
]

# --------------------------------------------------------------------------------------------------
# Logging
# --------------------------------------------------------------------------------------------------

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------------------------------
# Configuration
# --------------------------------------------------------------------------------------------------

@dataclass
class GenerateConfig:
    # fmt: off
    # ------------------------------------------------------------------
    # Model‑specific parameters
    # ------------------------------------------------------------------
    model_family: str = "openvla"
    pretrained_checkpoint: Union[str, Path] | None = 'your openvla-oft checkpoint'
    use_l1_regression: bool = True
    use_diffusion: bool = False
    num_diffusion_steps: int = 50
    use_film: bool = False
    num_images_in_input: int = 2
    use_proprio: bool = True
    center_crop: bool = True
    num_open_loop_steps: int = 8
    unnorm_key: Union[str, Path] = "your task suite name"
    load_in_8bit: bool = False
    load_in_4bit: bool = False
    root_dir: str = "dataset path"

    # ------------------------------------------------------------------
    # LIBERO environment‑specific parameters
    # ------------------------------------------------------------------
    task_suite_name: str = "your task suite name"
    num_steps_wait: int = 15
    num_trials_per_task: int = 5
    initial_states_path: str = "DEFAULT"
    env_img_res: int = 256
    switch_steps: int = 150
    resume: bool = False  # <-- New: Whether to enable the resume mechanism
    dynamic_shift_description: bool = False
    complete_description: bool = False
    task_description_suffix: str = ""
    dynamic: bool = False

    # ------------------------------------------------------------------
    # Utils
    # ------------------------------------------------------------------
    run_id_note: Optional[str] = None
    local_log_dir: str = "./experiments/logs"
    use_wandb: bool = False
    wandb_entity: str = "your-wandb‑entity"
    wandb_project: str = "your-wandb‑project"
    seed: int = 7
    # fmt: on

# --------------------------------------------------------------------------------------------------
# Utility Functions
# --------------------------------------------------------------------------------------------------

def load_actions(json_path: str) -> Dict[str, List[List[str]]]:
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    result: Dict[str, List[List[str]]] = {}
    for obj_id, relations in data.items():
        processed = []
        for triple in relations:
            if len(triple) == 2:
                verb, subj = triple
                processed.append([verb.lower(), subj])
            elif len(triple) == 3:
                verb, subj, obj = triple
                processed.append([verb.lower(), subj, obj])
            else:
                # Discard other formats directly
                continue
        result[obj_id] = processed
    return result


def validate_config(cfg: GenerateConfig) -> None:
    assert cfg.pretrained_checkpoint, "pretrained_checkpoint must not be None!"
    if "image_aug" in str(cfg.pretrained_checkpoint):
        assert cfg.center_crop, "Expecting center_crop=True because model was trained with image augmentations!"
    assert not (cfg.load_in_8bit and cfg.load_in_4bit), "Cannot use both 8‑bit and 4‑bit quantization!"
    if cfg.dynamic:
        assert cfg.resume


# --------------------------------------------------------------------------------------------------
# Model Initialization
# --------------------------------------------------------------------------------------------------

def initialize_model(cfg: GenerateConfig):
    model = get_model(cfg)
    proprio_projector = get_proprio_projector(cfg, model.llm_dim, proprio_dim=8) if cfg.use_proprio else None
    action_head = get_action_head(cfg, model.llm_dim) if (cfg.use_l1_regression or cfg.use_diffusion) else None
    noisy_action_projector = (
        get_noisy_action_projector(cfg, model.llm_dim) if cfg.use_diffusion else None
    )
    processor = get_processor(cfg) if cfg.model_family == "openvla" else None
    # unnorm key check
    unnorm_key = cfg.task_suite_name
    if unnorm_key not in model.norm_stats and f"{unnorm_key}_no_noops" in model.norm_stats:
        unnorm_key = f"{unnorm_key}_no_noops"
    assert unnorm_key in model.norm_stats, f"Action un‑norm key {unnorm_key} not found!"
    cfg.unnorm_key = unnorm_key
    return model, action_head, proprio_projector, noisy_action_projector, processor


# --------------------------------------------------------------------------------------------------
# Logging utils
# --------------------------------------------------------------------------------------------------

def setup_logging(cfg: GenerateConfig):
    run_id = f"EVAL-{cfg.task_suite_name}-{cfg.model_family}-{DATE_TIME}"
    if cfg.run_id_note:
        run_id += f"--{cfg.run_id_note}"
    os.makedirs(cfg.local_log_dir, exist_ok=True)
    local_log_filepath = os.path.join(cfg.local_log_dir, run_id + ".txt")
    log_file = open(local_log_filepath, "w")
    logger.info(f"Logging to {local_log_filepath}")
    if cfg.use_wandb:
        wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, name=run_id)
    return log_file, local_log_filepath, run_id


def log_message(msg: str, log_file=None):
    logger.info(msg)
    if log_file:
        log_file.write(msg + "\n")
        log_file.flush()

def save_rollout_video(
    rollout_images,
    idx,
    success,
    task_description,
    log_file=None,
    task_name: str = ""
):
    """Saves an MP4 replay of an episode."""
    rollout_dir = BASE_DIR / task_name if task_name else BASE_DIR
    os.makedirs(rollout_dir, exist_ok=True)
    processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
    video_name = f"{DATE_TIME}--episode={idx}--success={int(success)}--task={processed_task_description}.mp4"
    mp4_path = rollout_dir / video_name
    video_writer = imageio.get_writer(mp4_path, fps=30)
    for img in rollout_images:
        video_writer.append_data(img)
    video_writer.close()
    print(f"Saved rollout MP4 at path {mp4_path}")
    if log_file is not None:
        log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
    return str(mp4_path)
# --------------------------------------------------------------------------------------------------
# Observation / action helpers
# --------------------------------------------------------------------------------------------------

def prepare_observation(obs, resize_size):
    img = get_libero_image(obs)
    wrist_img = get_libero_wrist_image(obs)
    img_resized = resize_image_for_policy(img, resize_size)
    wrist_img_resized = resize_image_for_policy(wrist_img, resize_size)
    observation = {
        "full_image": img_resized,
        "wrist_image": wrist_img_resized,
        "state": np.concatenate(
            (obs["robot0_eef_pos"], quat2axisangle(obs["robot0_eef_quat"]), obs["robot0_gripper_qpos"])
        ),
    }
    return observation, img


def process_action(action, model_family):
    action = normalize_gripper_action(action, binarize=True)
    if model_family == "openvla":
        action = invert_gripper_action(action)
    return action

# ★★ Find the address (addr) of an object's y-axis in qpos ★★
def _find_obj_y_addr(sim, obj_name: str) -> Optional[int]:
    """
    Infers the index of the y-coordinate of the object's position in qpos based on the joint naming convention.
    Returns None if not found (the object will be ignored).
    """
    # Three common naming conventions: <name>_1_joint0 / <name>_joint0 / <name>_joint
    patterns = [f"{obj_name}_1_joint0", f"{obj_name}_joint0", f"{obj_name}_joint"]
    for jn in patterns:
        if jn in sim.model.joint_names:
            qpos_addr = sim.model.get_joint_qpos_addr(jn)[0]  # x
            return qpos_addr + 1  # x,y,z -> take y
    return None


# ★★ Parse task_description(.suffix).json ★★
def _load_step_objects(json_path: str, step_desc: Sequence[str]) -> List[str]:
    """
    Args:
        json_path: Path to task_description{suffix}.json
        step_desc: List obtained from parse_task_description() (with the "Step:" prefix removed)
    Returns:
        A list of object names corresponding to step_desc (if not found, an empty string is used as a placeholder)
    """
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)      # list[dict{step, object}]
    # Create a mapping from 'Step: xxx' to object
    mapping = {item["step"] : item["object"] for item in data if "step" in item}
    objs = []
    for desc in step_desc:
        key = f"Step: {desc}"
        objs.append(mapping.get(key, ""))   # Use an empty string as placeholder if not found
    return objs


# --------------------------------------------------------------------------------------------------
# Episode loop (supports resume)
# --------------------------------------------------------------------------------------------------

def run_episode(
    cfg: GenerateConfig,
    env,
    naming_step_desc: Sequence[str],   # For video naming
    model_step_desc: Sequence[str],    # For model input
    step_states: Sequence[np.ndarray] | None,
    model,
    goal: Any,
    resize_size,
    processor=None,
    action_head=None,
    proprio_projector=None,
    noisy_action_projector=None,
    log_file=None,
    episode_idx: int = 0,
    distractor_info: Optional[Dict[str, Any]] = None,
    task_line: str | None = None,
    task_name: str = "",
    wait_flag=True
) -> Tuple[bool, int, int]:
    """Run a single evaluation episode."""
    # ---- 0. Calculate the number of segments; when using complete+suffix, use the suffixed description length ----
    if cfg.task_description_suffix:
        segment_count = len(naming_step_desc)
    else:
        segment_count = len(model_step_desc)

    # -------- 0. Preprocess full description (always from canonical) --------
    if cfg.complete_description:
        # ★ Only use the line starting with "Task:" (with the prefix removed); if missing, use an empty string as a placeholder
        full_description = task_line or ""
    else:
        full_description = None
    obs = env.reset()

    if cfg.dynamic and distractor_info:
        step_addr_y        = distractor_info["step_addr"]
        step_base_y        = distractor_info["step_base"]
        unrelated_set      = distractor_info["unrel"]
        rng                = np.random.default_rng()
        toggle_dir         = -1
        seg_mid_moved      = False
        resume_trigger_step = None    # Record the timestep 't' when each resume occurs

    # -------- 1. Handle the initial state for dynamic shift / resume --------
    init_state_idx = 0
    skip_increment = False
    pending_catch_up = False

    if cfg.dynamic_shift_description:
        init_state_idx = 1
        skip_increment = True
        if cfg.resume:
            pending_catch_up = True

    if (cfg.resume or cfg.dynamic_shift_description) and step_states:
        print(init_state_idx)
        env.sim.set_state_from_flattened(step_states[init_state_idx])
        env.skip_pick_quat_once = True
        env.sim.forward(); env._post_process(); env._update_observables(force=True)
        obs = env._get_observations()

    # Initialize action queue
    if cfg.num_open_loop_steps != NUM_ACTIONS_CHUNK:
        print(f"WARNING: cfg.num_open_loop_steps ({cfg.num_open_loop_steps}) does not match the NUM_ACTIONS_CHUNK "
               "{NUM_ACTIONS_CHUNK} constant defined in prismatic.vla.constants! For best performance (in terms of "
               "both speed and success rate), we recommend executing the full action chunk.")
    action_queue = deque(maxlen=cfg.num_open_loop_steps)

    # -------- 2. Initialize statistics --------
    total_agent_subtasks = 0
    total_goals = sum(len(v) for v in goal.values()) if goal else 0
    total_resume_skipped = 0

    seg_increment_accum = 0
    just_resumed = cfg.dynamic_shift_description
    resume_success_flag = False

    action_queue: deque[np.ndarray] = deque(maxlen=cfg.num_open_loop_steps)
    replay_images_all: List[np.ndarray] = []
    replay_images_seg: List[np.ndarray] = []

    t = 0
    max_steps = cfg.switch_steps * segment_count

    prev_step_idx = 0

    # Initial completion baseline
    if not wait_flag:
        comp_start_dict, total_completed_prev, _ = env._check_success(goal)
    # ---------------------- Main control loop ----------------------
    while t < max_steps:
        # (a) Initial waiting period
        if t < cfg.num_steps_wait and wait_flag:
            obs, _, _, _ = env.step(get_libero_dummy_action(cfg.model_family))
            t += 1
            continue
        if t == cfg.num_steps_wait and wait_flag:
            comp_start_dict, total_completed_prev, _ = env._check_success(goal)

        # (b) Calculate segment index
        step_idx = (t // cfg.switch_steps) % segment_count

        # print(step_descriptions[step_idx])
        # Record the starting time of each segment (used to determine the midpoint for dynamic movement)
        if t % cfg.switch_steps == 0:
            seg_start_t   = t
            seg_mid_moved = False
        # (c) Handle segment switching
        # print(step_descriptions[prev_step_idx])
        if t > 0 and step_idx != prev_step_idx:
            # Save the previous segment's video
            seg_success = seg_increment_accum
            log_message(f"[Segment {prev_step_idx}] Subtasks completed = {seg_success}", log_file)
            suffix = "_complete_description" if cfg.complete_description else ""
            name = f"{episode_idx}_step{prev_step_idx}{suffix}"
            # Save the video and get its path
            mp4_path = save_rollout_video(
                replay_images_seg,
                name,
                success=seg_success > 0,
                task_description=naming_step_desc[prev_step_idx],
                log_file=log_file,
                task_name=task_name,
            )
            comp_end_dict, _, _ = env._check_success(goal)
            completed_objects = [
                obj for obj, rate in comp_end_dict.items()
                if rate > comp_start_dict.get(obj, 0)
            ]

            json_path = mp4_path.rsplit('.', 1)[0] + ".json"
            with open(json_path, "w", encoding="utf-8") as jf:
                json.dump({
                    "video": os.path.basename(mp4_path),
                    "completed_subtasks": completed_objects,
                    "success": len(completed_objects) > 0
                }, jf, ensure_ascii=False, indent=2)

            # Prepare the starting completion baseline for the next segment
            comp_start_dict = comp_end_dict
            # Clear the states related to switching segments
            replay_images_seg.clear()
            seg_increment_accum = 0
            just_resumed = False

            # Dynamic shift / resume logic
            do_resume_now = False
            if pending_catch_up:
                if step_idx == 1:
                    skip_increment = False
                else:
                    do_resume_now = True
                    pending_catch_up = False
            else:
                do_resume_now = cfg.resume and step_states is not None
                # do_resume_now = (
                #     cfg.resume and step_states is not None and
                #     env._check_success(goal)[1] == total_completed_prev
                # )

            if do_resume_now and step_states is not None:
                # print('resume', step_states[step_idx])
                env.sim.set_state_from_flattened(step_states[step_idx])
                env.sim.forward(); env._post_process(); env._update_observables(force=True)
                obs = env._get_observations()
                env.skip_pick_quat_once = True 
                skip_increment = True
                just_resumed = True
                resume_trigger_step = t
                log_message(f"[Resume] Reverting to the initial state of Step {step_idx}", log_file)

        prev_step_idx = step_idx

        if (
            cfg.dynamic
            and distractor_info
            and not seg_mid_moved
            and resume_trigger_step is not None
            and t == resume_trigger_step + 10
        ):
            # 3.1 Decide whether to use a related or an unrelated object
            use_related = rng.random() < 0.5
            if use_related and step_idx < len(step_addr_y) and step_addr_y[step_idx] is not None:
                addr_y = step_addr_y[step_idx]
                base_y = step_base_y[step_idx]
                target_name = "related"
            else:
                # ★ Fix: Use Python's random.choice to avoid converting to NumPy float64
                addr_y, base_y = random.choice(unrelated_set)
                target_name = "unrelated"

            # 3.2 Execute movement
            offset = 0.05 * toggle_dir
            env.sim.data.qpos[addr_y] = base_y + offset
            env.sim.forward(); env._post_process(); env._update_observables(force=True)
            obs = env._get_observations()
            log_message(f"[Dynamic] Moved {target_name} object at step {step_idx}, Δy={offset:+.2f}", log_file)
            seg_mid_moved = True
            toggle_dir *= -1
            resume_trigger_step = None

        # (d) Policy inference & env.step
        observation, img = prepare_observation(obs, resize_size)
        replay_images_all.append(img)
        replay_images_seg.append(img)
        
        if not action_queue:
            # Set the prompt for the model: use the suffixed description if complete_description is not enabled, else use model_step_desc
            if cfg.task_description_suffix != "" and not cfg.complete_description:
                desc = naming_step_desc[step_idx]
            else:
                # Otherwise, use the original logic: for complete, use full_description; else use model_step_desc
                desc = full_description if cfg.complete_description else model_step_desc[step_idx]

            # print(desc)
            
            # print(desc)
            actions = get_action(
                cfg,
                model,
                observation,
                desc,
                processor,
                action_head,
                proprio_projector,
                noisy_action_projector,
                use_film=cfg.use_film,
            )
            action_queue.extend(actions)
        raw_action = action_queue.popleft()
        # print(desc)
        obs, _, _, _ = env.step(process_action(raw_action, cfg.model_family).tolist())
        t += 1

        # (e) Call check_success at each step
        _, total_completed_now, _ = env._check_success(goal)
        diff = total_completed_now - total_completed_prev

        if diff > 0:
            step_no = step_idx + 1
            if skip_increment:
                total_resume_skipped += diff
                resume_success_flag = True
                log_message(f"[Step {step_no}] (Skip) Directly completed {diff} subtasks by resuming at the first frame", log_file)
            else:
                total_agent_subtasks += diff
                seg_increment_accum += diff
                log_message(f"[Step {step_no}] Completed {diff} new subtasks; current segment accumulated {seg_increment_accum}", log_file)

        total_completed_prev = total_completed_now
        if skip_increment:
            skip_increment = False

    # --------------------- Episode termination handling ---------------------
    # Do not reset the last segment to zero because of just_resumed; retain the subtasks actually completed in the segment
    seg_success = seg_increment_accum
    seg_no = prev_step_idx + 1
    log_message(f"[Segment {seg_no}] Subtasks counted = {seg_success} (final segment)", log_file)
    suffix = "_complete_description" if cfg.complete_description else ""
    name = f"{episode_idx}_step{prev_step_idx}{suffix}"
    save_rollout_video(
        replay_images_seg,
        name,
        success=seg_success > 0,
        # Use the suffixed naming_step_desc for video naming / log description
        task_description=naming_step_desc[prev_step_idx],
        log_file=log_file,
    )

    # ---- Determine overall task success ----
    if cfg.resume and resume_success_flag:
        all_done = False
    else:
        _, _, all_done = env._check_success(goal)

    suffix = "_complete_description" if cfg.complete_description else ""
    name = f"{episode_idx}{suffix}"
    mp4_path = save_rollout_video(
        replay_images_all,
        name,
        success=all_done,
        task_description=naming_step_desc[prev_step_idx],
        log_file=log_file,
        task_name=task_name,
    )
    # Generate JSON for the complete video, recording all high-level subtasks successfully completed in the episode
    json_path = mp4_path.rsplit('.', 1)[0] + ".json"
    with open(json_path, "w", encoding="utf-8") as jf:
        json.dump({
            "video": os.path.basename(mp4_path),
            "episode_success": all_done
        }, jf, ensure_ascii=False, indent=2)
    effective_total_goals = total_goals
    # print(effective_total_goals)
    return all_done, total_agent_subtasks, effective_total_goals


# --------------------------------------------------------------------------------------------------
# Task-level loop
# --------------------------------------------------------------------------------------------------

def parse_task_description(txt_path: str) -> Tuple[List[str], List[int]]:
    """Parse task_description*.txt and return (step_descriptions, start_indices)."""
    step_desc: list[str] = []
    start_indices: list[int] = []
    BRACKET_RE = re.compile(r"\[\s*(\d+)\s*,\s*(\d+)\s*\]")
    with open(txt_path, "r", encoding="utf-8") as f:
        lines = [ln.strip() for ln in f if ln.strip()]
    i = 0
    while i < len(lines):
        line = lines[i]
        if line.startswith("Step"):
            # Step line
            desc = line.split(":", 1)[1].strip()
            step_desc.append(desc)
            # The next line should be [start, end]
            if i + 1 < len(lines):
                m = BRACKET_RE.match(lines[i + 1])
                if m:
                    start_indices.append(int(m.group(1)))
                    i += 1  # Skip the bracket line
            i += 1
        else:
            i += 1
    return step_desc, start_indices


def run_task(
    cfg: GenerateConfig,
    dir_path: str,
    bddl_file_path: str,
    goal: Any,
    model,
    resize_size,
    processor=None,
    action_head=None,
    proprio_projector=None,
    noisy_action_projector=None,
    log_file=None,
) -> Tuple[int, int, int, int]:
    """Evaluate a single task directory and return (episodes, number of successful episodes, total agent subtasks completed, total possible subtasks)."""

    # Environment initialization
    problem_info = BDDLUtils.get_problem_info(bddl_file_path)
    problem_name = problem_info["problem_name"]
    controller_config = load_controller_config(default_controller="OSC_POSE")
    env = TASK_MAPPING[problem_name](
        bddl_file_name=bddl_file_path,
        robots=["Panda"],
        controller_configs=controller_config,
        has_renderer=False,
        has_offscreen_renderer=True,
        camera_names=["agentview", "robot0_eye_in_hand"],
        ignore_done=True,
        use_camera_obs=True,
        reward_shaping=True,
        camera_heights=256,
        camera_widths=256,
        control_freq=20,
    )

    # Read demo.hdf5
    h5_path = os.path.join(dir_path, "demo.hdf5")
    with h5py.File(h5_path, "r") as h5f:
        demo = h5f["data"]["demo_1"]
        orig_states = demo["states"][()]

    # 1) First, read the description with suffix (used for video naming)
    if cfg.task_description_suffix:
        suff_txt = os.path.join(dir_path, f"task_description{cfg.task_description_suffix}.txt")
        if not os.path.isfile(suff_txt):
            raise FileNotFoundError(f"Task description file not found: {suff_txt}")
        naming_step_desc, suff_start_idx = parse_task_description(suff_txt)
    else:
        canon_txt = os.path.join(dir_path, "task_description.txt")
        if not os.path.isfile(canon_txt):
            raise FileNotFoundError(f"Task description file not found: {canon_txt}")
        naming_step_desc, suff_start_idx = parse_task_description(canon_txt)

    # Dynamic distractor support
# ★------------------------------------------------------------------
# Dynamic distractor support (new version, reads task_description*.json)
# ★------------------------------------------------------------------
    distractor_info = None
    if cfg.dynamic and cfg.resume:
        # Read JSON
        if cfg.task_description_suffix:
            json_name = f"task_description{cfg.task_description_suffix}.json"
        else:
            json_name = "task_description.json"
        json_path = os.path.join(dir_path, json_name)
        if not os.path.isfile(json_path):
            log_message(f"[WARN] {json_path} not found, dynamic functionality disabled.", log_file)
        else:
            # Parse step -> object mapping
            step_objects = _load_step_objects(json_path, naming_step_desc)

            # For each step, collect the movable related object's address/base
            step_addr_y: List[Optional[int]]   = []
            step_base_y: List[Optional[float]] = []
            for obj_name in step_objects:
                if not obj_name or obj_name not in MOVABLE_OBJECT_LIST:
                    step_addr_y.append(None)
                    step_base_y.append(None)
                    continue
                addr = _find_obj_y_addr(env.sim, obj_name)
                if addr is None:
                    log_message(f"[WARN] Could not find joint for {obj_name}, ignoring the related object.", log_file)
                    step_addr_y.append(None)
                    step_base_y.append(None)
                else:
                    step_addr_y.append(addr)
                    step_base_y.append(env.sim.data.qpos[addr].copy())

            # Collect the pool of unrelated movable objects
            unrelated_addr = []
            for name in MOVABLE_OBJECT_LIST:
                if name in step_objects:
                    continue
                addr = _find_obj_y_addr(env.sim, name)
                print(name)
                if addr is not None:
                    unrelated_addr.append((addr, env.sim.data.qpos[addr].copy()))

            if any(a is not None for a in step_addr_y) and unrelated_addr:
                distractor_info = {
                    "step_addr": step_addr_y,
                    "step_base": step_base_y,
                    "unrel": unrelated_addr,
                }
            else:
                log_message("[WARN] Insufficient dynamic information (no related or unrelated objects), feature disabled.", log_file)

    # 2) Then, read the canonical description (for model input)
    canon_txt = os.path.join(dir_path, "task_description.txt")
    if not os.path.isfile(canon_txt):
        raise FileNotFoundError(f"Task description file not found: {canon_txt}")
    model_step_desc, canon_start_idx = parse_task_description(canon_txt)

    if cfg.task_description_suffix and len(naming_step_desc) == 0:
        total_goals = sum(len(v) for v in goal.values()) if goal else 0
        log_message(f"[WARN] The suffixed task description file {suff_txt} is empty; all trials will be considered as failures", log_file)
        episodes = cfg.num_trials_per_task
        successes = 0
        task_agent_subtasks = 0
        # Total possible subtasks = episodes × number of steps in the canonical description
        task_possible_subtasks = episodes * total_goals
        return episodes, successes, task_agent_subtasks, task_possible_subtasks

    task_line = ""
    with open(canon_txt, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln.startswith("Task:"):
                task_line = ln.split(":", 1)[1].strip()
                break

    # 3) Choose which set of start_indices to use: use suff when resume+suffix is enabled, otherwise use canon
    if cfg.resume and cfg.task_description_suffix:
        start_indices = suff_start_idx
    else:
        start_indices = canon_start_idx

    # 4) If the number of naming descriptions and indices are inconsistent, truncate and warn
    if len(naming_step_desc) != len(start_indices):
        min_len = min(len(naming_step_desc), len(start_indices))
        log_message(
            f"[WARN] Number of video naming descriptions ({len(naming_step_desc)}) and indices ({len(start_indices)}) are inconsistent, truncating to {min_len}",
            log_file
        )
        naming_step_desc = naming_step_desc[:min_len]
        start_indices     = start_indices[:min_len]

    if cfg.dynamic_shift_description and len(start_indices) == 1:
        log_message(
            f"[WARN] dynamic_shift_description=True and only one start index available, treating all episodes as failures ({os.path.basename(dir_path)})",
            log_file
        )
        episodes = cfg.num_trials_per_task
        successes = 0
        task_agent_subtasks = 0
        # Total possible subtasks: number of episodes × number of subtasks per episode
        task_possible_subtasks = episodes * len(model_step_desc)
        return episodes, successes, task_agent_subtasks, task_possible_subtasks

    # print(start_indices)
    wait_flag = True
    if start_indices[0] != 0:
        wait_flag = False
    # 5) Construct the step_states list
    step_states = [orig_states[idx] for idx in start_indices]

    # Episode loop
    episodes               = cfg.num_trials_per_task
    successes              = 0
    task_agent_subtasks    = 0
    task_possible_subtasks = 0

    task_name = os.path.basename(dir_path)
    for ep_idx in tqdm.tqdm(range(episodes)):
        succ, ep_subtasks, ep_goals = run_episode(
            cfg,
            env,
            naming_step_desc,
            model_step_desc,
            step_states,
            model,
            goal,
            resize_size,
            processor,
            action_head,
            proprio_projector,
            noisy_action_projector,
            log_file,
            episode_idx=ep_idx,
            distractor_info=distractor_info,
            task_line=task_line,
            task_name=task_name,          # Pass the subtask name
            wait_flag=wait_flag
        )
        successes               += int(succ)
        task_agent_subtasks    += ep_subtasks
        task_possible_subtasks += ep_goals

    return episodes, successes, task_agent_subtasks, task_possible_subtasks



# --------------------------------------------------------------------------------------------------
# Main entry point
# --------------------------------------------------------------------------------------------------

@draccus.wrap()  # Keep only one argument
def eval_libero(cfg: GenerateConfig) -> float:
    validate_config(cfg)
    root_dir = cfg.root_dir
    set_seed_everywhere(cfg.seed)
    model, action_head, proprio_projector, noisy_action_projector, processor = initialize_model(cfg)
    resize_size = get_image_resize_size(cfg)
    log_file, _, _ = setup_logging(cfg)
    log_message(f"Starting evaluation: {root_dir}", log_file)

    total_eps = 0
    total_success = 0
    total_agent_subtasks = 0
    total_possible_subtasks = 0

    for subdir in sorted(os.listdir(root_dir)):
        dir_path = os.path.join(root_dir, subdir)
        if not os.path.isdir(dir_path):
            continue
        bddl_file = next((os.path.join(dir_path, fn) for fn in os.listdir(dir_path) if fn.endswith(".bddl")), None)
        if not bddl_file:
            log_message(f"[WARN] .bddl file not found in {dir_path}, skipping.", log_file)
            continue
        goal_json = os.path.join(dir_path, "goal.json")
        goal = load_actions(goal_json) if os.path.isfile(goal_json) else None

        eps, succ, subtasks, possible = run_task(
            cfg,
            dir_path,
            bddl_file,
            goal,
            model,
            resize_size,
            processor,
            action_head,
            proprio_projector,
            noisy_action_projector,
            log_file,
        )
        # Update cumulative metrics
        total_eps += eps
        total_success += succ
        total_agent_subtasks += subtasks
        total_possible_subtasks += possible

        # After each task, report the current cumulative results
        subtask_success_rate = (total_agent_subtasks / total_possible_subtasks) if total_possible_subtasks else 0.0
        episode_success_rate = (total_success / total_eps) if total_eps else 0.0
        log_message(
            f"[Task {subdir}] Cumulative subtask success rate: {subtask_success_rate:.2%} "
            f"({total_agent_subtasks}/{total_possible_subtasks}), "
            f"Cumulative overall task success rate: {episode_success_rate:.2%} "
            f"({total_success}/{total_eps})",
            log_file
        )

    # Final report (optional)
    subtask_success_rate = (total_agent_subtasks / total_possible_subtasks) if total_possible_subtasks else 0.0
    episode_success_rate = (total_success / total_eps) if total_eps else 0.0
    log_message(f"Final subtask success rate: {subtask_success_rate:.2%}", log_file)
    log_message(f"Final overall task success rate: {episode_success_rate:.2%}", log_file)

    # For compatibility with the original interface, still return the overall task success rate
    return episode_success_rate



if __name__ == "__main__":
    eval_libero()