"""
数据加载和预处理模块

数据流程：
1. 环境特定的 DatasetLoader 读取原始数据
2. 转换为统一的 episode 格式
3. 加载到 DemoReplayBuffer
"""

import numpy as np
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from abc import ABC, abstractmethod
import pickle
import h5py
from scipy.spatial.transform import Rotation as R
from utils import convert_delta_to_absolute_actions
from PIL import Image

# ==================== 统一的 Episode 数据格式 ====================
# 每个 episode 是一个字典，包含：
#   - 'obs': np.ndarray [T+1, obs_dim]     - 低维观测（可选）
#   - 'rgb_obs': np.ndarray [T+1, num_cameras, C, H, W] - RGB 观测（可选）
#   - 'actions': np.ndarray [T, action_dim] - 动作序列
#   - 'rewards': np.ndarray [T]            - 奖励序列
#   - 'terminals': np.ndarray [T]          - 终止标志（bool）
#   - 'info': dict                         - 其他信息（可选）
# ==================================================================


class BaseDatasetLoader(ABC):
    """
    数据集加载器基类
    
    子类需要实现：
    - load_episodes(): 加载并返回 episode 列表
    """
    
    def __init__(self, data_path: str, **kwargs):
        """
        Args:
            data_path: 数据路径（目录或文件）
            **kwargs: 子类特定参数
        """
        self.data_path = Path(data_path)
        if not self.data_path.exists():
            raise ValueError(f"数据路径不存在: {self.data_path}")
    
    @abstractmethod
    def load_episodes(self, num_episodes: Optional[int] = None) -> List[Dict[str, np.ndarray]]:
        """
        加载 episodes
        
        Args:
            num_episodes: 加载的 episode 数量（None 表示全部）
        
        Returns:
            episodes: 标准格式的 episode 列表
        """
        pass
    
    def get_data_info(self) -> Dict[str, Any]:
        """
        获取数据集信息（维度、数量等）
        
        Returns:
            info: 数据集信息字典
        """
        episodes = self.load_episodes(num_episodes=1)
        if len(episodes) == 0:
            raise ValueError("数据集为空")
        
        ep = episodes[0]
        info = {
            'action_dim': ep['actions'].shape[-1],
            'avg_ep_length': len(ep['actions']),
        }
        
        if 'obs' in ep:
            info['obs_dim'] = ep['obs'].shape[-1]
            info['has_obs'] = True
        else:
            info['has_obs'] = False
        
        if 'rgb_obs' in ep:
            info['rgb_shape'] = ep['rgb_obs'].shape[1:]  # (num_cameras, C, H, W)
            info['has_rgb'] = True
        else:
            info['has_rgb'] = False
        
        return info


# ==================== Libero 数据加载器 ====================

