"""
配置文件
定义训练 world model 的所有配置参数
"""

from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Dict, Any


@dataclass
class ModelConfig:
    """模型配置"""
    # 维度
    obs_dim: int = 32  # 低维观测维度（需要根据实际数据设置）
    action_dim: int = 16  # 动作维度（需要根据实际数据设置）
    
    # SADM 架构
    hidden_dim: int = 256
    rnn_num_layers: int = 3
    dropout: float = 0.1
    
    # 训练选项
    use_symlog: bool = True  # 使用 symlog 变换
    use_var: bool = False  # 预测方差
    use_residual: bool = False  # 使用残差预测
    framestack: int = 1  # 帧堆叠数量（1表示不堆叠，3表示堆叠3帧）

@dataclass
class TrainingConfig:
    """训练配置"""
    # 基础设置
    seed: int = 42
    device: str = "cuda:0"
    num_gpus: int = 1
    
    # 训练超参数
    batch_size: int = 256
    seq_len: int = 10  # 序列长度
    learning_rate: float = 1e-4
    weight_decay: float = 1e-5
    grad_clip: Optional[float] = 100.0
    
    # 训练步数
    num_train_steps: int = 10000
    eval_every_steps: int = 1000
    log_every_steps: int = 1000
    save_every_steps: int = 1000
    
    num_workers: int = 4
    pin_memory: bool = True
    
@dataclass
class DataConfig:
    """数据配置"""
    # 数据路径
    data_dir: str = "/path/to/demo/data"  # 需要设置实际路径
    
    # 环境名称: 'libero', 'pickle'
    env_name: str = "libero"
    
    # 数据加载器参数（根据环境不同而不同）
    # Libero: dataset_names (多任务列表), use_rgb, camera_names, use_state
    # Pickle: obs_keys, action_key, reward_key, terminal_key, rgb_keys
    loader_kwargs: Dict[str, Any] = field(default_factory=dict)
    
    # 数据集设置
    num_demos: int = 100  # 使用的 demo 数量
    train_ratio: float = 0.9  # 训练集比例
    
    # 数据预处理
    normalize_obs: bool = True  # 是否标准化观测
    normalize_action: bool = True  # 是否标准化动作
    use_absolute_actions: bool = False  # 是否使用绝对动作


@dataclass
class LogConfig:
    """日志配置"""
    # 输出目录
    output_dir: str = "./exp_output"
    exp_name: str = "world_model"
    
    # 日志工具
    use_wandb: bool = True
    use_tensorboard: bool = False
    
    # wandb 设置
    wandb_project: str = "adads_world_model"
    wandb_entity: Optional[str] = None
    wandb_name: Optional[str] = None
    wandb_mode: str = "online"


@dataclass
class PolicyConfig:
    """总配置"""
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    log: LogConfig = field(default_factory=LogConfig)
    
    def to_dict(self):
        """转换为字典"""
        return {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'data': self.data.__dict__,
            'log': self.log.__dict__
        }
    
    @classmethod
    def from_dict(cls, config_dict):
        """从字典创建配置"""
        return cls(
            model=ModelConfig(**config_dict.get('model', {})),
            training=TrainingConfig(**config_dict.get('training', {})),
            data=DataConfig(**config_dict.get('data', {})),
            log=LogConfig(**config_dict.get('log', {}))
        )


# 默认配置
def get_default_config() -> PolicyConfig:
    """获取默认配置"""
    return PolicyConfig()



# Libero 环境的配置示例
def get_libero_policy_config(
    data_dir: str = "./libero_demos",
    dataset_names: List[str] = None,
    use_rgb: bool = True,
    camera_names: List[str] = None,
    use_state: bool = True,
    use_absolute_action: bool = False,
    **kwargs
) -> PolicyConfig:
    """
    获取 Libero 环境的配置（支持多任务）
    
    Args:
        data_dir: RLDS 数据目录路径（tensorflow_datasets 的数据目录）
        dataset_names: 数据集名称列表（支持多个 task suite），例如:
            ['libero_10_no_noops', 'libero_goal_no_noops', 'libero_object_no_noops', 'libero_spatial_no_noops']
            如果为 None，则使用所有默认的 task suites
        use_rgb: 是否使用 RGB 观测
        camera_names: 相机名称列表，默认 ['image', 'wrist_image']
            对应 RLDS 中的 observation['image'] 和 observation['wrist_image']
        use_state: 是否使用状态观测（observation['state']）
        **kwargs: 其他配置参数
    """
    config = PolicyConfig()
    
    # Libero 的默认维度
    # state: [8] (proprioception)
    # 如果不使用 RGB 观测，则只使用低维状态
    config.model.obs_dim = None  # 加载数据后自动推断
    config.model.action_dim = 7  # Libero 使用 7 维动作
    
    # RGB 配置
    if use_rgb:
        config.model.use_pixels = True
    
    # 数据配置
    config.data.env_name = "libero"
    config.data.data_dir = data_dir
    
    # 默认使用所有常见的 Libero task suites
    if dataset_names is None:
        dataset_names = [
            "libero_10_no_noops",
            "libero_goal_no_noops",
            "libero_object_no_noops",
            "libero_spatial_no_noops",
        ]
    
    config.data.loader_kwargs = {
        "dataset_names": dataset_names,
        "use_rgb": use_rgb,
        "camera_names": camera_names or ['image', 'wrist_image'],
        "use_state": use_state,
        "use_absolute_action": use_absolute_action,
    }
    
    # 日志配置
    config.log.wandb_project = "model_libero"

    config.log.exp_name = f"model_pixel{use_rgb}"
    
    return config
