"""
配置文件
定义训练 world model 的所有配置参数
"""

from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Dict, Any


@dataclass
class ModelConfig:
    """模型配置"""
    # 维度
    obs_dim: int = 32  # 低维观测维度（需要根据实际数据设置）
    action_dim: int = 16  # 动作维度（需要根据实际数据设置）
    
    # RGB 观测配置
    use_pixels: bool = False  # 是否使用 RGB 观测
    image_channels: int = 3  # 图像通道数
    image_size: Tuple[int, int] = (84, 84)  # 图像尺寸
    num_cameras: int = 1  # 相机数量
    
    # 视觉编码器配置
    use_dinov2: bool = False  # 是否使用 DINOv2 编码器（替代 SimpleConvEncoder）
    dinov2_model_type: str = 'dinov2_vits14'  # DINOv2 模型类型: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14
    dinov2_visual_feature_dim: int = 64  # DINOv2 输出特征维度（每帧）
    dinov2_mlp_hidden_dims: List[int] = field(default_factory=lambda: [256, 64])  # DINOv2 MLP 隐藏层维度
    dinov2_use_cls_token: bool = True  # DINOv2 是否使用 CLS token
    dinov2_dropout: float = 0.0  # DINOv2 MLP dropout
    
    # SADM 架构
    hidden_dim: int = 256
    rnn_num_layers: int = 3
    dropout: float = 0.1
    
    # 训练选项
    use_symlog: bool = True  # 使用 symlog 变换
    use_var: bool = False  # 预测方差
    use_residual: bool = False  # 使用残差预测
    framestack: int = 1  # 帧堆叠数量（1表示不堆叠，3表示堆叠3帧）

@dataclass
class PolicyModelConfig:
    """策略模型配置"""
    # 维度
    obs_dim: int = 32  # 低维观测维度（需要根据实际数据设置）
    action_dim: int = 16  # 动作维度（需要根据实际数据设置）
    
    # Safe-IQL 相关参数
    k_low: int = 1  # k 值的下限（降采样率的最小值）
    k_high: int = 2  # k 值的上限（降采样率的最大值）
    epsilon: float = 0.01  # 安全阈值，用于判断 transition 是否安全

@dataclass
class TrainingConfig:
    """训练配置"""
    # 基础设置
    seed: int = 42
    device: str = "cuda:0"
    num_gpus: int = 1
    
    # 训练超参数
    batch_size: int = 256
    seq_len: int = 10  # 序列长度
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    grad_clip: Optional[float] = 100.0
    
    # 训练步数
    num_train_steps: int = 10000
    eval_every_steps: int = 1000
    log_every_steps: int = 1000
    save_every_steps: int = 1000
    
    num_workers: int = 4
    pin_memory: bool = True
    
@dataclass
class DataConfig:
    """数据配置"""
    # 数据路径
    data_dir: str = "/path/to/demo/data"  # 需要设置实际路径
    
    # 环境名称: 'libero', 'pickle', 'hdf5'
    env_name: str = "libero"
    # 数据加载器参数（根据环境不同而不同）
    # Libero: dataset_names (多任务列表), use_rgb, camera_names, use_state
    # Pickle: obs_keys, action_key, reward_key, terminal_key, rgb_keys
    # HDF5: 无需额外参数（每个文件一个 episode）
    loader_kwargs: Dict[str, Any] = field(default_factory=dict)
    
    # 数据集设置
    num_demos: int = 100  # 使用的 demo 数量
    train_ratio: float = 0.9  # 训练集比例
    
    # 数据预处理
    normalize_obs: bool = True  # 是否标准化观测
    normalize_action: bool = True  # 是否标准化动作
    use_absolute_actions: bool = False  # 是否使用绝对动作

    # Replay buffer
    buffer_size: int = 1000000
    nstep: int = 1
    gamma: float = 0.99


@dataclass
class LogConfig:
    """日志配置"""
    # 输出目录
    output_dir: str = "./exp_output"
    exp_name: str = "world_model"
    
    # 日志工具
    use_wandb: bool = True
    use_tensorboard: bool = False
    
    # wandb 设置
    wandb_project: str = "adads_world_model"
    wandb_entity: Optional[str] = None
    wandb_name: Optional[str] = None
    wandb_mode: str = "online"


