from typing import Iterator, Tuple, Any
from pathlib import Path

import glob
import numpy as np
import tensorflow_datasets as tfds
import tqdm

class ExampleDataset(tfds.core.GeneratorBasedBuilder):
    """DatasetBuilder for example dataset."""

    VERSION = tfds.core.Version('1.0.0')
    RELEASE_NOTES = {
        '1.0.0': 'Initial release.',
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 加载语言标注文件
        # raw_dir = "/home/work3/data/aloha/fold_towel"
        # hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))


        # language_file = np.load("/home/work3/data/task_ABC_D_random1/training/lang_annotations/auto_lang_ann.npy", allow_pickle=True).item()
        # self.prompts = language_file['language']['ann']
        # self.start_end_list = language_file['info']['indx']
        
        self.tasks = [
            {"name": "/home/work3/data/aloha/fold_towel",
             "compressed": True, "filter": False},
        ]

    def _info(self) -> tfds.core.DatasetInfo:
        """Dataset metadata (homepage, citation,...)."""
        return self.dataset_info_from_configs(
            features=tfds.features.FeaturesDict({
                'steps': tfds.features.Dataset({
                    'observation': tfds.features.FeaturesDict({
                        'cam_head': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Static camera image.'
                        ),
                        'cam_left_wrist': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Wrist camera image.'
                        ),
                        'cam_right_wrist': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Next static camera image.'
                        ),
                        'next_cam_head': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Next wrist camera image.'
                        ),
                        'next_cam_left_wrist': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Wrist camera image.'
                        ),
                        'next_cam_right_wrist': tfds.features.Image(
                            shape=(480, 640, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Next static camera image.'
                        ),
                    }),
                    'language_instruction': tfds.features.Text(
                        doc='Language Instruction.'
                    ),
                    'action': tfds.features.Tensor(shape=(14,), dtype=np.float32, ),
                    'state': tfds.features.Tensor(shape=(14,), dtype=np.float32, ),
                }),
                'episode_metadata': tfds.features.FeaturesDict({
                    'file_path': tfds.features.Text(
                        doc='Path to the original data file.'
                    ),
                }),
            }))

    def _split_generators(self, dl_manager):
        del dl_manager
        return {
            'train': self._generate_examples(-1, spare=0)
        }

    def _generate_examples(self, num_ep, spare=0, start=0) -> Iterator[Tuple[str, Any]]:
        """Generator of examples for each split."""
        import h5py
        import torch

        def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
            """加载每个摄像头的原始图像数据"""
            imgs_per_cam = {}
            for camera in cameras:
                uncompressed = ep[f"observations/images/{camera}"].ndim == 4

                if uncompressed:
                    # 加载所有图像到内存
                    imgs_array = ep[f"observations/images/{camera}"][:]
                else:
                    import cv2
                    # 逐个加载压缩图像并解压
                    imgs_array = []
                    for data in ep[f"observations/images/{camera}"]:
                        imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
                    imgs_array = np.array(imgs_array)

                imgs_per_cam[camera] = imgs_array
            return imgs_per_cam

        def load_raw_episode_data(ep_path: Path) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
            """加载单个轨迹的原始数据"""
            with h5py.File(ep_path, "r") as ep:
                state = torch.from_numpy(ep["observations/qpos"][:])
                action = torch.from_numpy(ep["action"][:])

                velocity = None
                if "qvel" in ep["observations"]:
                    velocity = torch.from_numpy(ep["observations/qvel"][:])

                effort = None
                if "effort" in ep["observations"]:
                    effort = torch.from_numpy(ep["observations/effort"][:])

                # 获取可用的摄像头
                available_cameras = list(ep["observations/images"].keys())
                imgs_per_cam = load_raw_images_per_camera(ep, available_cameras)

            return imgs_per_cam, state, action, velocity, effort

        def _parse_trajectory(trajectory_start, trajectory_end, trajectory_idx, all_files_dict, path):
            """解析一条完整轨迹"""
            ins = self.prompts[trajectory_idx]
            episode = []
            
            # 遍历轨迹中的每个时刻（除了最后10个，因为需要next_image）
            for num in range(trajectory_start, trajectory_end):  # 保留最后10个时刻用于next_image
                try:
                    # 当前时刻的文件
                    current_filename = f"episode_{num:07d}.npz"
                    current_path = path / current_filename
                    with np.load(current_path, allow_pickle=True) as current_data:
                        next_num = min(num + 10, trajectory_end - 1)
                        next_filename = f"episode_{next_num:07d}.npz"
                        next_path = path / next_filename
                        with np.load(next_path, allow_pickle=True) as next_data:
                            # 使用数据...
                            action = current_data["rel_actions"][:7].astype(np.float32)
                            rgb_static = current_data["rgb_static"].copy()  # 注意加 .copy()
                            rgb_gripper = current_data["rgb_gripper"].copy()
                            next_rgb_static = next_data["rgb_static"].copy()
                            next_rgb_gripper = next_data["rgb_gripper"].copy()

                    # 创建step
                    step = {
                        'observation': {
                            'image': rgb_static,
                            'wrist_image': rgb_gripper,
                            'next_image': next_rgb_static,
                            'next_wrist_image': next_rgb_gripper,
                        },
                        'action': action,
                        'language_instruction': ins,
                    }
                    
                    episode.append(step)
                    
                except Exception as e:
                    print(f"Error processing {num}: {e}")
                    continue
            
            if len(episode) == 0:
                return None
                
            # 创建完整轨迹sample
            sample = {
                'steps': episode,
                'episode_metadata': {
                    'file_path': str(path / f"trajectory_{trajectory_idx}")
                }
            }
            
            return sample

        # 获取所有文件
        raw_dir = Path(self.tasks[0]["name"])
        
        hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
        episodes = range(len(hdf5_files))
        processed_trajectories = 0
        for ep_idx in tqdm.tqdm(episodes):
            ep_path = hdf5_files[ep_idx]
            imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
            num_frames = state.shape[0]
            episode = []
            for num in range(num_frames): 
                try:
                    step = {'observation': {},
                        'action': action[num],
                        'language_instruction': "fold the towel",
                        'state': state[num]}
                    for camera, img_array in imgs_per_cam.items():
                        step["observation"][camera] = img_array[num]
                        print(camera)
                    x = _
                    
                    episode.append(step)
                    
                except Exception as e:
                    print(f"Error processing {num}: {e}")
                    continue
            if len(episode) == 0:
                sample = None
            else:
                sample = {
                    'steps': episode,
                    'episode_metadata': {
                        'file_path': str(raw_dir / f"trajectory_{ep_idx}")
                    }
                }
            if sample is not None and len(sample['steps']) > 0:
                yield f"trajectory_{ep_idx}", sample
            else:
                print(f"Skipped trajectory {ep_idx} (no valid steps)")
                
        print(f"Total processed {processed_trajectories} trajectories")