#!/usr/bin/env python3
"""
Experiment 1b: FE Calibration-Only Evaluation for LaNE Tasks
=============================================================
Use n demo trajectories to calibrate FE without any finetuning for LaNE robosuite tasks.
Tests zero-shot adaptation capability of Function Encoder on robosuite environments.

Based on:
- exp1_fe_calibration_only.py: New FE calibration approach using CalibrationManager
- eval_lane_old.py: LaNE task evaluation logic for robosuite environments
"""

import argparse
import glob
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import cv2
import imageio
import numpy as np
import robosuite as suite
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from peft import LoraConfig, get_peft_model
from robosuite.controllers import load_controller_config
from robosuite.utils.placement_samplers import UniformRandomSampler
from tqdm import tqdm

# Add project paths
sys.path.append(".")
sys.path.append("experiments/robot")

from transformers import AutoConfig, AutoImageProcessor, AutoModelForVision2Seq, AutoProcessor

from experiments.robot.openvla_utils import (
    check_model_logic_mismatch,
    model_is_on_hf_hub,
    update_auto_map,
)
from experiments.robot.robot_utils import (
    invert_gripper_action,
    normalize_gripper_action,
    set_seed_everywhere,
)
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
from prismatic.models.action_heads import FunctionEncoderActionHead
from prismatic.models.backbones.llm.prompting import PurePromptBuilder
from prismatic.training.train_utils import get_current_action_mask, get_next_actions_mask
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.calibration_buffer import CalibrationManager
from prismatic.vla.constants import ACTION_DIM, NUM_ACTIONS_CHUNK
from prismatic.vla.datasets.datasets import RLDSBatchTransform
from prismatic.vla.materialize import get_vla_dataset_and_collator

NORMALIZE_ACTION = True


