from typing import Iterator, Tuple, Any
from pathlib import Path

import glob
import numpy as np
import tensorflow_datasets as tfds


class ExampleDataset(tfds.core.GeneratorBasedBuilder):
    """DatasetBuilder for example dataset."""

    VERSION = tfds.core.Version('1.0.0')
    RELEASE_NOTES = {
        '1.0.0': 'Initial release.',
    }

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # 加载语言标注文件
        language_file = np.load("/home/work3/data/task_ABC_D_random1/training/lang_annotations/auto_lang_ann.npy", allow_pickle=True).item()
        self.prompts = language_file['language']['ann']
        self.start_end_list = language_file['info']['indx']
        
        self.tasks = [
            {"name": "/home/work3/data/task_ABC_D_random1/training",
             "compressed": True, "filter": False},
        ]

    def find_interval_index(self, start_end_list, npz_filename):
        num = int(npz_filename.split('_')[-1].split('.')[0].lstrip('0'))
        for idx, (start, end) in enumerate(start_end_list):
            if start <= num <= end:
                return num, idx
        return None

    def _info(self) -> tfds.core.DatasetInfo:
        """Dataset metadata (homepage, citation,...)."""
        return self.dataset_info_from_configs(
            features=tfds.features.FeaturesDict({
                'steps': tfds.features.Dataset({
                    'observation': tfds.features.FeaturesDict({
                        'image': tfds.features.Image(
                            shape=(200, 200, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Static camera image.'
                        ),
                        'wrist_image': tfds.features.Image(
                            shape=(84, 84, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Wrist camera image.'
                        ),
                        'next_image': tfds.features.Image(
                            shape=(200, 200, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Next static camera image.'
                        ),
                        'next_wrist_image': tfds.features.Image(
                            shape=(84, 84, 3), dtype=np.uint8, encoding_format='jpeg',
                            doc='Next wrist camera image.'
                        ),
                    }),
                    'language_instruction': tfds.features.Text(
                        doc='Language Instruction.'
                    ),
                    'action': tfds.features.Tensor(shape=(7,), dtype=np.float32, ),
                }),
                'episode_metadata': tfds.features.FeaturesDict({
                    'file_path': tfds.features.Text(
                        doc='Path to the original data file.'
                    ),
                }),
            }))

    def _split_generators(self, dl_manager: tfds.download.DownloadManager):
        """Define data splits."""
        return {
            'train': self._generate_examples(-1, spare=0),  # -1表示使用所有轨迹
        }

    def _generate_examples(self, num_ep, spare=0, start=0) -> Iterator[Tuple[str, Any]]:
        """Generator of examples for each split."""

        def _parse_trajectory(trajectory_start, trajectory_end, trajectory_idx, all_files_dict, path):
            """解析一条完整轨迹"""
            ins = self.prompts[trajectory_idx]
            episode = []
            
            # 遍历轨迹中的每个时刻（除了最后10个，因为需要next_image）
            for num in range(trajectory_start, trajectory_end):  # 保留最后10个时刻用于next_image
                try:
                    # 当前时刻的文件
                    current_filename = f"episode_{num:07d}.npz"
                    current_path = path / current_filename
                    with np.load(current_path, allow_pickle=True) as current_data:
                        next_num = min(num + 10, trajectory_end - 1)
                        next_filename = f"episode_{next_num:07d}.npz"
                        next_path = path / next_filename
                        with np.load(next_path, allow_pickle=True) as next_data:
                            # 使用数据...
                            action = current_data["rel_actions"][:7].astype(np.float32)
                            rgb_static = current_data["rgb_static"].copy()  # 注意加 .copy()
                            rgb_gripper = current_data["rgb_gripper"].copy()
                            next_rgb_static = next_data["rgb_static"].copy()
                            next_rgb_gripper = next_data["rgb_gripper"].copy()

                    # 创建step
                    step = {
                        'observation': {
                            'image': rgb_static,
                            'wrist_image': rgb_gripper,
                            'next_image': next_rgb_static,
                            'next_wrist_image': next_rgb_gripper,
                        },
                        'action': action,
                        'language_instruction': ins,
                    }
                    
                    episode.append(step)
                    
                except Exception as e:
                    print(f"Error processing {num}: {e}")
                    continue
            
            if len(episode) == 0:
                return None
                
            # 创建完整轨迹sample
            sample = {
                'steps': episode,
                'episode_metadata': {
                    'file_path': str(path / f"trajectory_{trajectory_idx}")
                }
            }
            
            return sample

        # 获取所有文件
        path = Path(self.tasks[0]["name"])
        all_files = sorted(glob.glob(str(path / "*.npz")))
        
        # 创建文件字典以便快速查找
        all_files_dict = {Path(f).name: f for f in all_files}
        
        # 按轨迹分组处理
        processed_trajectories = 0
        for trajectory_idx, (trajectory_start, trajectory_end) in enumerate(self.start_end_list):
            if num_ep != -1 and processed_trajectories >= num_ep:
                break
                
            if trajectory_idx < start:
                continue
                
            print(f"Processing trajectory {trajectory_idx}: [{trajectory_start}, {trajectory_end}]")
            
            sample = _parse_trajectory(trajectory_start, trajectory_end, trajectory_idx, all_files_dict, path)
            if sample is not None and len(sample['steps']) > 0:
                yield f"trajectory_{trajectory_idx}", sample
                processed_trajectories += 1
            else:
                print(f"Skipped trajectory {trajectory_idx} (no valid steps)")
                
        print(f"Total processed {processed_trajectories} trajectories")