"""
Process pick-and-place dataset: segment into REACH and CARRY sub-trajectories,
extract (state, action) pairs, and normalize.

Input: Dataset generated by dataset_generator_pick_place_cp.py
Output: Segmented and normalized NPZ files for diffusion policy training

Each sub-trajectory contains:
- 64 timesteps of (state, action) pairs
- Associated canonical parameters (angle, dist_frac, pos_frac) normalized to [0, 1]

Phase configuration:
- REACH: steps [0, 64) - learned Bezier curve from home to pregrasp
- CARRY: steps [88, 152) - learned Bezier curve from lift to prerelease
"""

import os
import pickle
import numpy as np
from absl import app, flags
from tqdm import tqdm

FLAGS = flags.FLAGS

# Default path matches dataset_generator_pick_place_cp.py output
DEFAULT_DATA_PATH = "/scratch4/workspace/placeholder-hdp1/dppo/rlbench_pick_place_data/stack_blocks/variation0"

flags.DEFINE_string("data_path", DEFAULT_DATA_PATH,
    "Path to the dataset (e.g., .../stack_blocks/variation0).")
flags.DEFINE_string("output_path", None,
    "Output directory. Defaults to data_path/processed.")
flags.DEFINE_string("split", "train",
    "Which split to process: train or eval.")

# Phase configuration (must match dataset_generator_pick_place_cp.py)
PHASE_STEPS = {
    "reach": 64,           # Phase 1: steps [0, 64)
    "descend": 8,          # Phase 2: steps [64, 72)
    "grasp": 8,            # Phase 3: steps [72, 80)
    "lift": 8,             # Phase 4: steps [80, 88)
    "carry": 64,           # Phase 5: steps [88, 152)
    "descend_release": 8,  # Phase 6: steps [152, 160)
    "release": 8,          # Phase 7: steps [160, 168)
}

# Compute phase indices
def get_phase_indices():
    indices = {}
    current = 0
    for phase, steps in PHASE_STEPS.items():
        indices[phase] = (current, current + steps)
        current += steps
    return indices

PHASE_INDICES = get_phase_indices()
REACH_START, REACH_END = PHASE_INDICES["reach"]  # [0, 64)
CARRY_START, CARRY_END = PHASE_INDICES["carry"]  # [88, 152)

# Robot bounds for normalization (same as compute_normalization_robot_bounds.py)
# Panda robot joint limits
JOINT_POS_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718, -2.8973, -0.0175, -2.8973], dtype=np.float32)
JOINT_POS_MAX = np.array([2.8973, 1.7628, 2.8973, -0.0698, 2.8973, 3.7525, 2.8973], dtype=np.float32)

# Joint velocity limits
JOINT_VEL_MAX = np.array([2.1750, 2.1750, 2.1750, 2.1750, 2.6100, 2.6100, 2.6100], dtype=np.float32)

# Workspace bounds
EE_POS_MIN = np.array([0.0, -0.6, 0.0], dtype=np.float32)
EE_POS_MAX = np.array([1.0, 0.6, 1.6], dtype=np.float32)


def extract_state_from_obs(obs):
    """
    Extract state vector from RLBench observation.
    State includes:
    - joint_positions: 7 dims
    - joint_velocities: 7 dims
    - gripper_open: 1 dim (0.0 or 1.0)
    - gripper_pose: 7 dims (position xyz + quaternion xyzw)
    Total: 22 dims
    """
    state = np.concatenate([
        obs.joint_positions,        # 7
        obs.joint_velocities,       # 7
        np.array([obs.gripper_open], dtype=np.float32),  # 1
        obs.gripper_pose[:3],       # position: 3
        obs.gripper_pose[3:],       # quaternion: 4
    ])
    return state.astype(np.float32)


def extract_action_from_obs(obs):
    """
    Extract the executed joint target vector that produced this observation.
    """
    if hasattr(obs, "misc") and isinstance(obs.misc, dict):
        action = obs.misc.get("joint_position_action", None)
        if action is not None:
            return np.array(action, dtype=np.float32)

    # Fallback: use measured joint positions + gripper state
    return np.concatenate([
        np.array(obs.joint_positions, dtype=np.float32),
        np.array([obs.gripper_open], dtype=np.float32),
    ])