@dataclass
class Config:
    """总配置"""
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    log: LogConfig = field(default_factory=LogConfig)
    
    def to_dict(self):
        """转换为字典"""
        return {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'data': self.data.__dict__,
            'log': self.log.__dict__
        }
    
    @classmethod
    def from_dict(cls, config_dict):
        """从字典创建配置"""
        return cls(
            model=ModelConfig(**config_dict.get('model', {})),
            training=TrainingConfig(**config_dict.get('training', {})),
            data=DataConfig(**config_dict.get('data', {})),
            log=LogConfig(**config_dict.get('log', {}))
        )


@dataclass
class SafeIQLConfig:
    """Safe-IQL 训练配置"""

    # 维度
    control_mode: str = "delta"  # 控制模式："delta" 或 "absolute"
    obs_dim: int = 32  # 低维观测维度（需要根据实际数据设置）
    
    # Safe-IQL 相关参数
    k_low: int = 1  # k 值的下限（降采样率的最小值）
    k_high: int = 2  # k 值的上限（降采样率的最大值）
    epsilon: float = 3.0  # 安全阈值，用于判断 transition 是否安全
    seq_len: int = 10  # 序列长度

    # 网络架构
    hidden_dims: List[int] = field(default_factory=lambda: [256, 256])  # MLP 隐藏层维度
    rnn_hidden_size: int = 256  # RNN 隐藏层大小
    num_rnn_layers: int = 3  # RNN 层数
    
    # 训练超参数
    learning_rate: float = 1e-4  # 学习率
    expectile: float = 0.9  # Expectile 参数（用于 IQL 的 value function 更新）
    gamma: float = 0.1  # 折扣因子
    tau: float = 0.005  # 软更新系数（用于 target network）
    grad_clip_norm: float = 1.0  # 梯度裁剪范数
    
    # 训练设置
    num_epochs: int = 200  # 训练轮数
    batch_size: int = 512  # 批次大小
    
    # 检查点设置
    save_checkpoint: bool = True  # 是否保存检查点
    checkpoint_every_n_epochs: int = 10  # 每 N 个 epoch 保存一次检查点
    save_best_checkpoint: bool = True  # 是否保存最佳检查点
    
    # 评估设置
    eval_every_n_epochs: int = 10  # 每 N 个 epoch 评估一次
    eval_episodes: int = 50  # 评估时使用的 episode 数量
    
    # Dynamics 模型路径
    dynamics_snapshot_path: Optional[str] = None  # Dynamics 模型快照路径（如果为 None，则使用默认路径）


@dataclass
class PolicyConfig:
    """总配置"""
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    log: LogConfig = field(default_factory=LogConfig)
    safe_iql: SafeIQLConfig = field(default_factory=SafeIQLConfig)
    
    def to_dict(self):
        """转换为字典"""
        return {
            'training': self.training.__dict__,
            'data': self.data.__dict__,
            'log': self.log.__dict__,
            'safe_iql': self.safe_iql.__dict__
        }
    
    @classmethod
    def from_dict(cls, config_dict):
        """从字典创建配置"""
        return cls(
            training=TrainingConfig(**config_dict.get('training', {})),
            data=DataConfig(**config_dict.get('data', {})),
            log=LogConfig(**config_dict.get('log', {})),
            safe_iql=SafeIQLConfig(**config_dict.get('safe_iql', {}))
        )

# 默认配置
def get_default_config() -> Config:
    """获取默认配置"""
    return Config()