class LiberoDatasetLoader(BaseDatasetLoader):
    """
    Libero 环境数据加载器
    
    从 RLDS/TFDS 格式加载 Libero 演示数据（原始格式）。
    
    Libero RLDS 数据结构（参考 convert_libero_data_to_lerobot.py）：
    - 使用 tensorflow_datasets 加载
    - episode["steps"] 包含多个 step
    - 每个 step 包含：
        - observation: {
            - image: [256, 256, 3] (agentview)
            - wrist_image: [256, 256, 3] (eye_in_hand)
            - state: [8] (proprioception)
          }
        - action: [7] (delta ee position + orientation + gripper)
        - language_instruction: str
    
    支持多任务数据加载（多个 task suite）：
    - libero_10_no_noops
    - libero_goal_no_noops
    - libero_object_no_noops
    - libero_spatial_no_noops
    """
    
    def __init__(
        self,
        data_path: str,
        dataset_names: List[str] = None,
        use_rgb: bool = False,
        camera_names: List[str] = None,
        use_state: bool = True,
        **kwargs
    ):
        """
        Args:
            data_path: RLDS 数据目录路径（tensorflow_datasets 的数据目录）
            dataset_names: 数据集名称列表，支持多个 task suite
                例如: ['libero_10_no_noops', 'libero_goal_no_noops', ...]
                如果为 None，则从 data_path 中自动检测或使用默认列表
            use_rgb: 是否加载 RGB 观测
            camera_names: 相机名称列表，默认 ['image', 'wrist_image']
                对应 RLDS 中的 observation['image'] 和 observation['wrist_image']
            use_state: 是否加载状态观测（observation['state']）
        """
        super().__init__(data_path)
        
        # 尝试导入 tensorflow_datasets
        import tensorflow_datasets as tfds
        self.tfds = tfds
        
        self.use_rgb = use_rgb
        self.camera_names = camera_names or ['image', 'wrist_image']
        self.use_state = use_state
        
        # 数据集名称列表（支持多任务）
        if dataset_names is None:
            # 默认使用所有常见的 Libero task suites
            self.dataset_names = [
                "libero_10_no_noops",
                "libero_goal_no_noops",
                "libero_object_no_noops",
                "libero_spatial_no_noops",
            ]
        else:
            self.dataset_names = dataset_names if isinstance(dataset_names, list) else [dataset_names]

        print(f"Libero 数据加载器初始化:")
        print(f"  数据目录: {self.data_path}")
        print(f"  数据集: {self.dataset_names}")
        print(f"  使用 RGB: {self.use_rgb}")
        print(f"  使用状态: {self.use_state}")
    
    def load_episodes(self, num_episodes: Optional[int] = None) -> List[Dict[str, np.ndarray]]:
        """
        从 RLDS/TFDS 格式加载 Libero demos（支持多任务）
        
        Args:
            num_episodes: 加载的 episode 数量（None 表示全部）
        
        Returns:
            episodes: 标准格式的 episode 列表
        """
        episodes = []
        total_demos_loaded = 0
        
        # 遍历所有数据集（多任务）
        for dataset_name in self.dataset_names:
            if num_episodes is not None and total_demos_loaded >= num_episodes:
                break
            
            print(f"\n加载数据集: {dataset_name}")
            # 加载 TFDS 数据集
            dataset = self.tfds.load(
                dataset_name,
                data_dir=str(self.data_path),
                split="train"
            )
            
            dataset_episodes = 0
            for episode in dataset:
                if num_episodes is not None and total_demos_loaded >= num_episodes:
                    break
                
                converted_episode = self._convert_episode_from_rlds(episode)
                if converted_episode is not None:
                    episodes.append(converted_episode)
                    total_demos_loaded += 1
                    dataset_episodes += 1
            
            print(f"  从 {dataset_name} 加载了 {dataset_episodes} 个 episodes")
                
            
        
        print(f"\n总共从 Libero 数据集加载了 {len(episodes)} 个 episodes")
        return episodes
    

    
    
    def _convert_episode_from_rlds(self, episode) -> Optional[Dict[str, np.ndarray]]:
        """
        从 RLDS 格式的 episode 转换为标准格式
        
        Args:
            episode: RLDS episode，包含 episode["steps"]
        
        Returns:
            episode: 标准格式的 episode 字典
        """
        # 收集所有 steps
        steps = []
        for step in episode["steps"].as_numpy_iterator():
            steps.append(step)
        
        if len(steps) == 0:
            return None
        
        T = len(steps)
        
        # 提取低维状态观测（需要在转换 action 之前提取）
        states = None
        if self.use_state:
            states = []
            for step in steps:
                obs = step["observation"]
                if "state" in obs:
                    states.append(obs["state"])
                else:
                    raise ValueError("没有状态观测")
            
            if states:
                obs_array = np.array(states, dtype=np.float32)  # [T, state_dim]
                # 添加最后一个状态
                last_obs = obs_array[-1:]
                episode_dict = {
                    'obs': np.concatenate([obs_array, last_obs], axis=0),  # [T+1, obs_dim]
                }
            else:
                episode_dict = {}
        else:
            episode_dict = {}
        
        # 提取动作序列
        delta_actions = np.array([step["action"] for step in steps], dtype=np.float32)  # [T, 7]
        actions = delta_actions
        
        # 构建 episode
        episode_dict.update({
            'actions': actions,
            'rewards': np.ones(T, dtype=np.float32),  # RLDS 格式可能没有显式的 rewards
            'terminals': np.zeros(T, dtype=bool),
        })
        episode_dict['terminals'][-1] = True  # 最后一个 step 是终止
        
        # 提取 RGB 观测
        if self.use_rgb:
            rgb_list = []
            for step in steps:
                obs = step["observation"]
                rgb_frames = []
                for cam_name in self.camera_names:
                    if cam_name in obs:
                        # RLDS 格式: [H, W, C] -> [C, H, W]
                        rgb = obs[cam_name]
                        if len(rgb.shape) == 3:
                            rgb = np.transpose(rgb, (2, 0, 1))  # [H, W, C] -> [C, H, W]
                        rgb_frames.append(rgb)
                if rgb_frames:
                    # 堆叠多相机: [num_cameras, C, H, W]
                    rgb_list.append(np.stack(rgb_frames, axis=0))
            
            if rgb_list:
                rgb_obs = np.array(rgb_list, dtype=np.uint8)  # [T, num_cameras, C, H, W]
                # 添加最后一帧
                last_rgb = rgb_obs[-1:]
                episode_dict['rgb_obs'] = np.concatenate([rgb_obs, last_rgb], axis=0)  # [T+1, num_cameras, C, H, W]
        # 保存任务信息（可选）
        if len(steps) > 0 and "language_instruction" in steps[0]:
            episode_dict['info'] = {
                'task': steps[0]["language_instruction"].decode() if isinstance(steps[0]["language_instruction"], bytes) else steps[0]["language_instruction"]
            }
        
        return episode_dict