def get_normalization_bounds():
    """Get observation and action normalization bounds based on robot limits."""
    # Observation bounds (22D)
    obs_min = np.zeros(22, dtype=np.float32)
    obs_max = np.zeros(22, dtype=np.float32)

    # Joint positions (dims 0-6)
    obs_min[0:7] = JOINT_POS_MIN
    obs_max[0:7] = JOINT_POS_MAX

    # Joint velocities (dims 7-13)
    obs_min[7:14] = -JOINT_VEL_MAX
    obs_max[7:14] = JOINT_VEL_MAX

    # Gripper open (dim 14)
    obs_min[14] = 0.0
    obs_max[14] = 1.0

    # EE position (dims 15-17)
    obs_min[15:18] = EE_POS_MIN
    obs_max[15:18] = EE_POS_MAX

    # EE quaternion (dims 18-21)
    obs_min[18:22] = -1.0
    obs_max[18:22] = 1.0

    # Action bounds (8D: 7 joints + 1 gripper)
    action_min = np.concatenate([JOINT_POS_MIN, [0.0]])
    action_max = np.concatenate([JOINT_POS_MAX, [1.0]])

    return obs_min, obs_max, action_min, action_max


def normalize_to_minus1_1(x, x_min, x_max):
    """Normalize x from [x_min, x_max] to [-1, 1]."""
    x_range = x_max - x_min
    x_range = np.where(x_range < 1e-6, 1.0, x_range)
    return 2.0 * (x - x_min) / x_range - 1.0


def normalize_canonical_params(params):
    """
    Normalize canonical parameters to [0, 1].

    Input params: (angle, dist_frac, pos_frac)
    - angle: [0, 2π] -> [0, 1]
    - dist_frac: already [0, 1]
    - pos_frac: already [0, 1]
    """
    angle, dist_frac, pos_frac = params
    angle_normalized = angle / (2 * np.pi)  # [0, 2π] -> [0, 1]
    return np.array([angle_normalized, dist_frac, pos_frac], dtype=np.float32)


def load_episode(episode_path):
    """Load a single episode from pickle file."""
    pkl_path = os.path.join(episode_path, "low_dim_obs.pkl")
    if not os.path.exists(pkl_path):
        return None

    with open(pkl_path, "rb") as f:
        demo = pickle.load(f)

    return demo


def extract_subtraj_states_actions(demo, start_idx, end_idx):
    """
    Extract states and actions for a sub-trajectory.

    Returns:
        states: (num_steps, state_dim) - states for timesteps [start_idx, end_idx)
        actions: (num_steps, action_dim) - actions that produced next states
    """
    states = []
    actions = []

    # For each timestep t in [start_idx, end_idx-1], we want:
    # - state at t (from demo[t])
    # - action that transitions from t to t+1 (from demo[t+1])
    for t in range(start_idx, end_idx):
        if t >= len(demo) - 1:
            break

        state = extract_state_from_obs(demo[t])
        action = extract_action_from_obs(demo[t + 1])

        states.append(state)
        actions.append(action)

    return np.array(states, dtype=np.float32), np.array(actions, dtype=np.float32)


