import numpy as np
import torch
import os
import random
import h5py
import sys
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
from torchvision import transforms, utils
from tqdm import tqdm
import json
import time
import copy
import re
import zarr
from typing import Union, Optional, Tuple
import numcodecs
import numpy as np
import simplejpeg
import functools
import matplotlib.pyplot as plt
from random import randint
from PIL import Image
import multiprocessing
from collections import defaultdict
from typing import Dict, Any
from gr00t.data.transform import ComposedModalityTransform
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.data.schema import (
    DatasetMetadata,
    DatasetStatisticalValues,
    LeRobotModalityMetadata,
    LeRobotStateActionMetadata,
)
dataset_statistics = {
    "min": [
        
    ],
    "max": [

    ],
    "mean": [

    ],
    "std": [

    ],
    "q01": [

    ],
    "q99": [

    ]
}

dvrk_modality = {
    "state": {
        "psm1": {
            "start": 0,
            "end": 7
        },
        "psm2": {
            "start": 7,
            "end": 14
        }
    },
    "action": {
        "psm1": {
            "original_key": "actions",
            "start": 0,
            "end": 7
        },
        "psm2": {
            "original_key": "actions",
            "start": 7,
            "end": 14
        }
    },
    "video": {
        "main": {
            "original_key": "observation.video.left"
        },
        "endo_psm1": {
            "original_key": "observation.video.endo_psm1"
        },
        "endo_psm2": {
            "original_key": "observation.video.endo_psm2"
        }
    },
    "annotation": {
        "human.task_description": {
            "original_key": "task_index"
        }
    }
}

dvrk_video_info = {
    "fps": 30,
    "observation.video.left": {                                                                                                                       
        "dtype": "video",                                                                                                                             
        "shape": [                                                                                                                                    
            224,                                                                                                                                      
            224,                                                                                                                                      
            3                                                                                                                                         
        ],                                                                                                                                            
        "names": [                                                                                                                                    
            "height",                                                                                                                                 
            "width",                                                                                                                                  
            "channel"                                                                                                                                
        ]                                                                                                                                             
    },                                                                                                                                                
    "observation.video.endo_psm1": {                                                                                                                      
        "dtype": "video",                                                                                                                             
        "shape": [                                                                                                                                    
            224,                                                                                                                                      
            224,                                                                                                                                      
            3                                                                                                                                         
        ],                                                                                                                                            
        "names": [                                                                                                                                    
            "height",                                                                                                                                 
            "width",                                                                                                                                  
            "channel"                                                                                                                                
        ]                                                                                                                                             
    },                                                                                                                                                
    "observation.video.endo_psm2": {                                                                                                                  
        "dtype": "video",                                                                                                                             
        "shape": [                                                                                                                                    
            224,                                                                                                                                      
            224,                                                                                                                                      
            3           
        ],                                                                                                                                            
        "names": [                                                                                                                                    
            "height",                                                                                                                                 
            "width",                                                                                                                                  
            "channel"                                                                                                                                
        ]
    }                                                                                                                                      
}

def save_obs_debug_imgs(obs: Dict[str, Any], targeting_type: str, idx: int):
    """
    Save the observation images to a directory for debugging purposes.
    """
    debug_dir = os.path.join("debug_imgs_complete", targeting_type)

    os.makedirs(debug_dir, exist_ok=True)

    for key, value in obs.items():
        if "video" in key:
            cv2.imwrite(os.path.join(debug_dir, f"{idx}_{key}.png"), cv2.cvtColor((value[0]).astype(np.uint8), cv2.COLOR_RGB2BGR))

def apply_color_jitter(image, brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08):
    """
    Apply a series of image augmentations similar to VideoCrop, VideoResize, and VideoColorJitter.
    
    Args:
        image: Input image (H, W, C) in uint8 format
        scale: Scale factor for cropping (0.95 means crop to 95% of original size)
        target_size: Target size for resizing (height, width)
        brightness: Maximum brightness adjustment factor
        contrast: Maximum contrast adjustment factor
        saturation: Maximum saturation adjustment factor
        hue: Maximum hue adjustment factor
        
    Returns:
        Augmented image with same dimensions as target_size
    """
    if image.dtype != np.uint8:
        image = image.astype(np.uint8)
    
    # 3. Color jitter (similar to VideoColorJitter)
    # Convert to HSV for easier color manipulation
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    
    # Apply random brightness adjustment
    brightness_factor = 1.0 + random.uniform(-brightness, brightness)
    hsv[:,:,2] = hsv[:,:,2] * brightness_factor
    
    # Apply random saturation adjustment
    saturation_factor = 1.0 + random.uniform(-saturation, saturation)
    hsv[:,:,1] = hsv[:,:,1] * saturation_factor
    
    # Apply random hue adjustment
    hue_factor = random.uniform(-hue, hue) * 180  # OpenCV uses 0-180 for hue
    hsv[:,:,0] = (hsv[:,:,0] + hue_factor) % 180
    
    # Clip HSV values to valid ranges
    hsv[:,:,0] = np.clip(hsv[:,:,0], 0, 180)
    hsv[:,:,1] = np.clip(hsv[:,:,1], 0, 255)
    hsv[:,:,2] = np.clip(hsv[:,:,2], 0, 255)
    
    # Convert back to BGR
    augmented = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
    
    # Apply random contrast adjustment (after HSV conversion)
    contrast_factor = 1.0 + random.uniform(-contrast, contrast)
    augmented = np.clip(augmented * contrast_factor, 0, 255).astype(np.uint8)
    
    return augmented