# Libero 环境的配置示例
def get_libero_config(
    data_dir: str = "./libero_demos",
    dataset_names: List[str] = None,
    use_rgb: bool = True,
    camera_names: List[str] = None,
    use_state: bool = True,
    use_absolute_action: bool = False,
    **kwargs
) -> Config:
    """
    获取 Libero 环境的配置（支持多任务）
    
    Args:
        data_dir: RLDS 数据目录路径（tensorflow_datasets 的数据目录）
        dataset_names: 数据集名称列表（支持多个 task suite），例如:
            ['libero_10_no_noops', 'libero_goal_no_noops', 'libero_object_no_noops', 'libero_spatial_no_noops']
            如果为 None，则使用所有默认的 task suites
        use_rgb: 是否使用 RGB 观测
        camera_names: 相机名称列表，默认 ['image', 'wrist_image']
            对应 RLDS 中的 observation['image'] 和 observation['wrist_image']
        use_state: 是否使用状态观测（observation['state']）
        **kwargs: 其他配置参数
    """
    config = Config()
    
    # Libero 的默认维度
    # state: [8] (proprioception)
    # 如果不使用 RGB 观测，则只使用低维状态
    config.model.obs_dim = None  # 加载数据后自动推断
    config.model.action_dim = 7  # Libero 使用 7 维动作
    
    # RGB 配置
    if use_rgb:
        config.model.use_pixels = True
    
    # 数据配置
    config.data.env_name = "libero"
    config.data.data_dir = data_dir
    
    # 默认使用所有常见的 Libero task suites
    if dataset_names is None:
        dataset_names = [
            "libero_10_no_noops",
            "libero_goal_no_noops",
            "libero_object_no_noops",
            "libero_spatial_no_noops",
        ]
    
    config.data.loader_kwargs = {
        "dataset_names": dataset_names,
        "use_state": use_state,
        "use_rgb": use_rgb,
        "camera_names": camera_names or ['image', 'wrist_image'],
        "use_absolute_action": use_absolute_action,
    }
    
    # 日志配置
    config.log.wandb_project = "model_libero"

    config.log.exp_name = f"model_pixel{use_rgb}"
    
    return config