def process_dataset(data_path, split):
    """
    Process dataset and segment into REACH and CARRY sub-trajectories.

    Returns:
        reach_data: dict with states, actions, canonical_params, traj_lengths
        carry_data: dict with states, actions, canonical_params, traj_lengths
    """
    split_root = os.path.join(data_path, split)
    episodes_path = os.path.join(split_root, "episodes")
    metadata_path = os.path.join(split_root, f"{split}_metadata.npy")

    if not os.path.exists(episodes_path):
        raise FileNotFoundError(f"Episodes path not found: {episodes_path}")

    # Load metadata
    metadata = None
    if os.path.exists(metadata_path):
        metadata = np.load(metadata_path, allow_pickle=True)
        print(f"Loaded metadata: {len(metadata)} entries")

    # Get episode directories
    episode_dirs = sorted([
        d for d in os.listdir(episodes_path)
        if os.path.isdir(os.path.join(episodes_path, d)) and d.startswith("episode")
    ])
    print(f"Found {len(episode_dirs)} episodes")

    # Storage for REACH and CARRY sub-trajectories
    reach_states_all = []
    reach_actions_all = []
    reach_params_all = []
    reach_traj_lengths = []
    reach_metadata_list = []  # For training: end_pos, canonical_cp_params

    carry_states_all = []
    carry_actions_all = []
    carry_params_all = []
    carry_traj_lengths = []
    carry_metadata_list = []  # For training: end_pos, canonical_cp_params

    for i, ep_dir in enumerate(tqdm(episode_dirs, desc="Processing episodes")):
        ep_path = os.path.join(episodes_path, ep_dir)
        demo = load_episode(ep_path)

        if demo is None:
            print(f"Warning: Failed to load {ep_path}")
            continue

        if len(demo) < CARRY_END + 1:
            print(f"Warning: Demo {ep_dir} has only {len(demo)} steps, expected at least {CARRY_END + 1}")
            continue

        # Get canonical parameters and waypoints from metadata
        if metadata is not None and i < len(metadata):
            reach_params = np.array(metadata[i].get('reach_cp_params', [0, 0, 0]), dtype=np.float32)
            carry_params = np.array(metadata[i].get('carry_cp_params', [0, 0, 0]), dtype=np.float32)
            # Get end positions for each phase
            pregrasp_pos = np.array(metadata[i].get('pregrasp_pos', [0, 0, 0]), dtype=np.float32)
            prerelease_pos = np.array(metadata[i].get('prerelease_pos', [0, 0, 0]), dtype=np.float32)
        else:
            reach_params = np.array([0, 0, 0], dtype=np.float32)
            carry_params = np.array([0, 0, 0], dtype=np.float32)
            pregrasp_pos = np.array([0, 0, 0], dtype=np.float32)
            prerelease_pos = np.array([0, 0, 0], dtype=np.float32)

        # Extract REACH sub-trajectory: steps [0, 64)
        reach_states, reach_actions = extract_subtraj_states_actions(demo, REACH_START, REACH_END)
        if len(reach_states) == REACH_END - REACH_START:
            reach_states_all.append(reach_states)
            reach_actions_all.append(reach_actions)
            reach_params_all.append(normalize_canonical_params(reach_params))
            reach_traj_lengths.append(len(reach_states))
            # Metadata for training (format expected by ControlPointSequenceDataset)
            reach_metadata_list.append({
                'canonical_cp_params': normalize_canonical_params(reach_params).tolist(),
                'end_pos': pregrasp_pos.tolist(),  # REACH ends at pregrasp
            })

        # Extract CARRY sub-trajectory: steps [88, 152)
        carry_states, carry_actions = extract_subtraj_states_actions(demo, CARRY_START, CARRY_END)
        if len(carry_states) == CARRY_END - CARRY_START:
            carry_states_all.append(carry_states)
            carry_actions_all.append(carry_actions)
            carry_params_all.append(normalize_canonical_params(carry_params))
            carry_traj_lengths.append(len(carry_states))
            # Metadata for training (format expected by ControlPointSequenceDataset)
            carry_metadata_list.append({
                'canonical_cp_params': normalize_canonical_params(carry_params).tolist(),
                'end_pos': prerelease_pos.tolist(),  # CARRY ends at prerelease
            })

    # Stack into arrays
    reach_data = {
        'states': np.concatenate(reach_states_all, axis=0),
        'actions': np.concatenate(reach_actions_all, axis=0),
        'canonical_params': np.array(reach_params_all),  # (num_trajs, 3)
        'traj_lengths': np.array(reach_traj_lengths, dtype=np.int32),
    }

    carry_data = {
        'states': np.concatenate(carry_states_all, axis=0),
        'actions': np.concatenate(carry_actions_all, axis=0),
        'canonical_params': np.array(carry_params_all),  # (num_trajs, 3)
        'traj_lengths': np.array(carry_traj_lengths, dtype=np.int32),
    }

    return reach_data, carry_data, reach_metadata_list, carry_metadata_list