# ==================== 真机 HDF5 数据加载器 ====================

class HDF5DatasetLoader(BaseDatasetLoader):
    """
    真机 HDF5 数据加载器
    
    每个 episode 是一个 HDF5 文件，包含：
    - action: [T, 14]
    - observations/qpos: [T, 14]
    - observations/eef_factory: [T, 14]
    state = qpos + eef_factory
    rewards: 最后一步为1，其他为0
    """
    
    def __init__(
        self,
        data_path: str,
        **kwargs
    ):
        """
        Args:
            data_path: HDF5 文件目录路径（包含多个 .hdf5 文件，每个文件一个 episode）
        """
        super().__init__(data_path)

        single_arm = kwargs.get('single_arm', False)
        self.single_arm = single_arm
        print(f"HDF5 数据加载器初始化:")
        print(f"  数据目录: {self.data_path}")
    
    def load_episodes(self, num_episodes: Optional[int] = None) -> List[Dict[str, np.ndarray]]:
        """
        加载 HDF5 episodes
        
        Args:
            num_episodes: 加载的 episode 数量（None 表示全部）
        
        Returns:
            episodes: 标准格式的 episode 列表
        """
        # 查找所有 HDF5 文件
        hdf5_files = sorted(self.data_path.glob("*.hdf5"))
        if len(hdf5_files) == 0:
            # 也尝试 .h5 扩展名
            hdf5_files = sorted(self.data_path.glob("*.h5"))
        
        if len(hdf5_files) == 0:
            raise ValueError(f"未找到 HDF5 文件: {self.data_path}")
        
        print(f"找到 {len(hdf5_files)} 个 HDF5 文件")
        
        # 限制数量
        if num_episodes is not None:
            hdf5_files = hdf5_files[:num_episodes]
        
        episodes = []
        for hdf5_file in hdf5_files:
            episode = self._load_single_episode(hdf5_file)
            if episode is not None:
                episodes.append(episode)
        
        print(f"从 HDF5 加载了 {len(episodes)} 个 episodes")
        return episodes
    
    def _load_single_episode(self, hdf5_file: Path) -> Optional[Dict[str, np.ndarray]]:
        """
        加载单个 HDF5 episode 文件
        
        Args:
            hdf5_file: HDF5 文件路径
        
        Returns:
            episode: 标准格式的 episode 字典
        """
        with h5py.File(hdf5_file, 'r') as f:
            # 加载动作
            if 'action' not in f:
                print(f"警告: {hdf5_file} 中没有 'action' 数据集，跳过")
                return None
            
            actions = np.array(f['action'], dtype=np.float32) if not self.single_arm else np.array(f['action'][:, 7:], dtype=np.float32)
            T = len(actions)
            
            if T == 0:
                print(f"警告: {hdf5_file} 中动作序列为空，跳过")
                return None
            
            # 加载观测数据
            obs_group = f.get('observations', None)
            if obs_group is None:
                print(f"警告: {hdf5_file} 中没有 'observations' 组，跳过")
                return None
            
            # 提取 qpos, eef_factory
            qpos = None
            eef_factory = None
            
            if 'qpos' in obs_group:
                qpos = np.array(obs_group['qpos'], dtype=np.float32) if not self.single_arm else np.array(obs_group['qpos'][:, 7:], dtype=np.float32)

            else:
                print(f"警告: {hdf5_file} 中没有 'observations/qpos'，跳过")
                return None
            
            
            if 'eef_factory' in obs_group:
                eef_factory = np.array(obs_group['eef_factory'], dtype=np.float32) if not self.single_arm else np.array(obs_group['eef_factory'][:, 7:], dtype=np.float32)
            else:
                print(f"警告: {hdf5_file} 中没有 'observations/eef_factory'，跳过")
                return None
            
            # 组合状态: state = qpos + eef_factory
                
            states = np.concatenate([qpos, eef_factory], axis=-1)  # [T, 42]
            
            # 添加最后一个状态（T+1 个状态）
            last_state = states[-1:]
            obs = np.concatenate([states, last_state], axis=0)  # [T+1, 42]
            
            # 构建 rewards: 最后一步为1，其他为0
            rewards = np.zeros(T, dtype=np.float32)
            rewards[-1] = 1.0
            
            # 构建 terminals: 最后一步为 True
            terminals = np.zeros(T, dtype=bool)
            terminals[-1] = True
            
            # 构建 episode 字典
            episode = {
                'obs': obs,  
                'actions': actions, 
                'rewards': rewards, 
                'terminals': terminals, 
            }
            
            return episode



# ==================== 数据集工具函数 ====================

