import os
import shutil  # Added import for shutil
import numpy as np
import pandas as pd
import zarr
import numcodecs
import simplejpeg
from PIL import Image
import functools
from typing import List
from multiprocessing import Pool, cpu_count

def _assert_shape(arr: np.ndarray, expected_shape: tuple[int | None, ...]):
    """Asserts that the shape of an array matches the expected shape."""
    assert len(arr.shape) == len(expected_shape), (arr.shape, expected_shape)
    for dim, expected_dim in zip(arr.shape, expected_shape):
        if expected_dim is not None:
            assert dim == expected_dim, (arr.shape, expected_shape)

class JpegCodec(numcodecs.abc.Codec):
    """Codec for JPEG compression.
    Encodes image chunks as JPEGs. Assumes that chunks are uint8 with shape (1, H, W, 3).
    """
    codec_id = "pi_jpeg"

    def __init__(self, quality: int = 95):
        super().__init__()
        self.quality = quality

    def encode(self, buf):
        _assert_shape(buf, (1, None, None, 3))
        assert buf.dtype == "uint8"
        return simplejpeg.encode_jpeg(buf[0], quality=self.quality)

    def decode(self, buf, out=None):
        img = simplejpeg.decode_jpeg(buf, buffer=out)
        return img[np.newaxis, ...]

@functools.cache
def register_codecs():
    """Register the custom codecs."""
    numcodecs.register_codec(JpegCodec)

register_codecs()

def read_images(image_dir: str, file_pattern: str, num_images: int) -> np.ndarray:
    """Reads images from a directory into a NumPy array."""
    images = []
    for idx in range(num_images):
        filename = os.path.join(image_dir, file_pattern.format(idx))
        if not os.path.exists(filename):
            print(f"Warning: {filename} does not exist.")
            continue
        img = Image.open(filename)
        img_array = np.array(img)[..., :3]  # Ensure 3 channels
        images.append(img_array)
    if images:
        return np.stack(images)
    else:
        return np.empty((0, 0, 0, 3), dtype=np.uint8)

def copy_clicked_point_csv(episode_path, output_base_dir):
    """Copies clicked_point.csv to the destination directory with the episode timestamp prefix."""
    # Check if clicked_point.csv exists
    clicked_point_path = os.path.join(episode_path, 'clicked_point.csv')
    if not os.path.exists(clicked_point_path):
        return
    
    # Parse path components
    path_parts = os.path.normpath(episode_path).split(os.sep)
    
    # Extract components (assuming path structure: .../tissue_X/task_name/timestamp)
    tissue_name = None
    task_name = None
    timestamp = None
    
    for i, part in enumerate(path_parts):
        if part.startswith('tissue_'):
            tissue_name = part
            if i+1 < len(path_parts):
                task_name = path_parts[i+1]
            if i+2 < len(path_parts):
                timestamp = path_parts[i+2]
            break
    
    if tissue_name and task_name and timestamp:
        # Create destination directory
        dest_dir = os.path.join(output_base_dir, tissue_name, task_name)
        os.makedirs(dest_dir, exist_ok=True)
        
        # Copy the file with timestamp prefix
        dest_file = os.path.join(dest_dir, f"{timestamp}_clicked_point.csv")
        shutil.copy2(clicked_point_path, dest_file)
        print(f"Copied {clicked_point_path} to {dest_file}")

def process_episode(args):
    """Processes a single episode, creating a zipped Zarr file."""
    episode_path, output_base_dir = args  # Unpack arguments
    try:
        # Paths to image directories
        left_dir = os.path.join(episode_path, 'left_img_dir')
        right_dir = os.path.join(episode_path, 'right_img_dir')
        psm1_dir = os.path.join(episode_path, 'endo_psm1')
        psm2_dir = os.path.join(episode_path, 'endo_psm2')
        csv_file = os.path.join(episode_path, 'ee_csv.csv')

        # Read CSV to determine the number of frames (excluding header)
        df = pd.read_csv(csv_file)
        num_frames = len(df)

        # Read images from each camera
        left_images = read_images(left_dir, 'frame{:06d}_left.jpg', num_frames)
        right_images = read_images(right_dir, 'frame{:06d}_right.jpg', num_frames)
        psm1_images = read_images(psm1_dir, 'frame{:06d}_psm1.jpg', num_frames)
        psm2_images = read_images(psm2_dir, 'frame{:06d}_psm2.jpg', num_frames)

        # Read kinematics data and convert to structured array with headers
        kinematics_data = np.array(
            [tuple(row) for row in df.to_numpy()],
            dtype=[(col, df[col].dtype.str) for col in df.columns]
        )

        # Create Zarr store
        relative_path = os.path.relpath(episode_path)
        zarr_path = os.path.join(output_base_dir, relative_path + '.zarr')
        os.makedirs(os.path.dirname(zarr_path), exist_ok=True)
        zarr_store = zarr.open_group(zarr_path, mode='w')

        # Set up compressor
        compressor = JpegCodec(quality=90)

        # Store images
        for cam_name, images in [('left', left_images), ('right', right_images),
                                ('endo_psm1', psm1_images), ('endo_psm2', psm2_images)]:
            if images.size > 0:
                image_store = zarr_store.create_dataset(
                    cam_name,
                    shape=images.shape,
                    chunks=(1, images.shape[1], images.shape[2], images.shape[3]),
                    dtype='uint8',
                    compressor=compressor
                )
                image_store[:] = images

        # Store kinematics data as structured array
        zarr_store.create_dataset(
            'kinematics',
            data=kinematics_data,
            dtype=kinematics_data.dtype
        )

        # Reintroduced zipping step
        # Zip the Zarr store
        zip_path = zarr_path.rstrip('.zarr') + '.zip'
        shutil.make_archive(zarr_path.rstrip('.zarr'), 'zip', zarr_path)

        # Optionally, remove the unzipped Zarr folder to save space
        shutil.rmtree(zarr_path)
        
        # Copy clicked_point.csv file if it exists
        copy_clicked_point_csv(episode_path, output_base_dir)
        
    except Exception as e:
        print(f"########## Failed to parse: {episode_path} due to: ##########\n{e}")

def process_all_episodes(base_dir: str, tissue_indices: List[int], output_base_dir: str):
    """Processes all episodes for given tissue indices using multiprocessing."""
    episode_paths = []

    for tissue_idx in tissue_indices:
        tissue_dir = os.path.join(base_dir, f'tissue_{tissue_idx}')
        if not os.path.exists(tissue_dir):
            print(f"Warning: {tissue_dir} does not exist.")
            continue

        for subtask_name in os.listdir(tissue_dir):
            subtask_dir = os.path.join(tissue_dir, subtask_name)
            if not os.path.isdir(subtask_dir):
                continue

            for episode_name in os.listdir(subtask_dir):
                episode_dir = os.path.join(subtask_dir, episode_name)
                if not os.path.isdir(episode_dir):
                    continue

                # Collect the episode path and output directory
                episode_paths.append((episode_dir, output_base_dir))

    # Determine the number of processes to use
    num_processes = min(16, len(episode_paths))
    print(f"Processing {len(episode_paths)} episodes with {num_processes} processes.")

    # Use multiprocessing Pool to process episodes in parallel
    with Pool(processes=num_processes) as pool:
        pool.map(process_episode, episode_paths)

# Usage example
if __name__ == "__main__":
    base_dir = '.'  # Current directory
    output_base_dir = './processed_suturing_data_zipped'  # New folder for Zarr files
    tissue_indices = list(range(11))
    process_all_episodes(base_dir, tissue_indices, output_base_dir)