def main(argv):
    # Find data path
    if os.path.exists(os.path.join(FLAGS.data_path, "train")):
        data_path = FLAGS.data_path
    else:
        raise FileNotFoundError(f"Cannot find data at {FLAGS.data_path}")

    print(f"Processing data from: {data_path}")
    print(f"Split: {FLAGS.split}")

    # Set output path
    output_path = FLAGS.output_path if FLAGS.output_path else os.path.join(data_path, "processed")
    os.makedirs(output_path, exist_ok=True)
    print(f"Output path: {output_path}")

    # Get normalization bounds
    obs_min, obs_max, action_min, action_max = get_normalization_bounds()

    # Process dataset
    reach_data, carry_data, reach_metadata, carry_metadata = process_dataset(data_path, FLAGS.split)

    print(f"\n{'='*60}")
    print("REACH sub-trajectories:")
    print(f"  Num trajectories: {len(reach_data['traj_lengths'])}")
    print(f"  Total timesteps: {len(reach_data['states'])}")
    print(f"  States shape: {reach_data['states'].shape}")
    print(f"  Actions shape: {reach_data['actions'].shape}")
    print(f"  Canonical params shape: {reach_data['canonical_params'].shape}")
    print(f"  Steps per traj: {reach_data['traj_lengths'][0]} (expected {REACH_END - REACH_START})")

    print(f"\nCARRY sub-trajectories:")
    print(f"  Num trajectories: {len(carry_data['traj_lengths'])}")
    print(f"  Total timesteps: {len(carry_data['states'])}")
    print(f"  States shape: {carry_data['states'].shape}")
    print(f"  Actions shape: {carry_data['actions'].shape}")
    print(f"  Canonical params shape: {carry_data['canonical_params'].shape}")
    print(f"  Steps per traj: {carry_data['traj_lengths'][0]} (expected {CARRY_END - CARRY_START})")

    # Save raw data
    print(f"\nSaving raw data...")
    np.savez_compressed(
        os.path.join(output_path, f"{FLAGS.split}_reach_raw.npz"),
        **reach_data
    )
    np.savez_compressed(
        os.path.join(output_path, f"{FLAGS.split}_carry_raw.npz"),
        **carry_data
    )

    # Normalize states and actions
    print(f"\nNormalizing data to [-1, 1]...")

    reach_states_norm = normalize_to_minus1_1(reach_data['states'], obs_min, obs_max)
    reach_actions_norm = normalize_to_minus1_1(reach_data['actions'], action_min, action_max)

    carry_states_norm = normalize_to_minus1_1(carry_data['states'], obs_min, obs_max)
    carry_actions_norm = normalize_to_minus1_1(carry_data['actions'], action_min, action_max)

    print(f"  REACH states range: [{reach_states_norm.min():.4f}, {reach_states_norm.max():.4f}]")
    print(f"  REACH actions range: [{reach_actions_norm.min():.4f}, {reach_actions_norm.max():.4f}]")
    print(f"  CARRY states range: [{carry_states_norm.min():.4f}, {carry_states_norm.max():.4f}]")
    print(f"  CARRY actions range: [{carry_actions_norm.min():.4f}, {carry_actions_norm.max():.4f}]")

    # Save normalized data
    print(f"\nSaving normalized data...")
    np.savez_compressed(
        os.path.join(output_path, f"{FLAGS.split}_reach_normalized.npz"),
        states=reach_states_norm.astype(np.float32),
        actions=reach_actions_norm.astype(np.float32),
        canonical_params=reach_data['canonical_params'],  # Already [0, 1]
        traj_lengths=reach_data['traj_lengths'],
    )
    np.savez_compressed(
        os.path.join(output_path, f"{FLAGS.split}_carry_normalized.npz"),
        states=carry_states_norm.astype(np.float32),
        actions=carry_actions_norm.astype(np.float32),
        canonical_params=carry_data['canonical_params'],  # Already [0, 1]
        traj_lengths=carry_data['traj_lengths'],
    )

    # Save normalization statistics
    np.savez(
        os.path.join(output_path, "normalization.npz"),
        obs_min=obs_min,
        obs_max=obs_max,
        action_min=action_min,
        action_max=action_max,
    )

    # Save metadata for training (format expected by ControlPointSequenceDataset)
    # Contains: canonical_cp_params (normalized [0,1]) and end_pos for each trajectory
    print(f"\nSaving metadata for training...")
    np.save(os.path.join(output_path, f"{FLAGS.split}_reach_metadata.npy"), np.array(reach_metadata))
    np.save(os.path.join(output_path, f"{FLAGS.split}_carry_metadata.npy"), np.array(carry_metadata))

    # Create MERGED dataset (REACH + CARRY combined)
    # This allows training a single unified model
    print(f"\nCreating merged dataset (REACH + CARRY)...")

    # Merge states and actions
    merged_states = np.concatenate([reach_states_norm, carry_states_norm], axis=0)
    merged_actions = np.concatenate([reach_actions_norm, carry_actions_norm], axis=0)
    merged_traj_lengths = np.concatenate([reach_data['traj_lengths'], carry_data['traj_lengths']])

    # Merge metadata
    merged_metadata = reach_metadata + carry_metadata

    # Save merged normalized data
    np.savez_compressed(
        os.path.join(output_path, f"{FLAGS.split}_merged_normalized.npz"),
        states=merged_states.astype(np.float32),
        actions=merged_actions.astype(np.float32),
        traj_lengths=merged_traj_lengths,
    )
    np.save(os.path.join(output_path, f"{FLAGS.split}_merged_metadata.npy"), np.array(merged_metadata))

    print(f"  Merged trajectories: {len(merged_traj_lengths)} (REACH: {len(reach_data['traj_lengths'])}, CARRY: {len(carry_data['traj_lengths'])})")
    print(f"  Merged states shape: {merged_states.shape}")
    print(f"  Merged actions shape: {merged_actions.shape}")

    print(f"\n{'='*60}")
    print("Processing Complete!")
    print(f"{'='*60}")
    print(f"\nOutput files:")
    print(f"  {FLAGS.split}_reach_raw.npz")
    print(f"  {FLAGS.split}_reach_normalized.npz")
    print(f"  {FLAGS.split}_reach_metadata.npy")
    print(f"  {FLAGS.split}_carry_raw.npz")
    print(f"  {FLAGS.split}_carry_normalized.npz")
    print(f"  {FLAGS.split}_carry_metadata.npy")
    print(f"  {FLAGS.split}_merged_normalized.npz  <-- USE THIS FOR UNIFIED MODEL")
    print(f"  {FLAGS.split}_merged_metadata.npy")
    print(f"  normalization.npz")
    print(f"\nSaved to: {output_path}")

    # Print canonical params statistics
    print(f"\n{'='*60}")
    print("Canonical Parameters (normalized to [0, 1]):")
    print(f"{'='*60}")
    print(f"REACH:")
    print(f"  angle:     min={reach_data['canonical_params'][:, 0].min():.4f}, max={reach_data['canonical_params'][:, 0].max():.4f}")
    print(f"  dist_frac: min={reach_data['canonical_params'][:, 1].min():.4f}, max={reach_data['canonical_params'][:, 1].max():.4f}")
    print(f"  pos_frac:  min={reach_data['canonical_params'][:, 2].min():.4f}, max={reach_data['canonical_params'][:, 2].max():.4f}")
    print(f"CARRY:")
    print(f"  angle:     min={carry_data['canonical_params'][:, 0].min():.4f}, max={carry_data['canonical_params'][:, 0].max():.4f}")
    print(f"  dist_frac: min={carry_data['canonical_params'][:, 1].min():.4f}, max={carry_data['canonical_params'][:, 1].max():.4f}")
    print(f"  pos_frac:  min={carry_data['canonical_params'][:, 2].min():.4f}, max={carry_data['canonical_params'][:, 2].max():.4f}")


if __name__ == "__main__":
    app.run(main)
