"""
简化的 Replay Buffer 用于序列数据
基于 robobase 的实现，支持 world model 和 policy 训练
"""

import numpy as np
import torch
from torch.utils.data import Dataset
from typing import Dict, List, Tuple, Optional
from pathlib import Path
import pickle


class SequentialReplayBuffer(Dataset):
    """
    序列数据的 Replay Buffer
    
    存储和采样序列数据，支持 world model 和 policy 训练。
    每个样本包含：
    - obs: [T, obs_dim] 观测序列（或 RGB 观测）
    - next_obs: [T, obs_dim] 下一状态序列
    - actions: [T, action_dim] 动作序列
    - rewards: [T] 奖励序列
    - terminals: [T] 终止标志序列
    """
    
    def __init__(
        self,
        capacity: int,
        obs_dim: int = 0,
        action_dim: int = 0,
        seq_len: int = 16,
        use_rgb: bool = False,
        rgb_shape: Tuple[int, int, int, int] = None,  # (num_cameras, C, H, W)
        device: str = "cuda:0"
    ):
        """
        Args:
            capacity: Buffer 容量（样本数量）
            obs_dim: 低维观测维度
            action_dim: 动作维度
            seq_len: 序列长度
            use_rgb: 是否使用 RGB 观测
            rgb_shape: RGB 观测形状 (num_cameras, C, H, W)
            device: 设备
        """
        self.capacity = capacity
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.seq_len = seq_len
        self.use_rgb = use_rgb
        self.rgb_shape = rgb_shape
        self.device = device
        
        # 存储
        self.obs_buffer = []
        self.next_obs_buffer = []
        self.action_buffer = []
        self.reward_buffer = []
        self.terminal_buffer = []
        
        # RGB 存储（如果使用）
        if use_rgb:
            self.rgb_buffer = []
            self.next_rgb_buffer = []
        
        self.size = 0
        self.ptr = 0
    
    def add_episode(
        self,
        obs: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        terminals: np.ndarray = None,
        rgb_obs: np.ndarray = None
    ):
        """
        添加一个 episode 的数据
        
        Args:
            obs: [T+1, obs_dim] 低维观测序列（包含最后一个状态）
            actions: [T, action_dim] 动作序列
            rewards: [T] 奖励序列
            terminals: [T] 终止标志（可选，默认最后一步为 True）
            rgb_obs: [T+1, num_cameras, C, H, W] RGB 观测序列（可选）
        """
        ep_len = len(actions)
        
        # 如果没有提供 terminals，创建默认值
        if terminals is None:
            terminals = np.zeros(ep_len, dtype=bool)
            terminals[-1] = True
        
        # 滑动窗口采样
        for i in range(ep_len - self.seq_len + 1):
            # 提取序列
            obs_seq = obs[i:i+self.seq_len]  # [seq_len, obs_dim]
            next_obs_seq = obs[i+1:i+self.seq_len+1]  # [seq_len, obs_dim]
            action_seq = actions[i:i+self.seq_len]  # [seq_len, action_dim]
            reward_seq = rewards[i:i+self.seq_len]  # [seq_len]
            terminal_seq = terminals[i:i+self.seq_len]  # [seq_len]
            
            # RGB 观测
            if self.use_rgb and rgb_obs is not None:
                rgb_seq = rgb_obs[i:i+self.seq_len]
                next_rgb_seq = rgb_obs[i+1:i+self.seq_len+1]
            
            # 添加到 buffer
            if len(self.obs_buffer) < self.capacity:
                self.obs_buffer.append(obs_seq)
                self.next_obs_buffer.append(next_obs_seq)
                self.action_buffer.append(action_seq)
                self.reward_buffer.append(reward_seq)
                self.terminal_buffer.append(terminal_seq)
                
                if self.use_rgb and rgb_obs is not None:
                    self.rgb_buffer.append(rgb_seq)
                    self.next_rgb_buffer.append(next_rgb_seq)
                
                self.size += 1
            else:
                # 循环覆盖
                self.obs_buffer[self.ptr] = obs_seq
                self.next_obs_buffer[self.ptr] = next_obs_seq
                self.action_buffer[self.ptr] = action_seq
                self.reward_buffer[self.ptr] = reward_seq
                self.terminal_buffer[self.ptr] = terminal_seq
                
                if self.use_rgb and rgb_obs is not None:
                    self.rgb_buffer[self.ptr] = rgb_seq
                    self.next_rgb_buffer[self.ptr] = next_rgb_seq
            
            self.ptr = (self.ptr + 1) % self.capacity
    
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        """
        获取一个样本
        
        Returns:
            batch: 包含 obs, next_obs, actions, rewards, terminals 的字典
                  如果使用 RGB，还包含 rgb_obs, next_rgb_obs
        """
        batch = {
            'obs': torch.from_numpy(self.obs_buffer[idx]).float(),
            'next_obs': torch.from_numpy(self.next_obs_buffer[idx]).float(),
            'actions': torch.from_numpy(self.action_buffer[idx]).float(),
            'rewards': torch.from_numpy(self.reward_buffer[idx]).float(),
            'terminals': torch.from_numpy(self.terminal_buffer[idx]).float(),
        }
        
        if self.use_rgb:
            batch['rgb_obs'] = torch.from_numpy(self.rgb_buffer[idx]).float()
            batch['next_rgb_obs'] = torch.from_numpy(self.next_rgb_buffer[idx]).float()
        
        return batch
    
    def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:
        """
        随机采样一个 batch
        
        Args:
            batch_size: batch 大小
            
        Returns:
            batch: 包含堆叠后的张量的字典
        """
        if self.size == 0:
            raise ValueError("Buffer 为空，无法采样")
        
        # 随机采样索引
        indices = np.random.randint(0, self.size, size=batch_size)
        
        # 收集样本
        obs_batch = np.stack([self.obs_buffer[i] for i in indices])
        next_obs_batch = np.stack([self.next_obs_buffer[i] for i in indices])
        action_batch = np.stack([self.action_buffer[i] for i in indices])
        reward_batch = np.stack([self.reward_buffer[i] for i in indices])
        terminal_batch = np.stack([self.terminal_buffer[i] for i in indices])
        
        batch = {
            'obs': torch.from_numpy(obs_batch).float(),
            'next_obs': torch.from_numpy(next_obs_batch).float(),
            'actions': torch.from_numpy(action_batch).float(),
            'rewards': torch.from_numpy(reward_batch).float(),
            'terminals': torch.from_numpy(terminal_batch).float(),
        }
        
        if self.use_rgb:
            rgb_batch = np.stack([self.rgb_buffer[i] for i in indices])
            next_rgb_batch = np.stack([self.next_rgb_buffer[i] for i in indices])
            batch['rgb_obs'] = torch.from_numpy(rgb_batch).float()
            batch['next_rgb_obs'] = torch.from_numpy(next_rgb_batch).float()
        
        return batch
    
    def save(self, filepath: str):
        """保存 buffer"""
        data = {
            'obs_buffer': self.obs_buffer,
            'next_obs_buffer': self.next_obs_buffer,
            'action_buffer': self.action_buffer,
            'reward_buffer': self.reward_buffer,
            'terminal_buffer': self.terminal_buffer,
            'size': self.size,
            'ptr': self.ptr,
            'capacity': self.capacity,
            'obs_dim': self.obs_dim,
            'action_dim': self.action_dim,
            'seq_len': self.seq_len,
            'use_rgb': self.use_rgb,
            'rgb_shape': self.rgb_shape,
        }
        
        if self.use_rgb:
            data['rgb_buffer'] = self.rgb_buffer
            data['next_rgb_buffer'] = self.next_rgb_buffer
        
        with open(filepath, 'wb') as f:
            pickle.dump(data, f)
        print(f"Buffer 已保存至: {filepath}")
    
    def load(self, filepath: str):
        """加载 buffer"""
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        
        self.obs_buffer = data['obs_buffer']
        self.next_obs_buffer = data['next_obs_buffer']
        self.action_buffer = data['action_buffer']
        self.reward_buffer = data['reward_buffer']
        self.terminal_buffer = data['terminal_buffer']
        self.size = data['size']
        self.ptr = data['ptr']
        self.capacity = data['capacity']
        self.obs_dim = data['obs_dim']
        self.action_dim = data['action_dim']
        self.seq_len = data['seq_len']
        self.use_rgb = data.get('use_rgb', False)
        self.rgb_shape = data.get('rgb_shape', None)
        
        if self.use_rgb:
            self.rgb_buffer = data['rgb_buffer']
            self.next_rgb_buffer = data['next_rgb_buffer']
        
        print(f"Buffer 已从 {filepath} 加载，包含 {self.size} 个样本")


