"""
Convert RLBench pick-and-place dataset to NPZ format for diffusion policy training.

Extracts REACH and CARRY phases as SEPARATE 64-step trajectories.
These are the phases learned by the diffusion policy.

Phase layout in raw data:
  - REACH: steps 0-63 (64 steps) - LEARNED
  - DESCEND: steps 64-71 (8 steps) - hard-coded
  - GRASP: steps 72-79 (8 steps) - hard-coded
  - LIFT: steps 80-87 (8 steps) - hard-coded
  - CARRY: steps 88-151 (64 steps) - LEARNED
  - DESCEND_RELEASE: steps 152-159 (8 steps) - hard-coded
  - RELEASE: steps 160-167 (8 steps) - hard-coded

Output format:
  - Each episode produces 2 separate trajectories: REACH (64 steps) and CARRY (64 steps)
  - 80 episodes -> 160 trajectories, each 64 steps
  - States: 22 dims (joint_pos, joint_vel, gripper, ee_pos, ee_quat)
  - Actions: 8 dims (joint_pos_target + gripper)
"""

import os
import pickle
import numpy as np
from absl import app, flags
from tqdm import tqdm

FLAGS = flags.FLAGS
flags.DEFINE_string("data_path", None, "Path to RLBench dataset (e.g., /path/to/stack_blocks/variation0)")
flags.DEFINE_string("output_path", None, "Output NPZ file path")
flags.DEFINE_integer("max_episodes", -1, "Maximum number of episodes to process (-1 for all)")

flags.mark_flag_as_required("data_path")
flags.mark_flag_as_required("output_path")

# Phase configuration (must match dataset_generator_pick_place_cp.py)
PHASE_STEPS = {
    "reach": 64,
    "descend": 8,
    "grasp": 8,
    "lift": 8,
    "carry": 64,
    "descend_release": 8,
    "release": 8,
}

# Compute phase start indices
REACH_START = 0
REACH_END = PHASE_STEPS["reach"]  # 64
CARRY_START = REACH_END + PHASE_STEPS["descend"] + PHASE_STEPS["grasp"] + PHASE_STEPS["lift"]  # 88
CARRY_END = CARRY_START + PHASE_STEPS["carry"]  # 152


def extract_state_from_obs(obs):
    """
    Extract state vector from RLBench observation.

    State includes:
    - joint_positions: 7 dims
    - joint_velocities: 7 dims
    - gripper_open: 1 dim (0.0 or 1.0)
    - gripper_pose: 7 dims (position xyz + quaternion xyzw)
    Total: 22 dims
    """
    state = np.concatenate([
        obs.joint_positions,        # 7
        obs.joint_velocities,       # 7
        np.array([obs.gripper_open], dtype=np.float32),  # 1
        obs.gripper_pose[:3],       # position: 3
        obs.gripper_pose[3:],       # quaternion: 4
    ])
    return state


def extract_action_from_obs(obs):
    """
    Extract the executed joint target vector that produced this observation.
    """
    if hasattr(obs, "misc") and isinstance(obs.misc, dict):
        action = obs.misc.get("joint_position_action", None)
        if action is not None:
            return np.array(action, dtype=np.float32)

    # Fallback: use the measured joint positions + gripper state
    return np.concatenate([
        np.array(obs.joint_positions, dtype=np.float32),
        np.array([obs.gripper_open], dtype=np.float32),
    ])


def load_episode(episode_path):
    """Load a single episode from pickle file"""
    pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
    if not os.path.exists(pkl_path):
        return None

    with open(pkl_path, "rb") as f:
        demo = pickle.load(f)

    return demo


