import pickle
import os
from pathlib import Path
import torch
import numpy as np

class LeRobotDatasetWithPklReader:
    """
    扩展的LeRobotDataset类，包含pkl文件读取功能
    """
    
    def __init__(self, original_dataset):
        """
        初始化数据集
        
        Args:
            original_dataset: 原始的LeRobotDataset实例
        """
        self.dataset = original_dataset
        
    def read_pkl_data(self, pkl_path):
        """
        读取pkl文件中的数据
        
        Args:
            pkl_path (str): pkl文件的路径
            
        Returns:
            从pkl文件中加载的数据
        """
        try:
            # 构建完整的文件路径
            full_path = self.dataset.root / pkl_path
            
            # 检查文件是否存在
            if not full_path.exists():
                print(f"警告: pkl文件不存在 - {full_path}")
                return None
                
            # 读取pkl文件
            with open(full_path, 'rb') as f:
                data = pickle.load(f)
                
            return data
            
        except Exception as e:
            print(f"读取pkl文件时出错: {e}")
            return None
    
    def get_item_with_pkl(self, idx):
        """
        获取数据集项，并读取相关的pkl文件
        
        Args:
            idx (int): 数据索引
            
        Returns:
            包含pkl数据的字典
        """
        # 获取原始数据项
        item = self.dataset[idx]
        
        # 获取任务索引
        task_index = int(item["task_index"])
        
        # 从任务字符串中提取pkl路径
        task_string = self.dataset.meta.tasks[task_index]
        parts = task_string.split(";")
        
        if len(parts) >= 2:
            task_prompt = parts[0]
            info_pkl_path = parts[1]
            
            # 读取pkl文件
            pkl_data = self.read_pkl_data(info_pkl_path)
            
            # 将pkl数据添加到返回的字典中
            if pkl_data is not None:
                item["pkl_data"] = pkl_data
                item["pkl_path"] = info_pkl_path
            else:
                item["pkl_data"] = None
                item["pkl_path"] = info_pkl_path
                
            item["prompt"] = task_prompt
            
        return item
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.get_item_with_pkl(idx)

# 使用示例
def example_usage():
    """
    使用示例
    """
    # 假设你已经有了一个LeRobotDataset实例
    # dataset = LeRobotDataset(repo_id="your_repo_id")
    
    # 创建扩展的数据集
    # extended_dataset = LeRobotDatasetWithPklReader(dataset)
    
    # 获取第一个数据项
    # item = extended_dataset[0]
    
    # 打印pkl数据
    # if "pkl_data" in item and item["pkl_data"] is not None:
    #     print("Pkl数据:")
    #     print(item["pkl_data"])
    #     print("Pkl路径:")
    #     print(item["pkl_path"])
    #     print("任务提示:")
    #     print(item["prompt"])
    
    pass

# 如果你想要直接修改原始的LeRobotDataset类，可以这样做：
def modify_lerobot_dataset_class():
    """
    展示如何修改原始的LeRobotDataset类的__getitem__方法
    """
    
    # 这是修改后的__getitem__方法的代码片段
    modified_getitem_code = '''
    def __getitem__(self, idx) -> dict:
        item = self.hf_dataset[idx]
        ep_idx = item["episode_index"].item()

        query_indices = None
        if self.delta_indices is not None:
            current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
            query_indices, padding = self._get_query_indices(idx, current_ep_idx)
            query_result = self._query_hf_dataset(query_indices)
            item = {**item, **padding}
            for key, val in query_result.items():
                item[key] = val

        if len(self.meta.video_keys) > 0:
            current_ts = item["timestamp"].item()
            query_timestamps = self._get_query_timestamps(current_ts, query_indices)
            video_frames = self._query_videos(query_timestamps, ep_idx)
            item = {**video_frames, **item}

        if self.image_transforms is not None:
            image_keys = self.meta.camera_keys
            for cam in image_keys:
                item[cam] = self.image_transforms(item[cam])
                
        task_index = int(item["task_index"])
        task_prompt = self.meta.tasks[task_index].split(";")[0]
        info_pkl_path = self.meta.tasks[task_index].split(";")[1]
        
        # 读取pkl文件
        pkl_data = None
        if len(self.meta.tasks[task_index].split(";")) >= 2:
            try:
                full_pkl_path = self.root / info_pkl_path
                if full_pkl_path.exists():
                    with open(full_pkl_path, 'rb') as f:
                        pkl_data = pickle.load(f)
            except Exception as e:
                print(f"读取pkl文件时出错: {e}")
        
        # 将pkl数据添加到返回的字典中
        item["pkl_data"] = pkl_data
        item["pkl_path"] = info_pkl_path
        
        return {**item, "prompt": task_prompt}
    '''
    
    print("修改后的__getitem__方法代码:")
    print(modified_getitem_code)

if __name__ == "__main__":
    # 运行示例
    example_usage()
    modify_lerobot_dataset_class() 