def create_dot_image(episode_path, img_l, target_img_size=(224, 224)):
    """
    Create a raw dot image with a blue circle for the entry point and a green circle
    for the exit point. The red channel is zeroed out.
    This image is intended to be augmented geometrically alongside the main image
    and then overlaid.
    
    Args:
        episode_path: Path to the episode data
        img_l: Reference image (used for dimensions)
        target_img_size: Target image size (currently unused in this raw creation)
        
    Returns:
        A BGR uint8 image (H, W, 3) with blue/green dots on a black background.
    """
    clicked_point_path = episode_path.split(".zip")[0] + "_clicked_point.csv"
    # first row is xy entry point, second row is xy exit point
    clicked_point = pd.read_csv(clicked_point_path)
    entry_point = clicked_point.iloc[0]
    exit_point = clicked_point.iloc[1]
    
    # Create a blank image with the same dimensions as the input image
    orig_height, orig_width = img_l.shape[:-1]
    dot_img = np.zeros((orig_height, orig_width, 3), dtype=np.uint8)
    
    # Draw a blue circle (B,G,R = 255,0,0) at the entry point
    cv2.circle(dot_img, (int(entry_point.x), int(entry_point.y)), radius=10, color=(255, 0, 0), thickness=-1)
    
    # Draw a green circle (B,G,R = 0,255,0) at the exit point
    cv2.circle(dot_img, (int(exit_point.x), int(exit_point.y)), radius=10, color=(0, 255, 0), thickness=-1)
    
    # Zero out the red channel completely
    dot_img[:,:,2] = 0

    # convert to rgb
    dot_img = cv2.cvtColor(dot_img, cv2.COLOR_BGR2RGB)
    
    return dot_img

def _assert_shape(arr: np.ndarray, expected_shape: Tuple[Optional[int], ...]):
    """Asserts that the shape of an array matches the expected shape."""
    assert len(arr.shape) == len(expected_shape), f"Expected shape of length {len(expected_shape)}, but got {len(arr.shape)}"
    for dim, expected_dim in zip(arr.shape, expected_shape):
        if expected_dim is not None:
            assert dim == expected_dim, f"Expected dimension {expected_dim}, but got {dim}"

def rotate_image(image, angle):
    """Rotate the image by the given angle."""
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h))
    return rotated

def shift_image(image, shift_x, shift_y):
    """Shift the image by the given x and y offsets."""
    (h, w) = image.shape[:2]
    M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
    shifted_image = cv2.warpAffine(image, M, (w, h))
    return shifted_image


class JpegCodec(numcodecs.abc.Codec):
    """Codec for JPEG compression.
    Encodes image chunks as JPEGs. Assumes that chunks are uint8 with shape (1, H, W, 3).
    """
    codec_id = "pi_jpeg"

    def __init__(self, quality: int = 95):
        super().__init__()
        self.quality = quality

    def encode(self, buf):
        _assert_shape(buf, (1, None, None, 3))
        assert buf.dtype == "uint8"
        return simplejpeg.encode_jpeg(buf[0], quality=self.quality)

    def decode(self, buf, out=None):
        img = simplejpeg.decode_jpeg(buf, buffer=out)
        return img[np.newaxis, ...]

@functools.cache
def register_codecs():
    """Register the custom codecs."""
    numcodecs.register_codec(JpegCodec)

register_codecs()

def get_robot_episodes(main_folders):
    """
    Get a list of episodes from the robot dataset. 
    It returns a list of tuples, where each tuple contains: the episode path, the language instruction, and the tissue index.
    """
    episode_list = []

    # List all 'tissue_*' directories in the main folder
    for main_folder in main_folders:

        tissue_folders = [
            f for f in os.listdir(main_folder)
            if f.startswith('tissue_') and os.path.isdir(os.path.join(main_folder, f))
        ]
        
        for tissue_folder in tissue_folders:
            tissue_path = os.path.join(main_folder, tissue_folder)
            
            # Extract tissue index from the tissue folder name
            tissue_index_match = re.search(r'tissue_(\d+)', tissue_folder)
            if tissue_index_match:
                tissue_index = int(tissue_index_match.group(1))
            else:
                tissue_index = None  # or handle as needed
            
            # List all instruction folders within the tissue folder
            instruction_folders = [
                f for f in os.listdir(tissue_path)
                if os.path.isdir(os.path.join(tissue_path, f))
            ]
            
            for instr_folder in instruction_folders:
                instr_path = os.path.join(tissue_path, instr_folder)
                
                # Extract language instruction from the folder name
                # Remove leading numbers and underscores
                instr_name = re.sub(r'^\d+_*', '', instr_folder)
                # Remove digits elsewhere in the string
                instr_name = re.sub(r'\d+', '', instr_name)
                # Remove the word 'recovery' if it exists
                instr_name = instr_name.replace('recovery', '')
                # Replace underscores with spaces
                instr_name = instr_name.replace('_', ' ')
                # Strip leading and trailing whitespace
                instr_name = instr_name.strip()
                
                # List all episode files (zipped files) within the instruction folder
                episode_files = [
                    f for f in os.listdir(instr_path)
                    if os.path.isfile(os.path.join(instr_path, f)) and f.endswith('.zip')
                ]

                # print(f"{len(episode_files)} episode files for main_folder: {main_folder}, tissue_folder: {tissue_folder}, instruction: {instr_name}")
                
                for episode_file in episode_files:
                    episode_path = os.path.join(instr_path, episode_file)
                    episode_list.append((episode_path, instr_name, tissue_index))

    print(f"Found {len(episode_list)} episodes")
    
    return episode_list

# Example usage: for testing get_robot_episodes()
# Replace '/path/to/your/dataset' with the actual path to your main folder
# episodes = get_robot_episodes("base_chole_clipping_cutting/processed_data_zipped_pi/")

# # Print the list of episodes, their corresponding language instructions, and tissue indices
# for episode_path, instruction, tissue_index in episodes:
#     print(f"Episode Path: {episode_path}")
#     print(f"Language Instruction: {instruction}")
#     print(f"Tissue Index: {tissue_index}\n")

# assert(False)


