import numpy as np
import torch
import os
import random
from torch.utils.data import DataLoader
import pandas as pd
import cv2
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm
import json
import re
import zarr
from typing import Union, Optional, Tuple
import numcodecs
import numpy as np
import simplejpeg
import functools
from random import randint
import multiprocessing
from collections import defaultdict, Counter
from typing import Dict, Any
import torch.distributed as dist
from prismatic.vla.datasets import (
    RLDSBatchTransform,
    StratifiedBatchSampler,
)


def _assert_shape(arr: np.ndarray, expected_shape: Tuple[Optional[int], ...]):
    """Asserts that the shape of an array matches the expected shape."""
    assert len(arr.shape) == len(expected_shape), f"Expected shape of length {len(expected_shape)}, but got {len(arr.shape)}"
    for dim, expected_dim in zip(arr.shape, expected_shape):
        if expected_dim is not None:
            assert dim == expected_dim, f"Expected dimension {expected_dim}, but got {dim}"

def rotate_image(image, angle):
    """Rotate the image by the given angle."""
    (h, w) = image.shape[:2]
    center = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1.0)
    rotated = cv2.warpAffine(image, M, (w, h))
    return rotated

def shift_image(image, shift_x, shift_y):
    """Shift the image by the given x and y offsets."""
    (h, w) = image.shape[:2]
    M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
    shifted_image = cv2.warpAffine(image, M, (w, h))
    return shifted_image

def create_heatmap_image(episode_path, img_l, target_img_size=(224, 224)):
    """
    Create a heatmap image with gradient indicating the path from entry to exit point.
    This is a visualization alternative to the mask image, providing a gradient visualization
    between the entry and exit points for suturing tasks.
    
    The heatmap uses RGB channels to visualize different components:
    - Red channel: normalized dx offsets to insertion point
    - Green channel: normalized dy offsets to insertion point
    - Blue channel: heat gradient (1.0 at entry point, 0.0 at exit point)
    
    This directly implements the visualization approach from offset_map_to_rgb_visual.
    
    Args:
        episode_path: Path to the episode data
        img_l: Reference image (used for dimensions)
        target_img_size: Target image size for eventual resizing
        
    Returns:
        A uint8 colored heatmap image with the same dimensions as the input image
    """
    clicked_point_path = episode_path.split(".zip")[0] + "_clicked_point.csv"
    # first row is xy entry point, second row is xy exit point
    clicked_point = pd.read_csv(clicked_point_path)
    entry_point = clicked_point.iloc[0]
    exit_point = clicked_point.iloc[1]

    orig_height, orig_width = img_l.shape[:-1]
    
    # Create coordinate grids
    y_coords = np.arange(orig_height)
    x_coords = np.arange(orig_width)
    y_grid, x_grid = np.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Calculate normalizing constant similar to the original function
    normalizing_constant = 250.0 * (min(orig_height, orig_width) / 224.0)
    
    # Calculate dx and dy to entry point
    dx = (x_grid - entry_point.x) / normalizing_constant
    dy = (y_grid - entry_point.y) / normalizing_constant
    
    # Calculate distances to entry and exit points
    d_entry = np.sqrt((x_grid - entry_point.x)**2 + (y_grid - entry_point.y)**2)
    d_exit = np.sqrt((x_grid - exit_point.x)**2 + (y_grid - exit_point.y)**2)
    
    # Create gradient heatmap: 1.0 at entry point, 0.0 at exit point
    eps = 1e-6  # Avoid division by zero
    heat = d_exit / (d_entry + d_exit + eps)
    
    # Normalize each channel to [0, 1] range
    def normalize(x):
        x_min = np.min(x)
        x_max = np.max(x)
        return (x - x_min) / (x_max - x_min + eps)
    
    dx_norm = normalize(dx)
    dy_norm = normalize(dy)
    heat_norm = normalize(heat)
    
    # Stack the channels to create RGB image
    rgb_image = np.stack([
        dx_norm,     # R channel
        dy_norm,     # G channel
        heat_norm    # B channel
    ], axis=-1)
    
    # Convert to uint8
    rgb_uint8 = (rgb_image * 255).astype(np.uint8)
    
    return rgb_uint8

def create_dot_image(episode_path, img_l, target_img_size=(224, 224)):
    """
    Create a raw dot image with a blue circle for the entry point and a green circle
    for the exit point. The red channel is zeroed out.
    This image is intended to be augmented geometrically alongside the main image
    and then overlaid.
    
    Args:
        episode_path: Path to the episode data
        img_l: Reference image (used for dimensions)
        target_img_size: Target image size (currently unused in this raw creation)
        
    Returns:
        A BGR uint8 image (H, W, 3) with blue/green dots on a black background.
    """
    clicked_point_path = episode_path.split(".zip")[0] + "_clicked_point.csv"
    # first row is xy entry point, second row is xy exit point
    clicked_point = pd.read_csv(clicked_point_path)
    entry_point = clicked_point.iloc[0]
    exit_point = clicked_point.iloc[1]
    
    # Create a blank image with the same dimensions as the input image
    orig_height, orig_width = img_l.shape[:-1]
    dot_img = np.zeros((orig_height, orig_width, 3), dtype=np.uint8)
    
    # Draw a blue circle (B,G,R = 255,0,0) at the entry point
    cv2.circle(dot_img, (int(entry_point.x), int(entry_point.y)), radius=10, color=(255, 0, 0), thickness=-1)
    
    # Draw a green circle (B,G,R = 0,255,0) at the exit point
    cv2.circle(dot_img, (int(exit_point.x), int(exit_point.y)), radius=10, color=(0, 255, 0), thickness=-1)
    
    # Zero out the red channel completely
    dot_img[:,:,2] = 0

    # convert to rgb
    dot_img = cv2.cvtColor(dot_img, cv2.COLOR_BGR2RGB)
    
    return dot_img