# Libero 环境的配置示例
def get_libero_policy_config(
    data_dir: str = "./libero_demos",
    dataset_names: List[str] = None,
    use_rgb: bool = False,
    camera_names: List[str] = None,
    use_state: bool = True,
    use_absolute_action: bool = False,
    # Safe-IQL 重要参数
    k_low: int = 1,
    k_high: int = 2,
    epsilon: float = 0.01,
    # Safe-IQL 训练超参数
    learning_rate: float = 3e-4,
    expectile: float = 0.8,
    num_epochs: int = 100,
    batch_size: int = 64,
    gamma: float = 0.99,
    tau: float = 0.005,
    # Dynamics 模型路径
    dynamics_snapshot_path: Optional[str] = None,
    **kwargs
) -> PolicyConfig:
    """
    获取 Libero 环境的配置（支持多任务）
    
    Args:
        data_dir: RLDS 数据目录路径（tensorflow_datasets 的数据目录）
        dataset_names: 数据集名称列表（支持多个 task suite），例如:
            ['libero_10_no_noops', 'libero_goal_no_noops', 'libero_object_no_noops', 'libero_spatial_no_noops']
            如果为 None，则使用所有默认的 task suites
        use_rgb: 是否使用 RGB 观测
        camera_names: 相机名称列表，默认 ['image', 'wrist_image']
            对应 RLDS 中的 observation['image'] 和 observation['wrist_image']
        use_state: 是否使用状态观测（observation['state']）
        
        # Safe-IQL 重要参数
        k_low: k 值的下限（降采样率的最小值）
        k_high: k 值的上限（降采样率的最大值）
        epsilon: 安全阈值，用于判断 transition 是否安全
        
        # Safe-IQL 训练超参数
        learning_rate: 学习率
        expectile: Expectile 参数（用于 IQL 的 value function 更新）
        num_epochs: 训练轮数
        batch_size: 批次大小
        gamma: 折扣因子
        tau: 软更新系数（用于 target network）
        
        # Dynamics 模型路径
        dynamics_snapshot_path: Dynamics 模型快照路径（如果为 None，则使用默认路径）
        
        **kwargs: 其他配置参数（可用于设置 hidden_dims, rnn_hidden_size 等）
    """
    config = PolicyConfig()
    
    # Libero 的默认维度
    # state: [8] (proprioception)
    # 如果不使用 RGB 观测，则只使用低维状态
    config.safe_iql.obs_dim = None  # 加载数据后自动推断
    
    # Safe-IQL 重要参数
    config.safe_iql.k_low = k_low
    config.safe_iql.k_high = k_high
    config.safe_iql.epsilon = epsilon
    
    # Safe-IQL 训练超参数
    config.safe_iql.learning_rate = learning_rate
    config.safe_iql.expectile = expectile
    config.safe_iql.num_epochs = num_epochs
    config.safe_iql.batch_size = batch_size
    config.safe_iql.gamma = gamma
    config.safe_iql.tau = tau
    
    # Dynamics 模型路径
    if dynamics_snapshot_path is not None:
        config.safe_iql.dynamics_snapshot_path = dynamics_snapshot_path
    
    # 从 kwargs 中设置其他 Safe-IQL 参数
    if 'hidden_dims' in kwargs:
        config.safe_iql.hidden_dims = kwargs.pop('hidden_dims')
    if 'rnn_hidden_size' in kwargs:
        config.safe_iql.rnn_hidden_size = kwargs.pop('rnn_hidden_size')
    if 'num_rnn_layers' in kwargs:
        config.safe_iql.num_rnn_layers = kwargs.pop('num_rnn_layers')
    if 'grad_clip_norm' in kwargs:
        config.safe_iql.grad_clip_norm = kwargs.pop('grad_clip_norm')
    if 'checkpoint_every_n_epochs' in kwargs:
        config.safe_iql.checkpoint_every_n_epochs = kwargs.pop('checkpoint_every_n_epochs')
    if 'eval_every_n_epochs' in kwargs:
        config.safe_iql.eval_every_n_epochs = kwargs.pop('eval_every_n_epochs')
    if 'eval_episodes' in kwargs:
        config.safe_iql.eval_episodes = kwargs.pop('eval_episodes')
    
    # 数据配置
    config.data.env_name = "libero"
    config.data.data_dir = data_dir
    
    # 默认使用所有常见的 Libero task suites
    if dataset_names is None:
        dataset_names = [
            "libero_10_no_noops",
            "libero_goal_no_noops",
            "libero_object_no_noops",
            "libero_spatial_no_noops",
        ]
    
    config.data.loader_kwargs = {
        "dataset_names": dataset_names,
        "use_rgb": use_rgb,
        "camera_names": camera_names or ['image', 'wrist_image'],
        "use_state": use_state,
        "use_absolute_action": use_absolute_action,
    }
    
    # 日志配置
    config.log.wandb_project = "policy_libero"
    config.log.exp_name = f"safe_iql_k{k_low}_{k_high}_eps{epsilon}"
    
    # 应用 kwargs 中的其他参数（如 training, log 等）
    for key, value in kwargs.items():
        if hasattr(config.training, key):
            setattr(config.training, key, value)
        elif hasattr(config.log, key):
            setattr(config.log, key, value)
    
    return config


# HDF5 真机数据的配置示例
def get_hdf5_config(
    data_dir: str = "./datasets",
    **kwargs
) -> Config:
    """
    获取 HDF5 真机数据的配置
    
    Args:
        data_dir: HDF5 文件目录路径（包含多个 .hdf5 文件，每个文件一个 episode）
        **kwargs: 其他配置参数
    """
    config = Config()
    
    # HDF5 真机数据的维度
    # state = qpos (14) + eef_factory (14) = 28
    config.model.obs_dim = 28
    config.model.action_dim = 14
    
    # 不使用 RGB 观测
    config.model.use_pixels = False
    
    # 数据配置
    config.data.env_name = "hdf5"
    config.data.data_dir = data_dir
    config.data.loader_kwargs = {}
    # 日志配置
    config.log.wandb_project = "model_hdf5"
    config.log.exp_name = "hdf5_world_model"
    
    # 应用 kwargs 中的其他参数
    for key, value in kwargs.items():
        if hasattr(config.model, key):
            setattr(config.model, key, value)
        elif hasattr(config.training, key):
            setattr(config.training, key, value)
        elif hasattr(config.data, key):
            setattr(config.data, key, value)
        elif hasattr(config.log, key):
            setattr(config.log, key, value)
    
    return config