def resize_with_padding(image, target_width, target_height):
    """
    Resize an image to target dimensions while preserving aspect ratio using padding.
    
    Args:
        image: Input image
        target_width: Desired width
        target_height: Desired height
        
    Returns:
        Resized and padded image with dimensions (target_height, target_width)
    """
    # Get original dimensions
    h, w = image.shape[:2]
    
    # Calculate target aspect ratio and original aspect ratio
    target_aspect = target_width / target_height
    aspect = w / h
    
    # Determine new dimensions while preserving aspect ratio
    if aspect > target_aspect:
        # Image is wider than target aspect ratio
        new_w = target_width
        new_h = int(new_w / aspect)
    else:
        # Image is taller than target aspect ratio
        new_h = target_height
        new_w = int(new_h * aspect)
    
    # Resize image while preserving aspect ratio
    resized = cv2.resize(image, (new_w, new_h))
    
    # Create black canvas of target size
    padded = np.zeros((target_height, target_width, 3), dtype=np.uint8)
    
    # Calculate padding offsets to center the image
    pad_x = (target_width - new_w) // 2
    pad_y = (target_height - new_h) // 2
    
    # Place resized image on the canvas
    padded[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
    
    return padded

def quat_to_axis_angle_action(action):
    """
    Convert a quaternion action to an axis-angle action.
    
    Args:
        action: Tensor of shape (..., 8) representing [x,y,z,qw,qx,qy,qz,jaw]

    Returns:
        axis_angle_actions: Tensor of shape (..., 7) representing [x,y,z,rx,ry,rz,jaw]
    """
    quat_actions = action[:, 3:7]  # Shape: (n_actions, 4)

    r_actions = R.from_quat(quat_actions).as_rotvec()

    # Prepare the final diff array
    axis_angle_actions = np.zeros((action.shape[0], 7))  # Shape: (n_actions, 7)

    # Populate the diff_expand array
    axis_angle_actions[:, 0:3] = action[:, 0:3]     # Delta translation
    axis_angle_actions[:, 3:6] = r_actions          # Delta rotation (axis-angle)
    axis_angle_actions[:, 6] = action[:, 7]         # Abs Jaw
    
    return axis_angle_actions

class EpisodicDatasetDvrkGeneric(torch.utils.data.Dataset):
    def __init__(
        self,
        robot_base_dir_list,
        action_horizon = 50,
        cutting_action_pad_size = 10,
        transforms: ComposedModalityTransform | None = None,
        embodiment_tag: EmbodimentTag | None = None,
        batch_size: int = 1,
        downsample_factor: int = 1,
        is_eval: bool = False,
        ):

        super().__init__()
        assert len(robot_base_dir_list) > 0, "robot_base_dir_list must contain at least one directory"
        self.is_eval = is_eval
        self.batch_size = batch_size
        self.robot_base_dir_list = robot_base_dir_list
        self.action_horizon = action_horizon
        self.cutting_action_pad_size = cutting_action_pad_size
        self.fps = 30
        self.downsample_factor = downsample_factor
        self.normalize_actions_func = self.min_max_normalize_actions
        self.single_arm_action_dim = 7
        self.modality_json = dvrk_modality
        self.action_dim = 2 * self.single_arm_action_dim
        self.state_dim = 14
        self.transforms = (
            transforms if transforms is not None else ComposedModalityTransform(transforms=[])
        )
        self.tag = embodiment_tag.value
        self.target_size = (224, 224)
        
        # Get list of episodes
        self.episode_list = get_robot_episodes(self.robot_base_dir_list)
        
        # Create flattened list of all timesteps across episodes
        self.flattened_indices = []
        self.episode_lengths = []
        
        # Open each episode to get its length
        # Dictionary to store unique instructions and their indices
        self.instruction_to_idx = {}
        curr_idx = 0
        
        # Get list of episodes
        self.episode_list = get_robot_episodes(self.robot_base_dir_list)

        # Metadata paths for all directories
        meta_paths = [os.path.join(robot_base_dir, "meta.json") for robot_base_dir in self.robot_base_dir_list]
        task_meta_paths = [os.path.join(robot_base_dir, "task_meta.json") for robot_base_dir in self.robot_base_dir_list]

        # Dictionary to store task meta
        task_meta = {}

        self.metadata = self._get_metadata(embodiment_tag)
        self.set_transforms_metadata(self.metadata)

        # Check if metadata exists for all directories
        if all(os.path.exists(meta_path) and os.path.exists(task_meta_path) for meta_path, task_meta_path in zip(meta_paths, task_meta_paths)):
            print("Metadata files found. Loading from all directories.")
            for meta_path, task_meta_path in zip(meta_paths, task_meta_paths):
                with open(meta_path, "r") as meta_file:
                    meta_data = [json.loads(line) for line in meta_file]
                with open(task_meta_path, "r") as task_meta_file:
                    task_data = [json.loads(line) for line in task_meta_file]

                # Populate episode_lengths and instruction_to_idx from loaded metadata
                self.episode_lengths.extend([entry["length"] for entry in meta_data])
                for entry in task_data:
                    instruction = entry["tasks"][0]
                    if instruction not in self.instruction_to_idx:
                        self.instruction_to_idx[instruction] = curr_idx
                        curr_idx += 1
                        # Initialize task meta entry if not already
                        if instruction not in task_meta:
                            task_meta[instruction] = {
                                "episode_indices": [],
                                "tissue_ids": []
                            }
                    # Append episode index and tissue ID for the task
                    episode_idx = entry["episode_index"]
                    tissue_id = entry.get("tissue_id", None)  # Add tissue_id if it exists
                    task_meta[instruction]["episode_indices"].append(episode_idx)
                    task_meta[instruction]["tissue_ids"].append(tissue_id)
        else:
            print("Metadata files not found for some directories. Generating metadata...")
            for robot_base_dir in self.robot_base_dir_list:
                meta_data = []
                task_data = []
                for episode_idx, (episode_path, instruction, _) in enumerate(tqdm(self.episode_list)):
                    # Ensure the episode belongs to one of the robot_base_dir_list
                    if not any(episode_path.startswith(robot_base_dir) for robot_base_dir in self.robot_base_dir_list):
                        continue
                    # Ensure the episode belongs to the current robot_base_dir
                    if not episode_path.startswith(robot_base_dir):
                        continue

                    store = zarr.ZipStore(episode_path, mode='r')
                    zarr_store = zarr.group(store=store)
                    kinematics = zarr_store['kinematics'][:]
                    episode_len = len(kinematics)
                    self.episode_lengths.append(episode_len)
                    store.close()

                    # Append metadata for the current episode
                    meta_data.append({
                        "episode_index": episode_idx,
                        "tasks": [instruction],
                        "length": episode_len
                    })

                    # Add instruction to dict if not seen before
                    if instruction not in self.instruction_to_idx:
                        self.instruction_to_idx[instruction] = curr_idx
                        curr_idx += 1
                        task_data.append({
                            "episode_index": episode_idx,
                            "tasks": [instruction],
                            "tissue_id": instruction,  # You can set this to any relevant tissue ID
                        })

                    # Append the task's episode index and tissue ID to the task_meta
                    if instruction not in task_meta:
                        task_meta[instruction] = {
                            "episode_indices": [],
                            "tissue_ids": []
                        }
                    task_meta[instruction]["episode_indices"].append(episode_idx)
                    task_meta[instruction]["tissue_ids"].append(instruction)  # This should be the actual tissue ID

                # Save metadata to JSON files for the current directory
                meta_path = os.path.join(robot_base_dir, "meta.json")
                task_meta_path = os.path.join(robot_base_dir, "task_meta.json")
                with open(meta_path, "w") as meta_file:
                    for entry in meta_data:
                        meta_file.write(json.dumps(entry) + "\n")
                print(f"Metadata saved to {meta_path}")
                with open(task_meta_path, "w") as task_meta_file:
                    for entry in task_data:
                        task_meta_file.write(json.dumps(entry) + "\n")
                print(f"Task Metadata saved to {task_meta_path}")

        # Save task_meta to a global file for later sampling
        task_meta_global_path = os.path.join(self.robot_base_dir_list[0], "task_meta_global.json")
        with open(task_meta_global_path, "w") as task_meta_global_file:
            json.dump(task_meta, task_meta_global_file)
        print(f"Global Task Metadata saved to {task_meta_global_path}")

        # Sort episode list by (task name) to group tasks together 
        self.episode_list.sort(key=lambda tup: tup[1])

        # Create flattened index + per-task sample indices
        self.flattened_indices = []
        self.task_to_indices = defaultdict(list)

        for episode_idx, (episode_path, instruction, _) in enumerate(self.episode_list):
            episode_len = self.episode_lengths[episode_idx]
            task_idx = self.instruction_to_idx[instruction]

            for ts in range(episode_len):
                flat_idx = len(self.flattened_indices)
                self.flattened_indices.append((episode_idx, ts))
                self.task_to_indices[task_idx].append(flat_idx)
        # psm = patient side manipulator
        # qpos = current pose read from the robot
        self.header_name_qpos_psm1 = ["psm1_pose.position.x", "psm1_pose.position.y", "psm1_pose.position.z",
                                "psm1_pose.orientation.x", "psm1_pose.orientation.y", "psm1_pose.orientation.z", "psm1_pose.orientation.w",
                                "psm1_jaw"]
        
        self.header_name_qpos_psm2 = ["psm2_pose.position.x", "psm2_pose.position.y", "psm2_pose.position.z",
                                "psm2_pose.orientation.x", "psm2_pose.orientation.y", "psm2_pose.orientation.z", "psm2_pose.orientation.w",
                                "psm2_jaw"]

        # sp = setpoint (i.e. when you teleoperate, it generate setpoints for the robot to reach)
        self.header_name_actions_psm1 = ["psm1_sp.position.x", "psm1_sp.position.y", "psm1_sp.position.z",
                                    "psm1_sp.orientation.x", "psm1_sp.orientation.y", "psm1_sp.orientation.z", "psm1_sp.orientation.w",
                                    "psm1_jaw_sp"]

        self.header_name_actions_psm2 = ["psm2_sp.position.x", "psm2_sp.position.y", "psm2_sp.position.z",
                                    "psm2_sp.orientation.x", "psm2_sp.orientation.y", "psm2_sp.orientation.z", "psm2_sp.orientation.w",
                                    "psm2_jaw_sp"]
        
        self.quat_cp_psm1 = ["psm1_pose.orientation.x", "psm1_pose.orientation.y", "psm1_pose.orientation.z", "psm1_pose.orientation.w"]
        self.quat_cp_psm2 = ["psm2_pose.orientation.x", "psm2_pose.orientation.y", "psm2_pose.orientation.z", "psm2_pose.orientation.w"]
    
    def set_transforms_metadata(self, metadata: DatasetMetadata):
        """Set the metadata for the transforms. This is useful for transforms that need to know the metadata, such as the normalization values."""
        self.transforms.set_metadata(metadata)

    def _get_metadata(self, embodiment_tag: EmbodimentTag) -> DatasetMetadata:
        """Get the metadata for the dataset.

        Returns:
            dict: The metadata for the dataset.
        """

        # 1. Modality metadata

        # 1.1. State and action modalities
        simplified_modality_meta: dict[str, dict] = {}
        le_modality_meta = LeRobotModalityMetadata.model_validate(self.modality_json)
        for modality in ["state", "action"]:
            simplified_modality_meta[modality] = {}
            le_state_action_meta: dict[str, LeRobotStateActionMetadata] = getattr(
                le_modality_meta, modality
            )
            for subkey in le_state_action_meta:
                state_action_dtype = np.dtype(le_state_action_meta[subkey].dtype)
                if np.issubdtype(state_action_dtype, np.floating):
                    continuous = True
                else:
                    continuous = False
                simplified_modality_meta[modality][subkey] = {
                    "absolute": le_state_action_meta[subkey].absolute,
                    "rotation_type": le_state_action_meta[subkey].rotation_type,
                    "shape": [
                        le_state_action_meta[subkey].end - le_state_action_meta[subkey].start
                    ],
                    "continuous": continuous,
                }

        # 1.2. Video modalities
        simplified_modality_meta["video"] = {}
        le_info = dvrk_video_info
        for new_key in le_modality_meta.video:
            original_key = le_modality_meta.video[new_key].original_key
            if original_key is None:
                original_key = new_key
            le_video_meta = le_info[original_key]
            height = le_video_meta["shape"][le_video_meta["names"].index("height")]
            width = le_video_meta["shape"][le_video_meta["names"].index("width")]
            # NOTE(FH): different lerobot dataset versions have different keys for the number of channels and fps
            channels = le_video_meta["shape"][le_video_meta["names"].index("channel")]
            fps = le_info["fps"]

            simplified_modality_meta["video"][new_key] = {
                "resolution": [width, height],
                "channels": channels,
                "fps": fps,
            }

        # 2. Dataset statistics
        null_dataset_statistics = {}
        for our_modality in ["state", "action"]:
            null_dataset_statistics[our_modality] = {}
            for subkey in simplified_modality_meta[our_modality]:
                null_dataset_statistics[our_modality][subkey] = {}
                state_action_meta = le_modality_meta.get_key_meta(f"{our_modality}.{subkey}")
                assert isinstance(state_action_meta, LeRobotStateActionMetadata)
                le_modality = state_action_meta.original_key
                for stat_name in ["min", "max", "mean", "std", "q01", "q99"]:
                    indices = np.arange(
                        state_action_meta.start,
                        state_action_meta.end,
                    )
                    stat = np.zeros(state_action_meta.end - state_action_meta.start)
                    null_dataset_statistics[our_modality][subkey][stat_name] = stat.tolist()

        # 3. Full dataset metadata
        metadata = DatasetMetadata(
            statistics=null_dataset_statistics,  # type: ignore
            modalities=simplified_modality_meta,  # type: ignore
            embodiment_tag=embodiment_tag,
        )

        return metadata

    def _process_episodes_parallel(self, num_workers=32):
        """Process episodes in parallel using multiple processes."""
        # Split episodes into chunks for parallel processing
        episode_chunks = self._split_into_chunks(self.episode_list, num_workers)
        
        # Create a pool of workers
        with multiprocessing.Pool(processes=num_workers) as pool:
            # Process each chunk in parallel - only get episode lengths
            results = list(tqdm(
                pool.imap(self._process_episode_chunk, episode_chunks),
                total=len(episode_chunks),
                desc="Processing episode chunks"
            ))
        
        # Combine episode lengths
        self.episode_lengths = []
        for lengths in results:
            self.episode_lengths.extend(lengths)
        
        # Process instructions sequentially to maintain exact same order
        self.instruction_to_idx = {}
        curr_idx = 0
        for episode_path, instruction, _ in self.episode_list:
            if instruction not in self.instruction_to_idx:
                self.instruction_to_idx[instruction] = curr_idx
                curr_idx += 1

    def _process_episode_chunk(self, episode_chunk):
        """Process a chunk of episodes and return the episode lengths."""
        lengths = []
        
        for episode_path, instruction, _ in episode_chunk:
            # Open the Zarr store and get episode length
            store = zarr.ZipStore(episode_path, mode='r')
            zarr_store = zarr.group(store=store)
            kinematics = zarr_store['kinematics'][:]
            episode_len = len(kinematics)
            lengths.append(episode_len)
            store.close()
        
        return lengths

    @staticmethod
    def _split_into_chunks(lst, num_chunks):
        """Split a list into approximately equal-sized chunks."""
        chunk_size = max(1, len(lst) // num_chunks)
        return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
    
    def __len__(self):
        return len(self.flattened_indices)

    def _retrieve_data(self, index):
        # Get episode index and timestep from flattened index
        episode_idx, start_ts = self.flattened_indices[index]
        episode_path, instruction, tissue_id = self.episode_list[episode_idx]
        
        # open the Zarr store using ZipStore
        store = zarr.ZipStore(episode_path, mode='r')
        zarr_store = zarr.group(store=store)

        kinematics = zarr_store['kinematics'][:]
        df = pd.DataFrame(kinematics)
        episode_len = len(df)
        img_idx = start_ts

        # for cutting tasks, length of the kinematics data extend longer than images, so image index must be capped
        if start_ts >= episode_len:
            start_ts = episode_len - 1

        # For cutting tasks, clip image index
        if instruction in ["go to the cutting position left tube", "go to the cutting position right tube"]:
            max_img_idx = episode_len - self.cutting_action_pad_size - 1
            img_idx = min(start_ts, max_img_idx)
        else:
            img_idx = min(start_ts, episode_len - 1)

        # # for cutting tasks, length of the kinematics data extend longer than images, so image index must be capped
        # if (instruction == "go to the cutting position left tube" or instruction == "go to the cutting position right tube") and start_ts >= episode_len - self.cutting_action_pad_size:
        #     img_idx = episode_len - self.cutting_action_pad_size - 1

        # get images
        img_l = np.array(zarr_store['left'][img_idx]) # da vinci endoscope image
        img_lw = np.array(zarr_store['endo_psm2'][img_idx]) # left wrist camera view image from PSM2
        img_rw = np.array(zarr_store['endo_psm1'][img_idx]) # right wrist camera view

        # apply color jitter if not is_eval
        if not self.is_eval:
            img_l = apply_color_jitter(img_l)
            img_lw = apply_color_jitter(img_lw)
            img_rw = apply_color_jitter(img_rw)

        if "throw" in instruction:
            raw_overlay_img = create_dot_image(episode_path, img_l)
            # Create a mask where the dot image has non-zero pixels
            nonzero_mask = np.any(raw_overlay_img != 0, axis=-1)
            # get img_l back to 
            # Keep the original image and oznly add the colored dots where they exist
            img_l_copy = img_l.copy()
            img_l_copy[nonzero_mask] = cv2.addWeighted(img_l, 0.5, raw_overlay_img, 0.5, 0)[nonzero_mask]
            img_l = img_l_copy

        # if img_l is not 960, 540, then resize it to 960, 540
        if img_l.shape[0] != self.target_size[0] or img_l.shape[1] != self.target_size[1]:
            img_l = resize_with_padding(img_l, self.target_size[1], self.target_size[0])

        # if img_lw is not 480, 640, then resize it to 480, 640
        if img_lw.shape[0] != self.target_size[0] or img_lw.shape[1] != self.target_size[1]:
            img_lw = resize_with_padding(img_lw, self.target_size[1], self.target_size[0])

        # if img_rw is not 480, 640, then resize it to 480, 640
        if img_rw.shape[0] != self.target_size[0] or img_rw.shape[1] != self.target_size[1]:
            img_rw = resize_with_padding(img_rw, self.target_size[1], self.target_size[0])

        # read qpos
        qpos_psm1 = df[self.header_name_qpos_psm1].iloc[start_ts].to_numpy()
        qpos_psm2 = df[self.header_name_qpos_psm2].iloc[start_ts].to_numpy()

        # --- Action fetching and downsampling logic ---
        fetch_horizon = self.action_horizon * self.downsample_factor
        available_steps = min(fetch_horizon, episode_len - start_ts)
        
        action_setpoints_psm1 = df[self.header_name_actions_psm1].iloc[start_ts : start_ts + available_steps].to_numpy()
        action_setpoints_psm2 = df[self.header_name_actions_psm2].iloc[start_ts : start_ts + available_steps].to_numpy()

        raw_diff_psm1 = quat_to_axis_angle_action(qpos_psm1, action_setpoints_psm1)
        raw_diff_psm2 = quat_to_axis_angle_action(qpos_psm2, action_setpoints_psm2)

        # Downsample
        downsampled_diff_psm1 = raw_diff_psm1[::self.downsample_factor]
        downsampled_diff_psm2 = raw_diff_psm2[::self.downsample_factor]

        # Pad the downsampled actions up to the target action horizon
        final_diff_psm1 = np.zeros((self.action_horizon, self.single_arm_action_dim))
        final_diff_psm2 = np.zeros((self.action_horizon, self.single_arm_action_dim))

        pad_len_psm1 = min(self.action_horizon, downsampled_diff_psm1.shape[0])
        pad_len_psm2 = min(self.action_horizon, downsampled_diff_psm2.shape[0])

        final_diff_psm1[:pad_len_psm1, :] = downsampled_diff_psm1[:pad_len_psm1]
        final_diff_psm2[:pad_len_psm2, :] = downsampled_diff_psm2[:pad_len_psm2]
        
        # stack the actions along column dim
        action = np.column_stack((final_diff_psm1, final_diff_psm2))
        # --- End Action logic ---

        # normalize data
        normalized_action = self.normalize_actions_func(action)
        # action = self.min_max_normalize_actions(action, use_percentiles=True)
        # Print min/max for each action index

        # set current poses to zeros (dvrk kinematics unreliable)
        qpos = np.zeros(self.state_dim)

        dataset_dict = {}
        
        dataset_dict["video.main"] = img_l[np.newaxis, ...]
        # dataset_dict["video.right"] = img_l
        dataset_dict["video.endo_psm1"] = img_rw[np.newaxis, ...]
        dataset_dict["video.endo_psm2"] = img_lw[np.newaxis, ...]


        dataset_dict["state.psm1"] = qpos[:7][np.newaxis, ...]
        dataset_dict["state.psm2"] = qpos[7:][np.newaxis, ...]
        dataset_dict["action.psm1"] = normalized_action[:,:self.single_arm_action_dim]
        dataset_dict["action.psm2"] = normalized_action[:,self.single_arm_action_dim:]
        dataset_dict["unnormalized_action.psm1"] = action[:,:self.single_arm_action_dim]
        dataset_dict["unnormalized_action.psm2"] = action[:,self.single_arm_action_dim:]
        dataset_dict["annotation.human.task_description"] = instruction

        store.close()

        return dataset_dict
        

    def __getitem__(self, index):
        return self.transforms(self._retrieve_data(index))

    def normalize_actions(self, diffs):
        """
        diffs: n_actions x 14 (delta position [3], delta orientation (axis-angle) [6], jaw angle (absolute) [1]) for both grippers
        return: normalized n_actions x 14 (zero mean unit variance)
        Note: only position and orientation are normalized, jaw angle is kept as is (absolute)
        the min / max value for each param is at the top of this script
        """
        normalized = (diffs - dataset_statistics["mean"]) / (dataset_statistics["std"] + 1e-6)
        return normalized

    def min_max_normalize_actions(self, diffs, use_percentiles=True):
        """
        Normalize actions using min-max scaling to the range [-1, 1].
        
        Args:
            diffs: n_actions x 14 array (delta position [3], delta orientation (axis-angle) [3], 
                jaw angle (absolute) [1]) for both grippers
            use_percentiles: If True, use q01/q99 percentiles instead of min/max to handle outliers
            
        Returns:
            normalized n_actions x 14 array scaled to [-1, 1] range
        """
        normalized = np.zeros_like(diffs)
        
        # Get normalization bounds from dataset statistics
        if use_percentiles:
            # Use q01 and q99 percentiles to be robust against outliers
            min_vals = np.array(dataset_statistics["q01"])
            max_vals = np.array(dataset_statistics["q99"])
        else:
            # Use absolute min and max values
            min_vals = np.array(dataset_statistics["min"])
            max_vals = np.array(dataset_statistics["max"])
        
        # Apply min-max normalization: normalized = 2 * (x - min) / (max - min) - 1
        # This scales values to [-1, 1] instead of [0, 1]
        for i in range(diffs.shape[1]):
            if max_vals[i] > min_vals[i]:
                normalized[:, i] = 2 * (diffs[:, i] - min_vals[i]) / (max_vals[i] - min_vals[i] + 1e-6) - 1
            else:
                normalized[:, i] = 0  # If no variation, set to 0
        
        return normalized

    def denormalize_action(self, action):
        """
        Denormalize the action using the dataset's statistics.
        """
        return action * dataset_statistics["std"] + dataset_statistics["mean"]

    

def calculate_dataset_statistics(dataset, output_json_path="dvrk_statistics.json", num_workers=None):
    """
    Calculate normalization statistics for the dataset's actions and states,
    incorporating the downsampling logic, using parallel processing.
    Skips loading images for efficiency.

    Args:
        dataset: An instance of EpisodicDatasetDvrkGeneric
        output_json_path: Path to save the JSON statistics
        num_workers: Number of parallel workers. Defaults to cpu_count.

    Returns:
        dict: Statistics for actions and states
    """
    print(f"Calculating dataset statistics with action_horizon={dataset.action_horizon}, downsample_factor={dataset.downsample_factor}...")

    if num_workers is None:
        num_workers = multiprocessing.cpu_count()
        print(f"Using {num_workers} workers.")

    # Initialize lists to collect all states (as zeros) and actions (downsampled, padded)
    all_states = [] # Note: State calculation is trivial (zeros), keeping it sequential for simplicity.
                    # Parallelization primarily benefits action calculation due to I/O.
    all_actions_nested = []


    # Package parameters for workers
    params = {
        'action_horizon': dataset.action_horizon,
        'downsample_factor': dataset.downsample_factor,
        'rel_action_func': quat_to_axis_angle_action,
        'single_arm_action_dim': dataset.single_arm_action_dim,
        'action_dim': dataset.action_dim,
        'state_dim': dataset.state_dim, # Needed for state vector size
        'header_name_qpos_psm1': dataset.header_name_qpos_psm1,
        'header_name_qpos_psm2': dataset.header_name_qpos_psm2,
        'header_name_actions_psm1': dataset.header_name_actions_psm1,
        'header_name_actions_psm2': dataset.header_name_actions_psm2,
    }

    # Create arguments for the pool
    # Each argument tuple: (episode_idx, episode_path, instruction, tissue_id, params)
    pool_args = [
        (idx, path, instr, tid, params)
        for idx, (path, instr, tid) in enumerate(dataset.episode_list)
    ]

    # Use multiprocessing Pool
    with multiprocessing.Pool(processes=num_workers) as pool:
        # Use imap_unordered for potentially better performance and tqdm for progress
        results_iterator = pool.imap_unordered(_process_single_episode_stats, pool_args)

        # We still need to calculate the number of state vectors sequentially
        # (or pass episode lengths back, but appending zeros is cheap)
        total_state_vectors = 0
        episode_lengths_for_states = {} # Store lengths to avoid recounting

        print("Processing episodes in parallel for action statistics...")
        for i, result in enumerate(tqdm(results_iterator, total=len(dataset.episode_list), desc="Processing episodes")):
             if result: # If processing was successful and returned actions
                 all_actions_nested.append(result)
                 # We need the episode length to calculate how many state vectors this episode contributed
                 # Re-extracting episode_idx is tricky with imap_unordered's result order.
                 # A simpler way is to recalculate state vectors based on the episodes that succeeded.
                 # Let's fetch the original args used for this result if possible, or re-iterate later.
                 # Alternative: Iterate through successful episodes *after* pool processing to get state counts.

    # Flatten the list of action lists
    print("Combining results...")
    all_actions = [action for sublist in all_actions_nested for action in sublist]

    # --- Recalculate State Vectors ---
    # Since the parallel processing might skip episodes due to errors,
    # we recalculate the state vectors based on the episodes that contributed actions.
    # This is slightly inefficient but ensures correctness.
    # We need episode lengths. Let's reuse the dataset's stored lengths if available.
    print("Calculating total state vectors...")
    # We need a robust way to know which episodes succeeded. Let's re-process lengths
    # sequentially or modify the worker to return episode length on success.
    # For simplicity now, assume dataset.episode_lengths is populated correctly.
    # NOTE: This part assumes the parallel worker *only* fails to return actions,
    # but doesn't affect the overall count of valid timesteps considered.
    # A more robust approach would return (episode_idx, actions, state_count) from worker.
    # Let's stick to the original logic's state count calculation for now, assuming dataset.episode_lengths is accurate.
    for episode_idx in tqdm(range(len(dataset.episode_list)), desc="Calculating state vector counts"):
        episode_len = dataset.episode_lengths[episode_idx] # Assumes this is populated
        action_horizon = dataset.action_horizon
        downsample_factor = dataset.downsample_factor
        valid_length = max(0, episode_len - (action_horizon * downsample_factor) + downsample_factor)
        for _ in range(valid_length):
             all_states.append(np.zeros(params['state_dim']))


    # Convert lists to numpy arrays
    all_states = np.array(all_states)
    all_actions = np.array(all_actions)

    print(f"Collected {len(all_states)} state samples and {len(all_actions)} action samples")

    if len(all_states) == 0 or len(all_actions) == 0:
        print("Error: No data collected. Cannot calculate statistics.")
        return {}

    stats_key = "dataset_statistics"

    # Calculate statistics using batch NumPy functions
    print("Calculating final statistics...")
    stats = {
        stats_key: {
            "min": np.min(all_actions, axis=0).tolist(),
            "max": np.max(all_actions, axis=0).tolist(),
            "mean": np.mean(all_actions, axis=0).tolist(),
            "std": np.std(all_actions, axis=0).tolist(),
            "q01": np.percentile(all_actions, 1, axis=0).tolist(),
            "q99": np.percentile(all_actions, 99, axis=0).tolist()
        }
    }

    # Print statistics
    print(f"\nDataset Statistics ({stats_key}) (Calculated with parallel processing):")
    print(json.dumps(stats[stats_key], indent=2))

    # Save to JSON file
    with open(output_json_path, 'w') as f:
        json.dump(stats, f, indent=2) # Save the whole stats dict

    print(f"\nStatistics saved to {output_json_path}")

    # Also print the action mean and std in a format that can be directly copied into the code
    print(f"\nAction mean and std for {stats_key}:")
    print(f"mean = np.array({stats[stats_key]['mean']})")
    print(f"std = np.array({stats[stats_key]['std']})")
    print(f"min = np.array({stats[stats_key]['min']})") # Optional: also print min/max if useful
    print(f"max = np.array({stats[stats_key]['max']})")

    return stats

# Define the worker function outside the class (or it could be a static method too)
def _process_single_episode_stats(args):
    """Worker function to process a single episode for statistics calculation."""
    episode_idx, episode_path, instruction, tissue_id, params = args
    action_horizon = params['action_horizon']
    downsample_factor = params['downsample_factor']
    rel_action_func = params['rel_action_func']
    single_arm_action_dim = params['single_arm_action_dim']
    header_name_qpos_psm1 = params['header_name_qpos_psm1']
    header_name_qpos_psm2 = params['header_name_qpos_psm2']
    header_name_actions_psm1 = params['header_name_actions_psm1']
    header_name_actions_psm2 = params['header_name_actions_psm2']

    episode_actions = []
    store = None # Initialize store to None
    try:
        store = zarr.ZipStore(episode_path, mode='r')
        zarr_store = zarr.group(store=store)

        kinematics = zarr_store['kinematics'][:]
        df = pd.DataFrame(kinematics)
        episode_len = len(df)

        if episode_len == 0:
            # print(f"Warning: Episode {episode_idx} ({episode_path}) has length 0. Skipping.")
            store.close()
            return episode_actions # Return empty list

        # Valid length calculation
        valid_length = max(0, episode_len - (action_horizon * downsample_factor) + downsample_factor)

        for start_ts in range(valid_length):
            qpos_psm1 = df[header_name_qpos_psm1].iloc[start_ts].to_numpy()
            qpos_psm2 = df[header_name_qpos_psm2].iloc[start_ts].to_numpy()

            fetch_horizon = action_horizon * downsample_factor
            available_steps = min(fetch_horizon, episode_len - start_ts)

            action_setpoints_psm1 = df[header_name_actions_psm1].iloc[start_ts : start_ts + available_steps].to_numpy()
            action_setpoints_psm2 = df[header_name_actions_psm2].iloc[start_ts : start_ts + available_steps].to_numpy()

            raw_diff_psm1 = rel_action_func(qpos_psm1, action_setpoints_psm1)
            raw_diff_psm2 = rel_action_func(qpos_psm2, action_setpoints_psm2)

            downsampled_diff_psm1 = raw_diff_psm1[::downsample_factor]
            downsampled_diff_psm2 = raw_diff_psm2[::downsample_factor]

            final_diff_psm1 = np.zeros((action_horizon, single_arm_action_dim))
            final_diff_psm2 = np.zeros((action_horizon, single_arm_action_dim))

            pad_len_psm1 = min(action_horizon, downsampled_diff_psm1.shape[0])
            pad_len_psm2 = min(action_horizon, downsampled_diff_psm2.shape[0])

            final_diff_psm1[:pad_len_psm1, :] = downsampled_diff_psm1[:pad_len_psm1]
            final_diff_psm2[:pad_len_psm2, :] = downsampled_diff_psm2[:pad_len_psm2]

            combined_padded_action = np.column_stack((final_diff_psm1, final_diff_psm2))
            episode_actions.extend(list(combined_padded_action))

        store.close()
    except Exception as e:
        print(f"Error processing episode {episode_idx} ({episode_path}): {e}", file=sys.stderr)
        if store is not None: # and store.is_open: # is_open might not be available
             try:
                 store.close()
             except: # Ignore errors during closing after another error
                 pass
        return [] # Return empty list on error
    return episode_actions


"""
Test the EpisodicDatasetDvrkGeneric class.
"""
if __name__ == "__main__":
    # Specify the data directory as needed

    # Create an instance of the dataset with downsampling
    print("Initializing dataset...")
    dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dir_list=["path/to/dvrk_data"],
        action_horizon = 16, # Default action horizon
        cutting_action_pad_size = 10,
        # transforms=transforms,
        embodiment_tag=EmbodimentTag.NEW_EMBODIMENT,
        downsample_factor=2, # Example: downsample by 2 (30Hz -> 15Hz)
    )
    print("Dataset initialized.")

     # Calculate statistics using the modified function
    stats = calculate_dataset_statistics(
        dataset,
        output_json_path=f"dvrk_statistics_downsample_{dataset.downsample_factor}_final_suturing.json",
        num_workers=16 # Specify number of workers, e.g., 8
    )
    exit()
    
#     # Create a DataLoader
    # dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    # limit = 50
    # for i, data in enumerate(dataloader):
    #     if i > limit:
    #         break
        # Unpack data from DataLoader
        # action is: delta position, delta rotation, jaw angle
        # img_l, img_lw, img_rw, action, qpos, instruction = data
        
#         # Plot images side by side
#         fig, axes = plt.subplots(1, 3, figsize=(15, 5))
#         axes[0].imshow(img_l.squeeze())
#         axes[0].set_title("Left Image")
#         axes[0].axis("off")
        
#         axes[1].imshow(img_lw.squeeze())
#         axes[1].set_title("Endoscopic Image PSM2")
#         axes[1].axis("off")
        
#         axes[2].imshow(img_rw.squeeze())
#         axes[2].set_title("Endoscopic Image PSM1")
#         axes[2].axis("off")
        
#         plt.tight_layout()
    
#         # Save the figure to a file
#         fig_path = os.path.join("sample.jpg")
#         plt.savefig(fig_path)
#         print(f"Figure saved to {fig_path}")
        
#         # Close the plot to free memory
#         plt.close(fig)
        
#         # Print out the numerical data for reference
#         print("Action:\n", action.shape)
#         print("Qpos:\n", qpos.shape)
#         print("Instruction:\n", instruction)
        
#         # Only show one batch for visualization, then break out of loop
#         break