def create_throw_prompt(episode_path, img_l=None, target_img_size=(224, 224)):
    """
    Create a text prompt for a throwing task with normalized coordinates.
    
    Args:
        episode_path: Path to the episode data
        img_l: Reference image (used for dimensions if provided)
        target_img_size: Target image size for normalization
        
    Returns:
        A string prompt with normalized coordinates for entry and exit points
    """
    clicked_point_path = episode_path.split(".zip")[0] + "_clicked_point.csv"
    # first row is xy entry point, second row is xy exit point
    clicked_point = pd.read_csv(clicked_point_path)
    entry_point = clicked_point.iloc[0]
    exit_point = clicked_point.iloc[1]

    # Get original dimensions
    if img_l is not None:
        orig_height, orig_width = img_l.shape[:-1]
    else:
        # Use default values if no image is provided
        orig_height, orig_width = 480, 640  # Common default resolution
        
    target_width, target_height = target_img_size
    orig_aspect = orig_width / orig_height
    target_aspect = target_width / target_height
    
    if orig_aspect > target_aspect:
        # Original image is wider - scale by width, pad height
        scale_factor = target_width / orig_width
        scaled_height = int(orig_height * scale_factor)
        pad_top = (target_height - scaled_height) // 2        
        # Transform points
        scaled_entry_x = entry_point.x * scale_factor / target_width
        scaled_entry_y = (entry_point.y * scale_factor + pad_top) / target_height
        scaled_exit_x = exit_point.x * scale_factor / target_width
        scaled_exit_y = (exit_point.y * scale_factor + pad_top) / target_height
    else:
        # Original image is taller - scale by height, pad width
        scale_factor = target_height / orig_height
        scaled_width = int(orig_width * scale_factor)
        pad_left = (target_width - scaled_width) // 2
        
        # Transform points
        scaled_entry_x = (entry_point.x * scale_factor + pad_left) / target_width
        scaled_entry_y = entry_point.y * scale_factor / target_height
        scaled_exit_x = (exit_point.x * scale_factor + pad_left) / target_width
        scaled_exit_y = exit_point.y * scale_factor / target_height
    
    # Ensure values are in 0-1 range
    scaled_entry_x = int(max(0, min(1, scaled_entry_x)) * 1024)
    scaled_entry_y = int(max(0, min(1, scaled_entry_y)) * 1024)
    scaled_exit_x = int(max(0, min(1, scaled_exit_x)) * 1024)
    scaled_exit_y = int(max(0, min(1, scaled_exit_y)) * 1024)
    # 0 pad to 4 digits
    scaled_entry_x = f"{scaled_entry_x:04d}"
    scaled_entry_y = f"{scaled_entry_y:04d}"
    scaled_exit_x = f"{scaled_exit_x:04d}"
    scaled_exit_y = f"{scaled_exit_y:04d}"
    
    prompt = f"Insert the needle at <loc{scaled_entry_y}><loc{scaled_entry_x}><loc{scaled_entry_y}><loc{scaled_entry_x}> and ensure it exits at <loc{scaled_exit_y}><loc{scaled_exit_x}><loc{scaled_exit_y}><loc{scaled_exit_x}>."
    
    return prompt