def process_rlbench_data(data_path, max_episodes=-1):
    """
    Process RLBench pick-and-place dataset, extracting REACH and CARRY as separate trajectories.

    Returns:
        states: (total_timesteps, state_dim) array
        actions: (total_timesteps, action_dim) array
        traj_lengths: (num_trajectories,) array - each entry is 64 (one phase)
    """

    all_states = []
    all_actions = []
    traj_lengths = []

    # Check for train/episodes structure
    episodes_path = os.path.join(data_path, "train", "episodes")
    if not os.path.exists(episodes_path):
        # Try variation0 structure
        episodes_path = os.path.join(data_path, "variation0", "episodes")
    if not os.path.exists(episodes_path):
        # Direct episodes folder
        episodes_path = os.path.join(data_path, "episodes")

    if not os.path.exists(episodes_path):
        raise ValueError(f"Could not find episodes folder in {data_path}")

    # Find all episode folders
    episode_folders = sorted([
        d for d in os.listdir(episodes_path)
        if os.path.isdir(os.path.join(episodes_path, d)) and d.startswith("episode")
    ])

    print(f"Found {len(episode_folders)} episodes in {episodes_path}")
    print(f"Extracting REACH (steps {REACH_START}-{REACH_END-1}) and CARRY (steps {CARRY_START}-{CARRY_END-1}) as SEPARATE trajectories")

    episode_count = 0
    skipped_count = 0

    for ep_folder in tqdm(episode_folders, desc="Loading episodes"):
        if max_episodes > 0 and episode_count >= max_episodes:
            break

        episode_path = os.path.join(episodes_path, ep_folder)
        demo = load_episode(episode_path)

        if demo is None:
            print(f"Warning: Failed to load {episode_path}")
            skipped_count += 1
            continue

        # Check if demo has enough steps
        # Demo has one more observation than actions (initial state + actions)
        # So we need at least CARRY_END + 1 observations
        if len(demo) < CARRY_END + 1:
            print(f"Warning: Episode {ep_folder} has only {len(demo)} steps, need {CARRY_END + 1}. Skipping.")
            skipped_count += 1
            continue

        # Extract REACH phase (steps 0-63) as separate trajectory
        reach_states = []
        reach_actions = []
        for t in range(REACH_START, REACH_END):
            prev_obs = demo[t]
            next_obs = demo[t + 1]

            state = extract_state_from_obs(prev_obs)
            action = extract_action_from_obs(next_obs)

            reach_states.append(state)
            reach_actions.append(action)

        reach_states = np.array(reach_states, dtype=np.float32)
        reach_actions = np.array(reach_actions, dtype=np.float32)

        if reach_actions.shape[1] != 8:
            print(f"Warning: Expected 8-dim action, got {reach_actions.shape[1]} in {ep_folder} REACH. Skipping.")
            skipped_count += 1
            continue

        # Add REACH as a separate trajectory
        all_states.append(reach_states)
        all_actions.append(reach_actions)
        traj_lengths.append(len(reach_states))  # Should be 64

        # Extract CARRY phase (steps 88-151) as separate trajectory
        carry_states = []
        carry_actions = []
        for t in range(CARRY_START, CARRY_END):
            prev_obs = demo[t]
            next_obs = demo[t + 1]

            state = extract_state_from_obs(prev_obs)
            action = extract_action_from_obs(next_obs)

            carry_states.append(state)
            carry_actions.append(action)

        carry_states = np.array(carry_states, dtype=np.float32)
        carry_actions = np.array(carry_actions, dtype=np.float32)

        if carry_actions.shape[1] != 8:
            print(f"Warning: Expected 8-dim action, got {carry_actions.shape[1]} in {ep_folder} CARRY. Skipping.")
            skipped_count += 1
            continue

        # Add CARRY as a separate trajectory
        all_states.append(carry_states)
        all_actions.append(carry_actions)
        traj_lengths.append(len(carry_states))  # Should be 64

        episode_count += 1

    if episode_count == 0:
        raise ValueError("No valid episodes found!")

    # Concatenate all trajectories
    states = np.concatenate(all_states, axis=0)
    actions = np.concatenate(all_actions, axis=0)
    traj_lengths = np.array(traj_lengths, dtype=np.int32)

    print(f"\nDataset statistics:")
    print(f"  Original episodes: {episode_count}")
    print(f"  Skipped episodes: {skipped_count}")
    print(f"  Total trajectories: {len(traj_lengths)} ({episode_count} REACH + {episode_count} CARRY)")
    print(f"  Total timesteps: {len(states)}")
    print(f"  State shape: {states.shape}")
    print(f"  Action shape: {actions.shape}")
    print(f"  Steps per trajectory: {traj_lengths[0]} (64 steps each)")

    return states, actions, traj_lengths


def main(argv):
    data_path = FLAGS.data_path
    output_path = FLAGS.output_path
    max_episodes = FLAGS.max_episodes

    print(f"Processing RLBench pick-and-place data from: {data_path}")
    print(f"Output file: {output_path}")
    print(f"Extracting REACH and CARRY as separate 64-step trajectories")

    # Process data
    states, actions, traj_lengths = process_rlbench_data(data_path, max_episodes)

    # Save to NPZ
    os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True)
    np.savez_compressed(
        output_path,
        states=states,
        actions=actions,
        traj_lengths=traj_lengths,
    )

    print(f"\nSaved dataset to {output_path}")
    print(f"State dim: {states.shape[1]}")
    print(f"Action dim: {actions.shape[1]}")


if __name__ == "__main__":
    app.run(main)