class LaNEFECalibrationEvaluator:
    """Evaluator for FE calibration-only experiments on LaNE tasks."""

    def __init__(
        self,
        vla_path: str = "openvla/openvla-7b",
        fe_checkpoint_path: Optional[str] = None,
        fe_basis_functions: int = 32,
        n_continuous_actions: int = 6,
        device: str = "cuda",
        seed: int = 42,
        use_lora: bool = True,
        lora_rank: int = 32,
    ):
        """Initialize evaluator with FE model."""
        set_seed_everywhere(seed)
        self.device = device
        self.seed = seed
        self.fe_basis_functions = fe_basis_functions
        self.n_continuous_actions = n_continuous_actions
        self.demo_action_stats: Optional[Dict[str, Dict]] = None

        # Load model following exp1 approach
        print(f"Loading VLA model from: {vla_path}")
        self.vla, self.processor, self.fe_action_head = self._load_model(
            vla_path, fe_checkpoint_path, use_lora, lora_rank
        )

    def _load_model(
        self, vla_path: str, fe_checkpoint_path: Optional[str], use_lora: bool, lora_rank: int
    ) -> Tuple[nn.Module, object, nn.Module]:
        """Load FE-enabled OpenVLA model following exp1 approach."""

        # Handle HuggingFace Hub models
        if model_is_on_hf_hub(vla_path):
            vla_download_path = snapshot_download(repo_id=vla_path)
            vla_path = vla_download_path
        else:
            # Register models if loading locally
            AutoConfig.register("openvla", OpenVLAConfig)
            AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
            AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
            AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

        # Update config
        update_auto_map(vla_path)
        check_model_logic_mismatch(vla_path)

        # Load processor and VLA
        processor = AutoProcessor.from_pretrained(vla_path, trust_remote_code=True)
        vla = AutoModelForVision2Seq.from_pretrained(
            vla_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True,
        ).to(self.device)

        # Set number of images in input
        vla.vision_backbone.set_num_images_in_input(1)

        # Get model dimensions
        vla_module = vla.module if hasattr(vla, "module") else vla
        llm_dim = vla_module.llm_dim
        ACTION_DIM = 7  # 6 DoF + gripper

        # Create Function Encoder Action Head
        print(f"Creating FunctionEncoderActionHead with k={self.fe_basis_functions}")
        fe_action_head = FunctionEncoderActionHead(
            input_dim=llm_dim * ACTION_DIM,
            hidden_dim=llm_dim,
            action_dim=ACTION_DIM,
            k=self.fe_basis_functions,
            n_continuous=self.n_continuous_actions,
        ).to(self.device, dtype=torch.bfloat16)

        # Apply LoRA if requested
        if use_lora:
            print(f"Applying LoRA with rank {lora_rank}")
            lora_config = LoraConfig(
                r=lora_rank,
                lora_alpha=min(lora_rank, 16),
                lora_dropout=0.0,
                target_modules="all-linear",
                init_lora_weights="gaussian",
            )
            vla = get_peft_model(vla, lora_config)
            vla.print_trainable_parameters()

        # Load checkpoint if provided (following exp1 format)
        if fe_checkpoint_path:
            ckpt_dir = Path(fe_checkpoint_path)
            print(f"Loading FE checkpoint directory: {ckpt_dir}")

            # Try to infer step from directory name or contained files
            step = None
            # Parse from directory name like "checkpoint-40000"
            name_parts = ckpt_dir.name.split("-")
            if len(name_parts) > 1 and name_parts[-1].isdigit():
                step = int(name_parts[-1])
            if step is None:
                # Fallback: scan for fe_action_head--*_checkpoint.pt
                matches = list(ckpt_dir.glob("fe_action_head--*_checkpoint.pt"))
                if matches:
                    stem = matches[0].stem  # e.g., fe_action_head--40000_checkpoint
                    try:
                        step = int(stem.split("--")[1].split("_")[0])
                    except Exception:
                        step = None

            if step is None:
                print("Could not infer checkpoint step; skipping checkpoint load.")
            else:
                # Load VLA state (prefer LoRA-only state if present)
                def _add_default_to_lora_keys(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
                    """Adjust PEFT LoRA key names to include `.default` for compatibility with PEFT>=0.10."""
                    new_state = {}
                    for k, v in state_dict.items():
                        k_new = k
                        if "lora_A" in k and ".default" not in k:
                            k_new = k.replace("lora_A", "lora_A.default")
                        if "lora_B" in k_new and ".default" not in k_new:
                            k_new = k_new.replace("lora_B", "lora_B.default")
                        new_state[k_new] = v
                    return new_state

                vla_lora_path = ckpt_dir / f"vla_lora--{step}_checkpoint.pt"
                vla_full_path = ckpt_dir / f"vla--{step}_checkpoint.pt"
                if vla_lora_path.exists():
                    print(f"Loading VLA LoRA weights: {vla_lora_path}")
                    vla_state = torch.load(vla_lora_path, map_location=self.device)
                    vla_state = _add_default_to_lora_keys(vla_state)
                    missing, unexpected = vla.load_state_dict(vla_state, strict=False)
                    if missing:
                        pass  # It is ok because we are loading only LoRA weights
                    if unexpected:
                        print(f"[LoRA] Unexpected keys after load: {len(unexpected)}:")
                        for k in unexpected:
                            print(f"  {k}")
                        raise RuntimeError("Unexpected keys in LoRA checkpoint.")
                elif vla_full_path.exists():
                    print(f"Loading full VLA weights: {vla_full_path}")
                    vla_state = torch.load(vla_full_path, map_location=self.device)
                    vla.load_state_dict(vla_state, strict=True)
                else:
                    print("No VLA weight file found in checkpoint directory.")

                # Load FE action head weights
                fe_path = ckpt_dir / f"fe_action_head--{step}_checkpoint.pt"
                if fe_path.exists():
                    print(f"Loading FE action head weights: {fe_path}")
                    fe_state = torch.load(fe_path, map_location=self.device)
                    fe_action_head.load_state_dict(fe_state, strict=True)
                else:
                    raise FileNotFoundError(f"No FE action head weight file found in checkpoint directory: {fe_path}")

                # Load calibration state (dataset coefficients)
                calib_path = ckpt_dir / f"calibration_state--{step}.pt"
                if calib_path.exists():
                    print(f"Loading calibration state: {calib_path}")
                    calibration_state = torch.load(calib_path, map_location=self.device)
                    dataset_coeffs = calibration_state.get("dataset_coefficients", {})
                    for ds_name, coeffs in dataset_coeffs.items():
                        if isinstance(coeffs, dict) and "l1" in coeffs and "l2" in coeffs:
                            fe_action_head.set_dataset_coefficients(ds_name, coeffs["l1"], coeffs["l2"])
                else:
                    print("No calibration state file found in checkpoint directory.")

        vla.eval()
        fe_action_head.eval()

        return vla, processor, fe_action_head

    def create_robosuite_env(self, task_name: str):
        """Create robosuite environment for the specified task."""

        # Map task names to robosuite environment names
        task_mapping = {"door": "Door", "lift": "Lift", "pick_place_can": "PickPlaceCan", "stack": "Stack"}
        horizon_map = {"door": 200, "lift": 100, "pick_place_can": 300, "stack": 200}

        if task_name not in task_mapping:
            raise ValueError(f"Unknown task: {task_name}. Available tasks: {list(task_mapping.keys())}")

        env_name = task_mapping[task_name]
        horizon = horizon_map[task_name]

        # Controller configuration
        controller_config = load_controller_config(default_controller="OSC_POSE")

        # Create environment (matching LaNE configuration)
        env_kwargs = {
            "env_name": env_name,
            "robots": "Panda",
            "controller_configs": controller_config,
            "initialization_noise": None,
            "camera_names": ["frontview", "robot0_eye_in_hand"],
            "camera_heights": 512,
            "camera_widths": 512,
            "control_freq": 10,
            "horizon": horizon,
        }

        env = suite.make(**env_kwargs)

        # Robot reset / initialization behavior
        # Set to None to disable randomness in initial joint configuration
        # INITIALIZATION_NOISE = None

        # env = suite.make(
        #     env_name="Lift",
        #     robots="Panda",
        #     controller_configs=config,
        #     initialization_noise=INITIALIZATION_NOISE,
        #     camera_names=["frontview", "robot0_eye_in_hand"],
        #     camera_heights=224,
        #     camera_widths=224,
        #     control_freq=10,
        #     horizon=40,
        # )

        # Constrain cube placement tightly around table center
        if task_name == "lift":
            env.placement_initializer = UniformRandomSampler(
                name="ObjectSampler",
                x_range=(-0.12, -0.12),
                y_range=(-0.025, 0.025),
                rotation=(0, 0),
                ensure_object_boundary_in_range=False,
                ensure_valid_placement=True,
                reference_pos=np.array((0.0, 0.0, 0.8)),
                z_offset=0.01,
            )
        elif task_name == "door":
            env.placement_initializer.x_range = (0.02, 0.02)
            env.placement_initializer.rotation = (-1.85, -1.85)

        return env

    def load_lane_demo_dataset(self, task_name: str, num_demos: int = 1) -> List:
        """Load demonstration trajectories for LaNE tasks."""
        print(f"Loading {num_demos} demonstrations for task: {task_name}")

        # Define demo paths based on LaNE format
        demo_folder_map = {
            "door": "./LaNE/demo/robosuite_door/1",
            "lift": "./LaNE/demo/robosuite_lift/1",
            "pick_place_can": "./LaNE/demo/robosuite_pick_place_can/10",
            "stack": "./LaNE/demo/robosuite_stack/10",
        }

        demo_folder = demo_folder_map.get(task_name)
        if not demo_folder:
            raise ValueError(f"Unknown task: {task_name}")

        demonstrations = []

        # Check if demo files exist
        demo_file = os.path.join(demo_folder, "0_*.pt")
        demo_files = glob.glob(demo_file)

        assert len(demo_files) == 1, "Only one demo file is supported"

        # Load actual demo files
        demo_file = demo_files[0]  # Take the first (and likely only) demo file
        print(f"Loading demo from: {demo_file}")

        # Load the pytorch data file
        payload = torch.load(demo_file, map_location="cpu", weights_only=False)
        obs_list, _next_obs_list, action_list, _reward_list, _not_done_list = payload

        # Load demo boundaries
        demo_starts = np.load(os.path.join(demo_folder, "demo_starts.npy"))
        demo_ends = np.load(os.path.join(demo_folder, "demo_ends.npy"))

        print(f"Total steps: {len(obs_list)}")

        # Extract individual demonstrations
        for i in range(min(num_demos, len(demo_starts))):
            start_idx = demo_starts[i]
            end_idx = demo_ends[i]

            demo_images = obs_list[start_idx:end_idx]  # Shape: (T, 6, H, W)
            demo_actions = action_list[start_idx:end_idx]  # Shape: (T, 7)

            # Convert images to format expected by VLA (T, H, W, C)
            # LaNE format: (T, 6, H, W) -> (T, H, W, 6)
            demo_images = demo_images.transpose(0, 2, 3, 1)

            demo_data = {"images": demo_images, "actions": demo_actions, "length": len(demo_images)}
            demonstrations.append(demo_data)

            print(f"Demo {i + 1}: {len(demo_images)} steps")

        # Compute q01/q99 normalization stats across all demos (exclude gripper)
        all_actions = []
        for d in demonstrations:
            all_actions.append(d["actions"])
        if len(all_actions) > 0:
            actions_arr = np.concatenate(all_actions, axis=0).astype(np.float32)
            q01 = np.quantile(actions_arr, 0.01, axis=0)
            q99 = np.quantile(actions_arr, 0.99, axis=0)
            mask = np.array([True] * 6 + [False], dtype=bool)
            self.demo_action_stats = {"action": {"q01": q01, "q99": q99, "mask": mask}}
            print("Computed demo action stats (q01/q99) for normalization")
        else:
            raise ValueError("No demonstration actions found to compute normalization stats.")

        return demonstrations

    def _normalize_action_with_demo_stats(self, action: np.ndarray) -> np.ndarray:
        """Normalize action to [-1, 1] using q01/q99 for first 6 dims; leave gripper unchanged."""
        assert self.demo_action_stats is not None and "action" in self.demo_action_stats
        stats = self.demo_action_stats["action"]
        q01 = stats["q01"].astype(np.float32)
        q99 = stats["q99"].astype(np.float32)
        mask = stats.get("mask", np.array([True] * 6 + [False], dtype=bool))
        out = action.astype(np.float32).copy()
        denom = (q99 - q01) + 1e-8
        out[mask] = 2.0 * (np.clip(out[mask], q01[mask], q99[mask]) - q01[mask]) / denom[mask] - 1.0
        return out

    def _unnormalize_action_with_demo_stats(self, action: np.ndarray) -> np.ndarray:
        """Unnormalize from [-1, 1] using q01/q99 for first 6 dims; leave gripper unchanged."""
        assert self.demo_action_stats is not None and "action" in self.demo_action_stats
        stats = self.demo_action_stats["action"]
        q01 = stats["q01"].astype(np.float32)
        q99 = stats["q99"].astype(np.float32)
        mask = stats.get("mask", np.array([True] * 6 + [False], dtype=bool))
        out = action.astype(np.float32).copy()
        out[mask] = 0.5 * (out[mask] + 1.0) * (q99[mask] - q01[mask]) + q01[mask]
        return out

    def _get_dataset_and_collator(self):
        """Create dataset and collator for calibration."""
        # Use get_vla_dataset_and_collator for proper data processing
        _dataset, _action_tokenizer, collator = get_vla_dataset_and_collator(
            data_root_dir=Path("DATA_ROOT_DIR"),
            data_mix="viola",  # closest match
            image_transform=self.processor.image_processor.apply_transform,
            tokenizer=self.processor.tokenizer,
            prompt_builder_fn=PurePromptBuilder,
            default_image_resolution=(3, 224, 224),
            padding_side="right",
            predict_stop_token=True,
            shuffle_buffer_size=1_000,
            train=True,
            episodic=False,
            image_aug=False,
        )
        return collator

    def calibrate_fe(
        self, demonstrations: List[Dict], task_description: str, task_name: str, disable_rotation: bool = False
    ):
        """Calibrate FE using demonstration trajectories with CalibrationManager."""
        print(f"Calibrating FE with {len(demonstrations)} demonstrations...")

        # Get collator and set up RLDSBatchTransform (mirrors pretrain_fe.py)
        collator = self._get_dataset_and_collator()
        action_tokenizer = ActionTokenizer(self.processor.tokenizer)
        batch_transform = RLDSBatchTransform(
            action_tokenizer=action_tokenizer,
            base_tokenizer=self.processor.tokenizer,
            image_transform=self.processor.image_processor.apply_transform,
            prompt_builder_fn=PurePromptBuilder,
            predict_stop_token=True,
        )

        # Prepare dataset name for this task
        dataset_name = f"lane_{task_name}"

        # Initialize CalibrationManager for this dataset
        total_samples = sum(len(d["actions"]) for d in demonstrations)
        calib_manager = CalibrationManager(dataset_names=[dataset_name], buffer_size=total_samples)

        # Convert demos to transformed samples using RLDSBatchTransform and add to CalibrationManager
        for demo in demonstrations:
            images = demo["images"]  # Shape: (T, H, W, 6)
            actions = demo["actions"]  # Shape: (T, 7)

            for t in range(len(images)):
                # Extract frontview RGB (first 3 channels) as uint8 (H, W, 3)
                img = images[t, :, :, :3].copy()

                assert img.dtype == np.uint8, f"Image {t} has unexpected dtype {img.dtype}"

                img = cv2.resize(img, (224, 224))
                img = img[32:-64, 48:-48]

                # Adjust gripper to [0, 1], normalize first 6 dims using demo stats
                action = actions[t].astype(np.float32).copy()
                action[-1] = (-action[-1] + 1) / 2  # [-1,1] -> [0,1] for gripper
                if NORMALIZE_ACTION:
                    action = self._normalize_action_with_demo_stats(action)

                # Build a minimal RLDS-like step for RLDSBatchTransform
                # Use an action chunk where the first row is current action and the rest are zeros
                action_chunk = np.zeros((NUM_ACTIONS_CHUNK, ACTION_DIM), dtype=np.float32)
                action_chunk[0] = action.astype(np.float32)

                rlds_step = {
                    "observation": {"image_primary": [img]},  # numpy (H, W, 3), uint8
                    "task": {"language_instruction": task_description.encode("utf-8")},
                    "action": action_chunk,
                    "dataset_name": dataset_name,
                }

                transformed = batch_transform(rlds_step)
                calib_manager.add_training_sample(transformed, dataset_name)

        print(f"Added {len(calib_manager.buffers[dataset_name])} samples to calibration buffer")

        # Calibrate using manager (handles hidden extraction and CVX)
        calib_manager.eval_mode = True
        if disable_rotation:
            if task_name == "lift":
                calib_manager.disable_dims = [0, 3, 4, 5]
            elif task_name == "door":
                calib_manager.disable_dims = [0, 3, 4, 5]
            else:
                calib_manager.disable_dims = [3, 4, 5]
        else:
            calib_manager.disable_dims = None
        calib_manager.calibrate_all_datasets(self.fe_action_head, self.vla, collator)

        print(f"FE calibration completed for {dataset_name}")

    def evaluate_task(
        self,
        task_name: str,
        task_description: str,
        num_episodes: int = 20,
        max_steps: int = 1000,
        save_videos: bool = False,
        video_dir: Optional[str] = None,
        disable_rotation: bool = False,
    ) -> Dict:
        """Evaluate on a specific LaNE robosuite task."""

        # Initialize environment
        env = self.create_robosuite_env(task_name)

        print(f"\nEvaluating task: {task_name}")
        print(f"Task description: {task_description}")
        print(f"Number of episodes: {num_episodes}")

        # Evaluation metrics
        total_episodes = 0
        total_successes = 0
        episode_results = []

        # Prepare dataset name and RLDS preprocessing components
        dataset_name = f"lane_{task_name}"
        collator = self._get_dataset_and_collator()
        action_tokenizer = ActionTokenizer(self.processor.tokenizer)
        batch_transform = RLDSBatchTransform(
            action_tokenizer=action_tokenizer,
            base_tokenizer=self.processor.tokenizer,
            image_transform=self.processor.image_processor.apply_transform,
            prompt_builder_fn=PurePromptBuilder,
            predict_stop_token=True,
        )

        for episode_idx in tqdm(range(num_episodes), desc="Episodes"):
            # Reset environment
            obs = env.reset()

            if task_name == "door":
                for _ in range(7):
                    obs, _, _, _ = env.step([0, 0, 1, 0, 0, 0, -1])

            # Episode variables
            t = 0
            episode_success = False
            episode_frames = []  # Store frames for video if needed

            while t < max_steps:
                # Get preprocessed image (using LaNE camera setup)
                img = obs["frontview_image"][::-1].copy()  # Flip for OpenVLA format
                # img = obs["robot0_eye_in_hand_image"][::-1].copy()  # Flip for OpenVLA format

                # Resize to 224x224
                img = cv2.resize(img, (224, 224))

                img = img[32:-64, 48:-48]
                # img = img[48:-80, 64:-64]

                # Get action from model using FE
                with torch.no_grad():
                    # Build RLDS-like step and transform via RLDSBatchTransform
                    # Use zero action chunk to create action token positions
                    action_chunk = np.zeros((NUM_ACTIONS_CHUNK, ACTION_DIM), dtype=np.float32)
                    rlds_step = {
                        "observation": {"image_primary": [img]},
                        "task": {"language_instruction": task_description.encode("utf-8")},
                        "action": action_chunk,
                        "dataset_name": dataset_name,
                    }

                    transformed = batch_transform(rlds_step)
                    batched_inputs = collator([transformed])

                    # Move to device
                    input_ids = batched_inputs["input_ids"].to(self.device)
                    attention_mask = batched_inputs["attention_mask"].to(self.device)
                    pixel_values = batched_inputs["pixel_values"].to(self.device)
                    labels = batched_inputs["labels"].to(self.device)

                    # Predict action using FE
                    with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
                        outputs = self.vla(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            pixel_values=pixel_values,
                            labels=labels,
                            output_hidden_states=True,
                        )

                        # Extract text hidden states using vision patch count
                        last_hidden_states = outputs.hidden_states[-1]
                        num_patches = (
                            self.vla.vision_backbone.get_num_patches()
                            * self.vla.vision_backbone.get_num_images_in_input()
                        )
                        text_hidden_states = last_hidden_states[:, num_patches:-1]

                        # Use masks from labels to locate action tokens
                        current_action_mask = get_current_action_mask(labels[:, 1:])
                        next_actions_mask = get_next_actions_mask(labels[:, 1:])

                        if current_action_mask.any() or next_actions_mask.any():
                            actions_hidden_states = text_hidden_states[current_action_mask | next_actions_mask]
                            actions_hidden_states = actions_hidden_states.reshape(
                                1, NUM_ACTIONS_CHUNK * ACTION_DIM, -1
                            ).to(torch.bfloat16)

                            # Predict actions with FE
                            dataset_names_batch = [dataset_name]
                            predicted_actions, _ = self.fe_action_head.predict_action(
                                actions_hidden_states, dataset_names_batch
                            )

                            # Take first action chunk and clip
                            action = predicted_actions[0, 0, :].cpu().numpy().astype(np.float32)
                            action = action.clip(-1, 1)
                            # Unnormalize first 6 dims back to original scale
                            if NORMALIZE_ACTION:
                                action = self._unnormalize_action_with_demo_stats(action)
                        else:
                            raise ValueError("No action tokens found in labels")

                # Normalize and invert gripper action
                action = normalize_gripper_action(action, binarize=True)
                action = invert_gripper_action(action)

                if disable_rotation:
                    if task_name == "lift":
                        action[0] = 0
                        action[3:6] = 0
                    elif task_name == "door":
                        gripper_pos = np.array(env.sim.data.site_xpos[env.sim.model.site_name2id("gripper0_grip_site")])
                        action[0] = 0
                        action[3:6] = 0
                        if gripper_pos[0] < -0.12:
                            action[0] = 0.1
                        if gripper_pos[2] < 0.915:
                            action[2] = 0.1
                    else:
                        action[3:6] = 0

                # Save frame for video if requested
                if save_videos and video_dir:
                    episode_frames.append(obs["frontview_image"][::-1].copy())

                # Execute action
                obs, reward, done, _info = env.step(action.tolist())

                if done:
                    if reward > 0:
                        episode_success = True
                        total_successes += 1
                    break

                t += 1

            # Save video if requested and frames were collected
            if save_videos and video_dir and episode_frames:
                try:
                    video_filename = f"episode_{episode_idx + 1:03d}_{'success' if episode_success else 'fail'}.mp4"
                    video_path = os.path.join(video_dir, video_filename)

                    imageio.mimsave(video_path, episode_frames, fps=30, quality=10, macro_block_size=1)

                    print(f"Saved video: {video_path}")
                except Exception as e:
                    print(f"Failed to save video for episode {episode_idx + 1}: {e}")

            total_episodes += 1
            episode_results.append(
                {
                    "episode": episode_idx + 1,
                    "success": episode_success,
                    "steps": t,
                }
            )

            if (episode_idx + 1) % 10 == 0:
                current_success_rate = total_successes / total_episodes
                print(
                    f"Episodes {episode_idx + 1}: Success rate = {current_success_rate:.3f}"
                    f" ({total_successes}/{total_episodes})"
                )

        # Calculate final results
        final_success_rate = total_successes / total_episodes

        results = {
            "task_name": task_name,
            "task_description": task_description,
            "num_episodes": num_episodes,
            "total_successes": total_successes,
            "success_rate": final_success_rate,
            "episode_results": episode_results,
        }

        print(f"\nFinal Success Rate: {final_success_rate:.3f} ({total_successes}/{total_episodes})")

        env.close()

        return results

    def run_experiment(
        self,
        task_name: str,
        num_demos: int,
        num_episodes: int = 50,
        output_dir: str = "runs/exp1b_lane_fe_calibration",
        save_videos: bool = False,
        disable_rotation: bool = False,
    ) -> Dict:
        """Run complete experiment for a LaNE task."""

        # Task prompts
        task_prompts = {
            "door": "Open the door.",
            "lift": "Lift the red block.",
            "pick_place_can": "Pick up the red can and place it in the far right bin.",
            "stack": "Stack the red blocks on top of the green block.",
        }

        task_description = task_prompts[task_name]

        # Load demonstrations
        demonstrations = self.load_lane_demo_dataset(task_name, num_demos)

        # Calibrate FE
        self.calibrate_fe(demonstrations, task_description, task_name, disable_rotation=disable_rotation)

        # Setup video directory if needed
        video_dir = None
        if save_videos:
            video_dir = Path(output_dir) / f"videos_{task_name}_{num_demos}demos"
            if disable_rotation:
                video_dir = Path(str(video_dir) + "_no_rotation")
            video_dir.mkdir(parents=True, exist_ok=True)
            video_dir = str(video_dir)

        # Evaluate
        results = self.evaluate_task(
            task_name=task_name,
            task_description=task_description,
            num_episodes=num_episodes,
            save_videos=save_videos,
            video_dir=video_dir,
            disable_rotation=disable_rotation,
        )

        # Add experiment metadata
        results["experiment"] = "FE Calibration-Only (LaNE)"
        results["num_demos_calibration"] = num_demos
        results["fe_basis_functions"] = self.fe_basis_functions
        results["n_continuous_actions"] = self.n_continuous_actions
        results["timestamp"] = datetime.now().isoformat()

        # Save results
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        results_filename = f"{task_name}_n{num_demos}_eps{num_episodes}"
        if disable_rotation:
            results_filename += "_no_rotation"
        results_filename += ".json"

        results_path = output_path / results_filename
        with open(results_path, "w") as f:
            json.dump(results, f, indent=2)

        print(f"\nResults saved to: {results_path}")

        return results


def main():
    parser = argparse.ArgumentParser(description="FE Calibration-Only Evaluation for LaNE Tasks")

    # Task configuration
    parser.add_argument(
        "--task_name",
        type=str,
        required=True,
        choices=["door", "lift", "pick_place_can", "stack"],
        help="Name of the robosuite task",
    )
    parser.add_argument(
        "--num_demos",
        type=int,
        default=1,
        help="Number of demonstration trajectories for calibration",
    )
    parser.add_argument(
        "--num_episodes",
        type=int,
        default=20,
        help="Number of episodes for evaluation",
    )

    # Model configuration
    parser.add_argument(
        "--vla_path",
        type=str,
        default="openvla/openvla-7b",
        help="Path to base VLA model",
    )
    parser.add_argument(
        "--fe_checkpoint_path",
        type=str,
        default=(
            "runs/fe_pretrain_1x_gram_reg_95pm_c6/"
            "fe-pretrain+oxe_magic_soup_plus_minus+k16+b12+lr-0.0001+lora-r32/checkpoint-40000"
        ),
        # default=(
        #     "runs/fe_pretrain_coef_cap_plus_reg_95pm_c6/fe-pretrain+oxe_magic_soup_plus+k16+b10+lr-0.0001+lora-r32/checkpoint-40000"
        # ),
        help="Path to FE checkpoint directory",
    )
    parser.add_argument(
        "--fe_basis_functions",
        type=int,
        default=16,
        help="Number of FE basis functions",
    )
    parser.add_argument(
        "--n_continuous_actions",
        type=int,
        default=6,
        help="Number of continuous actions",
    )

    # LoRA configuration
    parser.add_argument(
        "--use_lora",
        action="store_true",
        default=True,
        help="Use LoRA for efficient fine-tuning",
    )
    parser.add_argument(
        "--lora_rank",
        type=int,
        default=32,
        help="LoRA rank",
    )

    # Experiment configuration
    parser.add_argument(
        "--output_dir",
        type=str,
        default="eval-results/exp1b_lane_fe_calibration",
        help="Output directory for results",
    )
    parser.add_argument(
        "--save_videos",
        action="store_true",
        help="Save episode videos",
    )
    parser.add_argument(
        "--disable_rotation",
        action="store_true",
        help=("Disable rotation in action space"),
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to use",
    )

    args = parser.parse_args()

    # Initialize evaluator
    evaluator = LaNEFECalibrationEvaluator(
        vla_path=args.vla_path,
        fe_checkpoint_path=args.fe_checkpoint_path,
        fe_basis_functions=args.fe_basis_functions,
        n_continuous_actions=args.n_continuous_actions,
        device=args.device,
        seed=args.seed,
        use_lora=args.use_lora,
        lora_rank=args.lora_rank,
    )

    # Run experiment
    results = evaluator.run_experiment(
        task_name=args.task_name,
        num_demos=args.num_demos,
        num_episodes=args.num_episodes,
        output_dir=args.output_dir,
        save_videos=args.save_videos,
        disable_rotation=args.disable_rotation,
    )

    return results


if __name__ == "__main__":
    main()
