import shutil
import numpy as np

from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tensorflow_datasets as tfds
import tyro
import cv2
import os
import pickle
import random
REPO_NAME = "csu_vla/error_cot_beat_the_buzz"  # Name of the output dataset, also used for the Hugging Face Hub


def main():
    output_path = LEROBOT_HOME / REPO_NAME
    if output_path.exists():
        shutil.rmtree(output_path)

    # Create LeRobot dataset, define features to store
    # OpenPi assumes that proprio is stored in `state` and actions in `action`
    # LeRobot assumes that dtype of image data is `image`
    dataset = LeRobotDataset.create(
        repo_id=REPO_NAME,
        root = "",
        robot_type="panda",
        fps=10,
        features={
            "image": {
                "dtype": "image",
                "shape": (256, 256, 3),
                "names": ["height", "width", "channel"],
            },
            "wrist_image": {
                "dtype": "image",
                "shape": (256, 256, 3),
                "names": ["height", "width", "channel"],
            },
            "state": {
                "dtype": "float32",
                "shape": (8,),
                "names": ["state"],
            },
            "actions": {
                "dtype": "float32",
                "shape": (8,),
                "names": ["actions"],
            }
        },
        image_writer_threads=10,
        image_writer_processes=5,
    )

    raw_data_path_list = []
    for raw_data_path in raw_data_path_list:
        for episode_path in os.listdir(raw_data_path):
            pkl_path = os.path.join(raw_data_path, episode_path, "info.pkl")
            with open(pkl_path, "rb") as f:
                info_dict = pickle.load(f)
            episode_length = len(info_dict["states"])
            for step in range(episode_length-1):
                overhead_rgb_path = os.path.join(raw_data_path, episode_path, "overhead_rgb", f"{step}.png")
                wrist_rgb_path = os.path.join(raw_data_path, episode_path, "wrist_rgb", f"{step}.png")
                image = cv2.imread(overhead_rgb_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                image = cv2.resize(image, (256, 256))
                wrist_image = cv2.imread(wrist_rgb_path)
                wrist_image = cv2.cvtColor(wrist_image, cv2.COLOR_BGR2RGB)
                wrist_image = cv2.resize(wrist_image, (256, 256))
                state = info_dict["states"][step]
                action = info_dict["states"][step+1]
                dataset.add_frame(
                    {
                        "image": image,
                        "wrist_image": wrist_image,
                        "state": state,
                        "actions": action,
                    }
                )


            random_index = random.randint(0, len(info_dict["desc"])-1)
            dataset.save_episode(task=str(info_dict["desc"][random_index])+str(";")+str(pkl_path))

            # OPTIONAL MODIFICATION: After processing all images in an episode,
            # you can also remove the now-empty image directories.
            try:
                overhead_dir = os.path.join(raw_data_path, episode_path, "overhead_rgb")
                wrist_dir = os.path.join(raw_data_path, episode_path, "wrist_rgb")
                # Check if directories exist and are empty before removing
                if os.path.exists(overhead_dir) and not os.listdir(overhead_dir):
                    os.rmdir(overhead_dir)
                if os.path.exists(wrist_dir) and not os.listdir(wrist_dir):
                    os.rmdir(wrist_dir)
            except OSError as e:
                print(f"Error removing empty directories in {episode_path}: {e}")


    dataset.consolidate(run_compute_stats=False)

if __name__ == "__main__":
    tyro.cli(main)