class DemoReplayBuffer:
    """
    从 demonstration 数据构建的 replay buffer
    
    支持加载多个 demo episode，并构建训练数据。
    """
    
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        seq_len: int,
        capacity: int = 1000000,
        normalize: bool = True,
        use_rgb: bool = False,
        rgb_shape: Tuple[int, int, int, int] = None
    ):
        """
        Args:
            obs_dim: 观测维度
            action_dim: 动作维度
            seq_len: 序列长度
            capacity: Buffer 容量
            normalize: 是否标准化数据
            use_rgb: 是否使用 RGB 观测
            rgb_shape: RGB 观测形状
        """
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.seq_len = seq_len
        self.normalize = normalize
        self.use_rgb = use_rgb
        
        # 创建 buffer
        self.buffer = SequentialReplayBuffer(
            capacity=capacity,
            obs_dim=obs_dim,
            action_dim=action_dim,
            seq_len=seq_len,
            use_rgb=use_rgb,
            rgb_shape=rgb_shape
        )
        
        # 统计信息
        self.obs_mean = None
        self.obs_std = None
        self.action_mean = None
        self.action_std = None
    
    def add_demos(
        self,
        demo_episodes: List[Dict[str, np.ndarray]]
    ):
        """
        添加多个 demo episode
        
        Args:
            demo_episodes: List of episodes, 每个 episode 是一个字典:
                - 'obs': [T+1, obs_dim]
                - 'actions': [T, action_dim]
                - 'rewards': [T] (可选，默认全为 1.0)
                - 'terminals': [T] (可选，默认最后一步为 True)
                - 'rgb_obs': [T+1, num_cameras, C, H, W] (如果使用 RGB)
        """
        all_obs = []
        all_actions = []
        all_rewards = []
        
        for episode in demo_episodes:
            obs = episode['obs']
            actions = episode['actions']
            
            # 处理 rewards
            if 'rewards' in episode:
                rewards = episode['rewards']
            else:
                # 默认所有 reward 为 1.0（successful demo）
                rewards = np.ones(len(actions), dtype=np.float32)
            
            all_obs.append(obs)
            all_actions.append(actions)
            all_rewards.append(rewards)
        
        # 计算统计信息
        if self.normalize:
            all_obs_concat = np.concatenate(all_obs, axis=0)
            all_actions_concat = np.concatenate(all_actions, axis=0)
            
            self.obs_mean = np.mean(all_obs_concat, axis=0)
            self.obs_std = np.std(all_obs_concat, axis=0) + 1e-8
            self.action_mean = np.mean(all_actions_concat, axis=0)
            self.action_std = np.std(all_actions_concat, axis=0) + 1e-8
            
            print("数据统计信息:")
            print(f"  观测均值: {self.obs_mean[:5]}... (显示前5维)")
            print(f"  观测标准差: {self.obs_std[:5]}...")
            print(f"  动作均值: {self.action_mean[:5]}...")
            print(f"  动作标准差: {self.action_std[:5]}...")
        
        # 标准化并添加到 buffer
        for episode in demo_episodes:
            obs = episode['obs'].copy()
            actions = episode['actions'].copy()
            
            # 处理 rewards
            if 'rewards' in episode:
                rewards = episode['rewards'].copy()
            else:
                rewards = np.ones(len(actions), dtype=np.float32)
            
            # 处理 terminals
            terminals = episode.get('terminals', None)
            
            # 处理 RGB
            rgb_obs = episode.get('rgb_obs', None) if self.use_rgb else None
            
            if self.normalize:
                obs = (obs - self.obs_mean) / self.obs_std
                actions = (actions - self.action_mean) / self.action_std
            
            self.buffer.add_episode(obs, actions, rewards, terminals, rgb_obs)
        
        print(f"已添加 {len(demo_episodes)} 个 episodes，buffer 大小: {len(self.buffer)}")
    
    def __len__(self):
        """返回 buffer 大小"""
        return len(self.buffer)
    
    def sample(self, batch_size: int) -> Dict[str, np.ndarray]:
        """从 buffer 中采样"""
        return self.buffer.sample(batch_size)
    
    def get_buffer(self):
        """获取 buffer"""
        return self.buffer
    
    def get_statistics(self) -> Dict[str, np.ndarray]:
        """获取统计信息"""
        return {
            'obs_mean': self.obs_mean,
            'obs_std': self.obs_std,
            'action_mean': self.action_mean,
            'action_std': self.action_std
        }
    
    def save(self, save_dir: str):
        """保存 buffer 和统计信息"""
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        
        # 保存 buffer
        self.buffer.save(str(save_dir / 'buffer.pkl'))
        
        # 保存统计信息
        if self.normalize:
            stats = self.get_statistics()
            with open(save_dir / 'statistics.pkl', 'wb') as f:
                pickle.dump(stats, f)
            print(f"统计信息已保存至: {save_dir / 'statistics.pkl'}")
    
    def load(self, save_dir: str):
        """加载 buffer 和统计信息"""
        save_dir = Path(save_dir)
        
        # 加载 buffer
        self.buffer.load(str(save_dir / 'buffer.pkl'))
        
        # 加载统计信息
        if self.normalize and (save_dir / 'statistics.pkl').exists():
            with open(save_dir / 'statistics.pkl', 'rb') as f:
                stats = pickle.load(f)
            self.obs_mean = stats['obs_mean']
            self.obs_std = stats['obs_std']
            self.action_mean = stats['action_mean']
            self.action_std = stats['action_std']
            print("统计信息已加载")