# HDF5 真机数据的策略配置示例
def get_hdf5_policy_config(
    data_dir: str = "./hdf5_demos",
    # Safe-IQL 重要参数
    k_low: int = 1,
    k_high: int = 2,
    epsilon: float = 0.01,
    # Safe-IQL 训练超参数
    learning_rate: float = 3e-4,
    expectile: float = 0.8,
    num_epochs: int = 100,
    batch_size: int = 64,
    gamma: float = 0.99,
    tau: float = 0.005,
    # Dynamics 模型路径
    dynamics_snapshot_path: Optional[str] = None,
    **kwargs
) -> PolicyConfig:
    """
    获取 HDF5 真机数据的策略配置
    
    Args:
        data_dir: HDF5 文件目录路径（包含多个 .hdf5 文件，每个文件一个 episode）
        
        # Safe-IQL 重要参数
        k_low: k 值的下限（降采样率的最小值）
        k_high: k 值的上限（降采样率的最大值）
        epsilon: 安全阈值，用于判断 transition 是否安全
        
        # Safe-IQL 训练超参数
        learning_rate: 学习率
        expectile: Expectile 参数（用于 IQL 的 value function 更新）
        num_epochs: 训练轮数
        batch_size: 批次大小
        gamma: 折扣因子
        tau: 软更新系数（用于 target network）
        
        # Dynamics 模型路径
        dynamics_snapshot_path: Dynamics 模型快照路径（如果为 None，则使用默认路径）
        
        **kwargs: 其他配置参数（可用于设置 hidden_dims, rnn_hidden_size 等）
    """
    config = PolicyConfig()
    
    # HDF5 真机数据的维度
    # state = qpos (14) + eef_rtb (14) = 28
    config.safe_iql.obs_dim = 28
    
    # Safe-IQL 重要参数
    config.safe_iql.k_low = k_low
    config.safe_iql.k_high = k_high
    config.safe_iql.epsilon = epsilon
    
    # Safe-IQL 训练超参数
    config.safe_iql.learning_rate = learning_rate
    config.safe_iql.expectile = expectile
    config.safe_iql.num_epochs = num_epochs
    config.safe_iql.batch_size = batch_size
    config.safe_iql.gamma = gamma
    config.safe_iql.tau = tau
    
    # Dynamics 模型路径
    if dynamics_snapshot_path is not None:
        config.safe_iql.dynamics_snapshot_path = dynamics_snapshot_path
    
    # 从 kwargs 中设置其他 Safe-IQL 参数
    if 'hidden_dims' in kwargs:
        config.safe_iql.hidden_dims = kwargs.pop('hidden_dims')
    if 'rnn_hidden_size' in kwargs:
        config.safe_iql.rnn_hidden_size = kwargs.pop('rnn_hidden_size')
    if 'num_rnn_layers' in kwargs:
        config.safe_iql.num_rnn_layers = kwargs.pop('num_rnn_layers')
    if 'grad_clip_norm' in kwargs:
        config.safe_iql.grad_clip_norm = kwargs.pop('grad_clip_norm')
    if 'checkpoint_every_n_epochs' in kwargs:
        config.safe_iql.checkpoint_every_n_epochs = kwargs.pop('checkpoint_every_n_epochs')
    if 'eval_every_n_epochs' in kwargs:
        config.safe_iql.eval_every_n_epochs = kwargs.pop('eval_every_n_epochs')
    if 'eval_episodes' in kwargs:
        config.safe_iql.eval_episodes = kwargs.pop('eval_episodes')
    
    # 数据配置
    config.data.env_name = "hdf5"
    config.data.data_dir = data_dir
    config.data.loader_kwargs = {}  # HDF5 加载器不需要额外参数
    
    # 日志配置
    config.log.wandb_project = "policy_hdf5"
    config.log.exp_name = f"safe_iql_k{k_low}_{k_high}_eps{epsilon}"
    
    # 应用 kwargs 中的其他参数（如 training, log 等）
    for key, value in kwargs.items():
        if hasattr(config.training, key):
            setattr(config.training, key, value)
        elif hasattr(config.log, key):
            setattr(config.log, key, value)
    
    return config