def resize_with_padding(image, target_width, target_height):
    """
    Resize an image to target dimensions while preserving aspect ratio using padding.
    
    Args:
        image: Input image
        target_width: Desired width
        target_height: Desired height
        
    Returns:
        Resized and padded image with dimensions (target_height, target_width)
    """
    # Get original dimensions
    h, w = image.shape[:2]
    
    # Calculate target aspect ratio and original aspect ratio
    target_aspect = target_width / target_height
    aspect = w / h
    
    # Determine new dimensions while preserving aspect ratio
    if aspect > target_aspect:
        # Image is wider than target aspect ratio
        new_w = target_width
        new_h = int(new_w / aspect)
    else:
        # Image is taller than target aspect ratio
        new_h = target_height
        new_w = int(new_h * aspect)
    
    # Resize image while preserving aspect ratio
    resized = cv2.resize(image, (new_w, new_h))
    
    # Create black canvas of target size
    padded = np.zeros((target_height, target_width, 3), dtype=np.uint8)
    
    # Calculate padding offsets to center the image
    pad_x = (target_width - new_w) // 2
    pad_y = (target_height - new_h) // 2
    
    # Place resized image on the canvas
    padded[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
    
    return padded


def apply_image_augmentations(image, extra_img=None, scale=0.95, target_size=(224, 224), 
                             brightness=0.3, contrast=0.4, saturation=0.5, hue=0.08, rotation_range=5, is_wrist=False):
    """
    Apply a series of image augmentations similar to VideoCrop, VideoResize, and VideoColorJitter.
    
    Args:
        image: Input image (H, W, C) in uint8 format
        scale: Scale factor for cropping (0.95 means crop to 95% of original size)
        target_size: Target size for resizing (height, width)
        brightness: Maximum brightness adjustment factor
        contrast: Maximum contrast adjustment factor
        saturation: Maximum saturation adjustment factor
        hue: Maximum hue adjustment factor
        
    Returns:
        Augmented image with same dimensions as target_size
    """
    if image.dtype != np.uint8:
        image = image.astype(np.uint8)
    
    if not is_wrist:
        # 1. Crop (similar to VideoCrop)
        h, w = image.shape[:2]
        crop_h, crop_w = int(h * scale), int(w * scale)
        start_h = random.randint(0, h - crop_h) if h > crop_h else 0
        start_w = random.randint(0, w - crop_w) if w > crop_w else 0
        cropped = image[start_h:start_h+crop_h, start_w:start_w+crop_w]
        
        # resize to original size
        image = cv2.resize(cropped, (w, h), interpolation=cv2.INTER_LINEAR)

        # 2. Random rotation
        # pick random rotation angle from rotation_range
        rotation_angle = random.uniform(-rotation_range, rotation_range)
        image = rotate_image(image, rotation_angle)
        
        if extra_img is not None:
            cropped_extra_img = extra_img[start_h:start_h+crop_h, start_w:start_w+crop_w]
            extra_img = cv2.resize(cropped_extra_img, (w, h), interpolation=cv2.INTER_LINEAR)
            extra_img = rotate_image(extra_img, rotation_angle)
    
    # 3. Color jitter (similar to VideoColorJitter)
    # Convert to HSV for easier color manipulation
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
    
    # Apply random brightness adjustment
    brightness_factor = 1.0 + random.uniform(-brightness, brightness)
    hsv[:,:,2] = hsv[:,:,2] * brightness_factor
    
    # Apply random saturation adjustment
    saturation_factor = 1.0 + random.uniform(-saturation, saturation)
    hsv[:,:,1] = hsv[:,:,1] * saturation_factor
    
    # Apply random hue adjustment
    hue_factor = random.uniform(-hue, hue) * 180  # OpenCV uses 0-180 for hue
    hsv[:,:,0] = (hsv[:,:,0] + hue_factor) % 180
    
    # Clip HSV values to valid ranges
    hsv[:,:,0] = np.clip(hsv[:,:,0], 0, 180)
    hsv[:,:,1] = np.clip(hsv[:,:,1], 0, 255)
    hsv[:,:,2] = np.clip(hsv[:,:,2], 0, 255)
    
    # Convert back to BGR
    augmented = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
    
    # Apply random contrast adjustment (after HSV conversion)
    contrast_factor = 1.0 + random.uniform(-contrast, contrast)
    augmented = np.clip(augmented * contrast_factor, 0, 255).astype(np.uint8)
    
    return augmented, extra_img

def save_obs_debug_imgs(obs: Dict[str, Any], targeting_type: str, idx: int):
    """
    Save the observation images to a directory for debugging purposes.
    """
    debug_dir = os.path.join("debug_imgs_complete", targeting_type)

    os.makedirs(debug_dir, exist_ok=True)

    for key, value in obs["observation"].items():
        if "wrist" in key or "image" in key:
            # check if key is wrist_mask and it actually has non-zero values
            if key == "wrist_mask" and np.sum(value) == 0:
                continue
            cv2.imwrite(os.path.join(debug_dir, f"{idx}_{key}.png"), cv2.cvtColor(value, cv2.COLOR_RGB2BGR))

def quat_to_axis_angle_action(action):
    """
    Convert a quaternion action to an axis-angle action.
    
    Args:
        action: Tensor of shape (..., 8) representing [x,y,z,qw,qx,qy,qz,jaw]

    Returns:
        axis_angle_actions: Tensor of shape (..., 7) representing [x,y,z,rx,ry,rz,jaw]
    """
    quat_actions = action[:, 3:7]  # Shape: (n_actions, 4)

    r_actions = R.from_quat(quat_actions).as_rotvec()

    # Prepare the final diff array
    axis_angle_actions = np.zeros((action.shape[0], 7))  # Shape: (n_actions, 7)

    # Populate the diff_expand array
    axis_angle_actions[:, 0:3] = action[:, 0:3]     # Delta translation
    axis_angle_actions[:, 3:6] = r_actions          # Delta rotation (axis-angle)
    axis_angle_actions[:, 6] = action[:, 7]         # Abs Jaw
    
    return axis_angle_actions    

class JpegCodec(numcodecs.abc.Codec):
    """Codec for JPEG compression.
    Encodes image chunks as JPEGs. Assumes that chunks are uint8 with shape (1, H, W, 3).
    """
    codec_id = "pi_jpeg"

    def __init__(self, quality: int = 95):
        super().__init__()
        self.quality = quality

    def encode(self, buf):
        _assert_shape(buf, (1, None, None, 3))
        assert buf.dtype == "uint8"
        return simplejpeg.encode_jpeg(buf[0], quality=self.quality)

    def decode(self, buf, out=None):
        img = simplejpeg.decode_jpeg(buf, buffer=out)
        return img[np.newaxis, ...]

@functools.cache
def register_codecs():
    """Register the custom codecs."""
    numcodecs.register_codec(JpegCodec)

register_codecs()

def get_robot_episodes(main_folders):
    """
    Get a list of episodes from the robot dataset. 
    It returns a list of tuples, where each tuple contains: the episode path, the language instruction, and the tissue index.
    """
    episode_list = []

    # List all 'tissue_*' directories in the main folder
    for main_folder in main_folders:

        tissue_folders = [
            f for f in os.listdir(main_folder)
            if f.startswith('tissue_') and os.path.isdir(os.path.join(main_folder, f))
        ]
        
        for tissue_folder in tissue_folders:
            tissue_path = os.path.join(main_folder, tissue_folder)
            
            # Extract tissue index from the tissue folder name
            tissue_index_match = re.search(r'tissue_(\d+)', tissue_folder)
            if tissue_index_match:
                tissue_index = int(tissue_index_match.group(1))
            else:
                tissue_index = None  # or handle as needed
            
            # List all instruction folders within the tissue folder
            instruction_folders = [
                f for f in os.listdir(tissue_path)
                if os.path.isdir(os.path.join(tissue_path, f))
            ]
            
            for instr_folder in instruction_folders:
                instr_path = os.path.join(tissue_path, instr_folder)
                
                # Extract language instruction from the folder name
                # Remove leading numbers and underscores
                instr_name = re.sub(r'^\d+_*', '', instr_folder)
                # Remove digits elsewhere in the string
                instr_name = re.sub(r'\d+', '', instr_name)
                # Remove the word 'recovery' if it exists
                instr_name = instr_name.replace('recovery', '')
                # Replace underscores with spaces
                instr_name = instr_name.replace('_', ' ')
                # Strip leading and trailing whitespace
                instr_name = instr_name.strip()
                
                # List all episode files (zipped files) within the instruction folder
                episode_files = [
                    f for f in os.listdir(instr_path)
                    if os.path.isfile(os.path.join(instr_path, f)) and f.endswith('.zip')
                ]

                # print(f"{len(episode_files)} episode files for main_folder: {main_folder}, tissue_folder: {tissue_folder}, instruction: {instr_name}")
                
                for episode_file in episode_files:
                    episode_path = os.path.join(instr_path, episode_file)
                    episode_list.append((episode_path, instr_name, tissue_index))

    print(f"Found {len(episode_list)} episodes")
    
    return episode_list

# Example usage: for testing get_robot_episodes()
# Replace '/path/to/your/dataset' with the actual path to your main folder
# episodes = get_robot_episodes("base_chole_clipping_cutting/processed_data_zipped_pi/")

# # Print the list of episodes, their corresponding language instructions, and tissue indices
# for episode_path, instruction, tissue_index in episodes:
#     print(f"Episode Path: {episode_path}")
#     print(f"Language Instruction: {instruction}")
#     print(f"Tissue Index: {tissue_index}\n")

# assert(False)

class EpisodicDatasetDvrkGeneric(torch.utils.data.Dataset):
    def __init__(
        self,
        robot_base_dir_list,
        action_horizon,
        norm_stats,
        batch_transform: RLDSBatchTransform,
        batch_size: int,
        cutting_action_pad_size = 10,
        skip_images = False,
        image_aug = True,
        targeting_strategy = "dot",
        norm_stats_key = "norm_stats",
        ):
        """
        Initialize the DVRK dataset.
        
        Args:
            robot_base_dir_list: List of directories containing robot data
            action_horizon: Number of future actions to predict
            batch_transform: Transform to apply to batches
            batch_size: Batch size
            cutting_action_pad_size: Padding size for cutting actions
            skip_images: Whether to skip loading images
            image_aug: Whether to apply image augmentations
            targeting_strategy: Strategy for visualizing suturing targets:
                "dot": Draws green circles at entry/exit points directly on the image
                "mask": Creates a separate mask image with white circles
                "heatmap": Creates a colored heatmap with gradient from entry to exit points
                "none": No targeting strategy
        """
        super().__init__()
        assert len(robot_base_dir_list) > 0, "There must be at least one robot base directory"
        print(f"\n\nUsing the following datasets: {robot_base_dir_list}\n\n\n")
        self.skip_images = skip_images
        print(f"skip_images: {self.skip_images}")
        self.norm_stats = norm_stats
        self.robot_base_dir_list = robot_base_dir_list
        self.action_horizon = action_horizon
        self.cutting_action_pad_size = cutting_action_pad_size
        self.fps = 30
        self.batch_transform = batch_transform
        self.batch_size = batch_size
        self.image_aug = image_aug
        self.targeting_strategy = targeting_strategy
        self.norm_stats_key = norm_stats_key

        # Create flattened list of all timesteps across episodes
        self.flattened_indices = []
        self.episode_lengths = []
        
        # Open each episode to get its length
        # Dictionary to store unique instructions and their indices
        self.instruction_to_idx = {}
        curr_idx = 0
        
        # Get list of episodes
        self.episode_list = get_robot_episodes(self.robot_base_dir_list)

        # Metadata paths for all directories
        meta_paths = [os.path.join(robot_base_dir, "meta.json") for robot_base_dir in self.robot_base_dir_list]
        task_meta_paths = [os.path.join(robot_base_dir, "task_meta.json") for robot_base_dir in self.robot_base_dir_list]

        # Dictionary to store task meta
        task_meta = {}

        # Check if metadata exists for all directories
        if all(os.path.exists(meta_path) and os.path.exists(task_meta_path) for meta_path, task_meta_path in zip(meta_paths, task_meta_paths)):
            print("Metadata files found. Loading from all directories.")
            for meta_path, task_meta_path in zip(meta_paths, task_meta_paths):
                with open(meta_path, "r") as meta_file:
                    meta_data = [json.loads(line) for line in meta_file]
                with open(task_meta_path, "r") as task_meta_file:
                    task_data = [json.loads(line) for line in task_meta_file]

                # Populate episode_lengths and instruction_to_idx from loaded metadata
                self.episode_lengths.extend([entry["length"] for entry in meta_data])
                for entry in task_data:
                    instruction = entry["tasks"][0]
                    if instruction not in self.instruction_to_idx:
                        self.instruction_to_idx[instruction] = curr_idx
                        curr_idx += 1
                        # Initialize task meta entry if not already
                        if instruction not in task_meta:
                            task_meta[instruction] = {
                                "episode_indices": [],
                                "tissue_ids": []
                            }
                    # Append episode index and tissue ID for the task
                    episode_idx = entry["episode_index"]
                    tissue_id = entry.get("tissue_id", None)  # Add tissue_id if it exists
                    task_meta[instruction]["episode_indices"].append(episode_idx)
                    task_meta[instruction]["tissue_ids"].append(tissue_id)
        else:
            print("Metadata files not found for some directories. Generating metadata...")
            for robot_base_dir in self.robot_base_dir_list:
                meta_data = []
                task_data = []
                for episode_idx, (episode_path, instruction, _) in enumerate(tqdm(self.episode_list)):
                    # Ensure the episode belongs to one of the robot_base_dirs
                    if not any(episode_path.startswith(robot_base_dir) for robot_base_dir in self.robot_base_dir_list):
                        continue
                    # Ensure the episode belongs to the current robot_base_dir
                    if not episode_path.startswith(robot_base_dir):
                        continue

                    store = zarr.ZipStore(episode_path, mode='r')
                    zarr_store = zarr.group(store=store)
                    kinematics = zarr_store['kinematics'][:]
                    episode_len = len(kinematics)
                    self.episode_lengths.append(episode_len)
                    store.close()

                    # Append metadata for the current episode
                    meta_data.append({
                        "episode_index": episode_idx,
                        "tasks": [instruction],
                        "length": episode_len
                    })

                    # Add instruction to dict if not seen before
                    if instruction not in self.instruction_to_idx:
                        self.instruction_to_idx[instruction] = curr_idx
                        curr_idx += 1
                        task_data.append({
                            "episode_index": episode_idx,
                            "tasks": [instruction],
                            "tissue_id": instruction,  # You can set this to any relevant tissue ID
                        })

                    # Append the task's episode index and tissue ID to the task_meta
                    if instruction not in task_meta:
                        task_meta[instruction] = {
                            "episode_indices": [],
                            "tissue_ids": []
                        }
                    task_meta[instruction]["episode_indices"].append(episode_idx)
                    task_meta[instruction]["tissue_ids"].append(instruction)  # This should be the actual tissue ID

                # Save metadata to JSON files for the current directory
                meta_path = os.path.join(robot_base_dir, "meta.json")
                task_meta_path = os.path.join(robot_base_dir, "task_meta.json")

                with open(meta_path, "w") as meta_file:
                    for entry in meta_data:
                        meta_file.write(json.dumps(entry) + "\n")
                print(f"Metadata saved to {meta_path}")
                with open(task_meta_path, "w") as task_meta_file:
                    for entry in task_data:
                        task_meta_file.write(json.dumps(entry) + "\n")
                print(f"Task Metadata saved to {task_meta_path}")

        # Save task_meta to a global file for later sampling
        task_meta_global_path = os.path.join(self.robot_base_dir_list[0], "task_meta_global.json")
        with open(task_meta_global_path, "w") as task_meta_global_file:
            json.dump(task_meta, task_meta_global_file)
        print(f"Global Task Metadata saved to {task_meta_global_path}")

        # Sort episode list by (task name) to group tasks together 
        self.episode_list.sort(key=lambda tup: tup[1])

        # Create flattened index + per-task sample indices
        self.flattened_indices = []
        self.task_to_indices = defaultdict(list)

        for episode_idx, (episode_path, instruction, _) in enumerate(self.episode_list):
            episode_len = self.episode_lengths[episode_idx]
            task_idx = self.instruction_to_idx[instruction]

            for ts in range(episode_len):
                flat_idx = len(self.flattened_indices)
                self.flattened_indices.append((episode_idx, ts))
                self.task_to_indices[task_idx].append(flat_idx)

        # psm = patient side manipulator
        # qpos = current pose read from the robot
        self.header_name_qpos_psm1 = ["psm1_pose.position.x", "psm1_pose.position.y", "psm1_pose.position.z",
                                "psm1_pose.orientation.x", "psm1_pose.orientation.y", "psm1_pose.orientation.z", "psm1_pose.orientation.w",
                                "psm1_jaw"]
        
        self.header_name_qpos_psm2 = ["psm2_pose.position.x", "psm2_pose.position.y", "psm2_pose.position.z",
                                "psm2_pose.orientation.x", "psm2_pose.orientation.y", "psm2_pose.orientation.z", "psm2_pose.orientation.w",
                                "psm2_jaw"]

        # sp = setpoint (i.e. when you teleoperate, it generate setpoints for the robot to reach)
        self.header_name_actions_psm1 = ["psm1_sp.position.x", "psm1_sp.position.y", "psm1_sp.position.z",
                                    "psm1_sp.orientation.x", "psm1_sp.orientation.y", "psm1_sp.orientation.z", "psm1_sp.orientation.w",
                                    "psm1_jaw_sp"]

        self.header_name_actions_psm2 = ["psm2_sp.position.x", "psm2_sp.position.y", "psm2_sp.position.z",
                                    "psm2_sp.orientation.x", "psm2_sp.orientation.y", "psm2_sp.orientation.z", "psm2_sp.orientation.w",
                                    "psm2_jaw_sp"]
        
        self.quat_cp_psm1 = ["psm1_pose.orientation.x", "psm1_pose.orientation.y", "psm1_pose.orientation.z", "psm1_pose.orientation.w"]
        self.quat_cp_psm2 = ["psm2_pose.orientation.x", "psm2_pose.orientation.y", "psm2_pose.orientation.z", "psm2_pose.orientation.w"]

    @property
    def state_dim(self) -> int:
        # Currently 14 (7 for psm1 qpos + 7 for psm2 qpos, although it's zeroed)
        return len(self.header_name_qpos_psm1) + len(self.header_name_qpos_psm2) - 2 # -2 because jaw is single dim

    @property
    def action_dim(self) -> int:
        # 7 relative dims for psm1 + 7 relative dims for psm2
        return 14

    def __len__(self):
        return len(self.flattened_indices)

    def __getitem__(self, index):
        # Get episode index and timestep from flattened index
        data = self.__fetch_idx__(index)
        return self.batch_transform(data)

    def __fetch_idx__(self, index):
        # Get episode index and timestep from flattened index
        episode_idx, start_ts = self.flattened_indices[index]
        episode_path, instruction, tissue_id = self.episode_list[episode_idx]   
        
        # open the Zarr store using ZipStore
        store = zarr.ZipStore(episode_path, mode='r')
        zarr_store = zarr.group(store=store)

        kinematics = zarr_store['kinematics'][:]
        df = pd.DataFrame(kinematics)
        episode_len = len(df)
        img_idx = start_ts

        # for cutting tasks, length of the kinematics data extend longer than images, so image index must be capped
        if start_ts >= episode_len:
            start_ts = episode_len - 1

        # For cutting tasks, clip image index
        if instruction in ["go to the cutting position left tube", "go to the cutting position right tube"]:
            max_img_idx = episode_len - self.cutting_action_pad_size - 1
            img_idx = min(start_ts, max_img_idx)
        else:
            img_idx = min(start_ts, episode_len - 1)

        # get images
        if not self.skip_images:
            img_l = np.array(zarr_store['left'][img_idx])  # da vinci endoscope image
            img_lw = np.array(zarr_store['endo_psm2'][img_idx])  # left wrist camera view image from PSM2
            img_rw = np.array(zarr_store['endo_psm1'][img_idx])  # right wrist camera view image from PSM1

        raw_overlay_img = None  # For "dot" strategy, holds the raw dot image before augmentation
        targeting_img = None   # For "mask" or "heatmap" strategies, holds the raw special image

        if "throw" in instruction and not self.skip_images:
            if self.targeting_strategy == "dot":
                raw_overlay_img = create_dot_image(episode_path, img_l, target_img_size=(224, 224))
            elif self.targeting_strategy == "mask":
                targeting_img = create_dot_image(episode_path, img_l, target_img_size=(224, 224))
            elif self.targeting_strategy == "heatmap":
                targeting_img = create_heatmap_image(episode_path, img_l, target_img_size=(224, 224))

        # read actions and qpos
        action_psm1 = df[self.header_name_actions_psm1].iloc[start_ts:start_ts + self.action_horizon].to_numpy()
        action_psm2 = df[self.header_name_actions_psm2].iloc[start_ts:start_ts + self.action_horizon].to_numpy()

        diff_psm1 = np.zeros((self.action_horizon, 7))
        diff_psm2 = np.zeros((self.action_horizon, 7))

        raw_diff_psm1 = quat_to_axis_angle_action(action_psm1)
        raw_diff_psm2 = quat_to_axis_angle_action(action_psm2)

        # Pad the actions up to the action horizon
        diff_psm1[:raw_diff_psm1.shape[0], :] = raw_diff_psm1
        diff_psm2[:raw_diff_psm2.shape[0], :] = raw_diff_psm2

        # stack the actions along column dim
        action = np.column_stack((diff_psm1, diff_psm2))

        dataset_dict = {
            "dataset_name": "dvrk",
        }

        # Must do image augmentations before overlaying the targeting image
        if self.image_aug:
            # If we have a targeting image, apply augmentations to both images
            if targeting_img is not None:
                img_l, targeting_img = apply_image_augmentations(img_l, extra_img=targeting_img, is_wrist=False)
            else:
                img_l, raw_overlay_img = apply_image_augmentations(img_l, extra_img=raw_overlay_img, is_wrist=False)
                if raw_overlay_img is not None:
                    # Create a mask where the dot image has non-zero pixels
                    nonzero_mask = np.any(raw_overlay_img != 0, axis=-1)
                    # Keep the original image and oznly add the colored dots where they exist
                    img_l_copy = img_l.copy()
                    img_l_copy[nonzero_mask] = cv2.addWeighted(img_l, 0.5, raw_overlay_img, 0.5, 0)[nonzero_mask]
                    img_l = img_l_copy

            img_rw, _ = apply_image_augmentations(img_rw, is_wrist=True)
            img_lw, _ = apply_image_augmentations(img_lw, is_wrist=True)
        
        # If we have a raw overlay image, overlay it on the original image
        elif raw_overlay_img is not None:
            img_l = cv2.addWeighted(img_l, 1.0, raw_overlay_img, 0.5, 0)
        
        if not self.skip_images:
            dataset_dict["observation"] = {
                "image_primary": resize_with_padding(img_l, 224, 224),
                "wrist_right": resize_with_padding(img_rw, 224, 224),
                "wrist_left": resize_with_padding(img_lw, 224, 224),
            }
            if targeting_img is not None:
                # Note: only calling it "wrist*" because that is how additional images are added
                dataset_dict["observation"]["wrist_mask"] = resize_with_padding(targeting_img, 224, 224)
            elif self.targeting_strategy in ["mask", "heatmap"]:
                dataset_dict["observation"]["wrist_mask"] = np.zeros((224, 224, 3), dtype=np.uint8)
        else:
            dataset_dict["observation"] = {
                "image_primary": np.zeros((1, 1, 3)),
                "wrist_right": np.zeros((1, 1, 3)),
                "wrist_left": np.zeros((1, 1, 3)),
            }

        # normalize actions
        action = self.min_max_normalize_actions(action)

        dataset_dict["action"] = action
        dataset_dict["timestamp"] = start_ts / self.fps
        dataset_dict["task"] = {
            "language_instruction": instruction.encode(),
        }      

        store.close()

        return dataset_dict

    def min_max_normalize_actions(self, diffs, use_percentiles=True):
        """
        Normalize actions using min-max scaling to the range [-1, 1].
        
        Args:
            diffs: n_actions x 14 array (delta position [3], delta orientation (axis-angle) [3], 
                jaw angle (absolute) [1]) for both grippers
            use_percentiles: If True, use q01/q99 percentiles instead of min/max to handle outliers
            
        Returns:
            normalized n_actions x 14 array scaled to [-1, 1] range
        """
        normalized = np.zeros_like(diffs)
        
        # Get normalization bounds from dataset statistics
        if use_percentiles:
            # Use q01 and q99 percentiles to be robust against outliers
            min_vals = np.array(self.norm_stats[self.norm_stats_key]["action"]["q01"])
            max_vals = np.array(self.norm_stats[self.norm_stats_key]["action"]["q99"])
        else:
            # Use absolute min and max values
            min_vals = np.array(self.norm_stats[self.norm_stats_key]["action"]["min"])
            max_vals = np.array(self.norm_stats[self.norm_stats_key]["action"]["max"])
        
        # Apply min-max normalization: normalized = 2 * (x - min) / (max - min) - 1
        # This scales values to [-1, 1] instead of [0, 1]
        for i in range(diffs.shape[1]):
            if max_vals[i] > min_vals[i]:
                normalized[:, i] = 2 * (diffs[:, i] - min_vals[i]) / (max_vals[i] - min_vals[i] + 1e-6) - 1
            else:
                normalized[:, i] = 0  # If no variation, set to 0
        
        return normalized
    
    def _process_episodes_parallel(self, num_workers=32):
        """Process episodes in parallel using multiple processes."""
        # Split episodes into chunks for parallel processing
        episode_chunks = self._split_into_chunks(self.episode_list, num_workers)
        
        # Create a pool of workers
        with multiprocessing.Pool(processes=num_workers) as pool:
            # Process each chunk in parallel - only get episode lengths
            results = list(tqdm(
                pool.imap(self._process_episode_chunk, episode_chunks),
                total=len(episode_chunks),
                desc="Processing episode chunks"
            ))
        
        # Combine episode lengths
        self.episode_lengths = []
        for lengths in results:
            self.episode_lengths.extend(lengths)
        
        # Process instructions sequentially to maintain exact same order
        self.instruction_to_idx = {}
        curr_idx = 0
        for episode_path, instruction, _ in self.episode_list:
            if instruction not in self.instruction_to_idx:
                self.instruction_to_idx[instruction] = curr_idx
                curr_idx += 1

    def _process_episode_chunk(self, episode_chunk):
        """Process a chunk of episodes and return the episode lengths."""
        lengths = []
        
        for episode_path, instruction, _ in episode_chunk:
            # Open the Zarr store and get episode length
            store = zarr.ZipStore(episode_path, mode='r')
            zarr_store = zarr.group(store=store)
            kinematics = zarr_store['kinematics'][:]
            episode_len = len(kinematics)
            lengths.append(episode_len)
            store.close()
        
        return lengths

    @staticmethod
    def _split_into_chunks(lst, num_chunks):
        """Split a list into approximately equal-sized chunks."""
        chunk_size = max(1, len(lst) // num_chunks)
        return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]

def calculate_dataset_statistics(dataset, output_json_path="dvrk_statistics.json"):
    """
    Calculate normalization statistics for the dataset's actions and states,
    mimicking the data processing (padding, zero states) used during training/compute_norm_stats.
    Skips loading images for efficiency.

    Args:
        dataset: An instance of EpisodicDatasetDvrkGeneric
        output_json_path: Path to save the JSON statistics

    Returns:
        dict: Statistics for actions and states
    """
    print(f"Calculating dataset statistics with action_horizon={dataset.action_horizon}...")

    # Initialize lists to collect all states (as zeros) and actions (padded)
    all_states = []
    all_actions = []

    print(f"Using action_horizon={dataset.action_horizon}")

    # Track progress with tqdm
    for episode_idx, (episode_path, instruction, tissue_id) in enumerate(tqdm(dataset.episode_list, desc="Processing episodes")):
        # Open the Zarr store
        try:
            store = zarr.ZipStore(episode_path, mode='r')
            zarr_store = zarr.group(store=store)

            # Get kinematics data
            kinematics = zarr_store['kinematics'][:]
            df = pd.DataFrame(kinematics)
            episode_len = len(df)

            if episode_len == 0:
                print(f"Warning: Episode {episode_idx} ({episode_path}) has length 0. Skipping.")
                store.close()
                continue

            # Process each timestep in the episode
            for start_ts in range(episode_len):
                # 1. State Handling: Append zero vector
                all_states.append(np.zeros(dataset.state_dim))

                # 2. Action Handling: Compute padded relative actions
                # Get current qpos needed for computing relative actions
                qpos_psm1 = df[dataset.header_name_qpos_psm1].iloc[start_ts].to_numpy()
                qpos_psm2 = df[dataset.header_name_qpos_psm2].iloc[start_ts].to_numpy()

                # Determine how many valid future action steps exist
                num_valid_steps = min(dataset.action_horizon, episode_len - start_ts)

                # Get future action setpoints
                if num_valid_steps > 0:
                    action_setpoints_psm1 = df[dataset.header_name_actions_psm1].iloc[start_ts : start_ts + num_valid_steps].to_numpy()
                    action_setpoints_psm2 = df[dataset.header_name_actions_psm2].iloc[start_ts : start_ts + num_valid_steps].to_numpy()
                else:
                    # Handle the case at the very end of the episode
                    action_setpoints_psm1 = np.empty((0, len(dataset.header_name_actions_psm1)))
                    action_setpoints_psm2 = np.empty((0, len(dataset.header_name_actions_psm2)))

                raw_diff_psm1 = quat_to_axis_angle_action(action_setpoints_psm1)
                raw_diff_psm2 = quat_to_axis_angle_action(qpos_psm2, action_setpoints_psm2)

                # Create padded action arrays
                padded_diff_psm1 = np.zeros((dataset.action_horizon, 7))
                padded_diff_psm2 = np.zeros((dataset.action_horizon, 7))

                # Fill with computed diffs
                padded_diff_psm1[:num_valid_steps] = raw_diff_psm1
                padded_diff_psm2[:num_valid_steps] = raw_diff_psm2

                # Combine padded actions
                combined_padded_action = np.column_stack((padded_diff_psm1, padded_diff_psm2)) # Shape: (action_horizon, 14)

                # Extend the list with each step (row) from the padded sequence
                all_actions.extend(list(combined_padded_action))

            # Close the store
            store.close()
        except Exception as e:
            print(f"Error processing episode {episode_idx} ({episode_path}): {e}")
            if 'store' in locals() and store.is_open:
                store.close()
            continue # Skip to next episode

    # Convert lists to numpy arrays
    all_states = np.array(all_states)
    all_actions = np.array(all_actions)

    print(f"Collected {len(all_states)} state samples and {len(all_actions)} action samples (including padding)")

    if len(all_states) == 0 or len(all_actions) == 0:
        print("Error: No data collected. Cannot calculate statistics.")
        return {}

    # Calculate statistics using batch NumPy functions
    stats = {
        "state": {
            "min": np.min(all_states, axis=0).tolist(),
            "max": np.max(all_states, axis=0).tolist(),
            "mean": np.mean(all_states, axis=0).tolist(),
            "std": np.std(all_states, axis=0).tolist(),
            "q01": np.percentile(all_states, 1, axis=0).tolist(),
            "q99": np.percentile(all_states, 99, axis=0).tolist()
        },
        "action": {
            "min": np.min(all_actions, axis=0).tolist(),
            "max": np.max(all_actions, axis=0).tolist(),
            "mean": np.mean(all_actions, axis=0).tolist(),
            "std": np.std(all_actions, axis=0).tolist(),
            "q01": np.percentile(all_actions, 1, axis=0).tolist(),
            "q99": np.percentile(all_actions, 99, axis=0).tolist()
        }
    }

    # Print statistics
    print("\nDataset Statistics (Calculated similar to compute_norm_stats):")
    print(json.dumps(stats, indent=2))

    # Save to JSON file
    with open(output_json_path, 'w') as f:
        json.dump(stats, f, indent=2)

    print(f"\nStatistics saved to {output_json_path}")

    # Also print the action mean and std in a format that can be directly copied into the code
    print("\nAction mean and std for code:")
    print(f"mean = np.array({stats['action']['mean']})")
    print(f"std = np.array({stats['action']['std']})")
    print(f"min = np.array({stats['action']['min']})") # Optional: also print min/max if useful
    print(f"max = np.array({stats['action']['max']})")

    return stats

"""
Test the EpisodicDatasetDvrkGeneric class.
"""
if __name__ == "__main__":
    
    # Create an instance of the dataset
    dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dir_list=[
            "path_to_dataset"
        ],
        action_horizon = 50,
        norm_stats={},
        batch_transform=None,
        batch_size=8,
        cutting_action_pad_size = 10,
        skip_images=False,
        image_aug=True,
        targeting_strategy="dot"
    )

    # Get total dataset length
    total_length = len(dataset.flattened_indices)
    print(f"Dataset contains {total_length} samples")

    # Calculate statistics with action_horizon=50
    stats = calculate_dataset_statistics(dataset)
    exit()
    
    # Create a DataLoader
    train_batch_sampler = StratifiedBatchSampler(dataset, seed=0, rank=dist.get_rank())
    dataloader = DataLoader(dataset, batch_sampler=train_batch_sampler)

    actions = []
    from tqdm import tqdm
    
    max_actions = 200
    i = 0
    for data in tqdm(dataloader, desc="Processing actions"):
        if i > max_actions:
            break
        i += 1
        action = data["action"]
        # Process action tensor before appending
        action = action.detach().cpu().numpy()
        actions.append(action)

    # print mean and std of actions max and min
    print(f"Mean: {np.mean(actions)}")
    print(f"Std: {np.std(actions)}")
    print(f"Max: {np.max(actions)}")
    print(f"Min: {np.min(actions)}")