def analyze_episodes(episodes: List[Dict[str, np.ndarray]]) -> Dict[str, Any]:
    """
    分析 episode 数据的统计信息
    
    Args:
        episodes: Episode 列表
        
    Returns:
        info: 统计信息字典
    """
    print("\n" + "="*60)
    print("数据集分析")
    print("="*60)
    
    num_episodes = len(episodes)
    print(f"\n总 episodes 数量: {num_episodes}")
    
    # Episode 长度
    ep_lengths = [len(ep['actions']) for ep in episodes]
    print(f"\nEpisode 长度:")
    print(f"  平均: {np.mean(ep_lengths):.1f}")
    print(f"  最小: {np.min(ep_lengths)}")
    print(f"  最大: {np.max(ep_lengths)}")
    
    # 维度
    has_obs = 'obs' in episodes[0]
    has_rgb = 'rgb_obs' in episodes[0]
    
    action_dim = episodes[0]['actions'].shape[-1]
    print(f"\n维度:")
    print(f"  动作维度: {action_dim}")
    
    if has_obs:
        obs_dim = episodes[0]['obs'].shape[-1]
        print(f"  低维观测维度: {obs_dim}")
        
        all_obs = np.concatenate([ep['obs'] for ep in episodes], axis=0)
        print(f"\n低维观测统计:")
        print(f"  范围: [{np.min(all_obs):.3f}, {np.max(all_obs):.3f}]")
        print(f"  均值: {np.mean(all_obs):.3f}")
    
    if has_rgb:
        rgb_shape = episodes[0]['rgb_obs'].shape[1:]  # (num_cameras, C, H, W)
        print(f"  RGB 形状: {rgb_shape}")
    
    # 动作统计
    all_actions = np.concatenate([ep['actions'] for ep in episodes], axis=0)
    print(f"\n动作统计:")
    print(f"  范围: [{np.min(all_actions):.3f}, {np.max(all_actions):.3f}]")
    print(f"  均值: {np.mean(all_actions):.3f}")
    
    # 奖励统计
    all_rewards = np.concatenate([ep['rewards'] for ep in episodes], axis=0)
    total_reward = np.sum(all_rewards)
    print(f"\n奖励统计:")
    print(f"  总奖励: {total_reward:.1f}")
    print(f"  成功率: {total_reward / num_episodes:.2%}")
    
    print("="*60 + "\n")
    
    return {
        'num_episodes': num_episodes,
        'action_dim': action_dim,
        'obs_dim': episodes[0]['obs'].shape[-1] if has_obs else 0,
        'has_obs': has_obs,
        'has_rgb': has_rgb,
        'rgb_shape': episodes[0]['rgb_obs'].shape[1:] if has_rgb else None,
        'avg_ep_length': np.mean(ep_lengths),
    }


def split_train_val(
    episodes: List[Dict[str, np.ndarray]],
    train_ratio: float = 0.9,
    shuffle: bool = True,
    seed: int = 42
) -> Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]]:
    """
    划分训练集和验证集
    
    Args:
        episodes: Episode 列表
        train_ratio: 训练集比例
        shuffle: 是否打乱
        seed: 随机种子
    
    Returns:
        train_episodes, val_episodes
    """
    episodes = episodes.copy()
    
    if shuffle:
        np.random.seed(seed)
        np.random.shuffle(episodes)
    
    split_idx = int(len(episodes) * train_ratio)
    train_episodes = episodes[:split_idx]
    val_episodes = episodes[split_idx:]
    
    print(f"训练集: {len(train_episodes)} episodes")
    print(f"验证集: {len(val_episodes)} episodes")
    
    return train_episodes, val_episodes


def load_episodes_to_buffer(
    loader: BaseDatasetLoader,
    buffer,
    num_episodes: Optional[int] = None
):
    """
    加载数据并直接添加到 replay buffer
    
    Args:
        loader: 数据加载器
        buffer: DemoReplayBuffer 实例
        num_episodes: 加载的 episode 数量
    """
    episodes = loader.load_episodes(num_episodes)
    buffer.add_demos(episodes)
    return episodes


# ==================== 便捷函数 ====================

def create_loader(
    env_name: str,
    data_path: str,
    **kwargs
) -> BaseDatasetLoader:
    """
    根据环境名称创建对应的数据加载器
    
    Args:
        env_name: 环境名称 ('libero', 'pickle', 'hdf5')
        data_path: 数据路径
        **kwargs: 传递给加载器的额外参数
    
    Returns:
        loader: 数据加载器实例
    """
    env_name = env_name.lower()
    if env_name == 'libero':
        return LiberoDatasetLoader(data_path, **kwargs)
    elif env_name == 'hdf5':
        return HDF5DatasetLoader(data_path, **kwargs)
    else:
        raise ValueError(f"未知的环境名称: {env_name}，支持: libero, hdf5")
