"""
MLP World Model - 简单的基于MLP的世界模型
根据当前state和action预测下一步state，不依赖action chunk
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple
import numpy as np


def soft_clamp(x: torch.Tensor, _min=None, _max=None):
    """软裁剪：保持梯度的裁剪函数"""
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x


class MLPBlock(nn.Module):
    """MLP 块"""
    
    def __init__(self, input_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.activation = nn.ReLU()
        self.layer_norm = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else None
    
    def forward(self, x):
        x = self.fc(x)
        x = self.activation(x)
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.layer_norm(x)
        return x


class MLPWorldModel(nn.Module):
    """
    基于MLP的简单世界模型
    
    功能：
    - 根据当前 state 和单步 action 预测下一步 state
    - 支持 symlog 变换
    - 支持方差预测（不确定性估计）
    - 支持残差预测
    """
    
    def __init__(
        self,
        # 观测和动作维度
        obs_dim: int,
        action_dim: int,
        # MLP 参数
        hidden_dims: list = [256, 256, 256],
        dropout: float = 0.1,
        # 训练参数
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-5,
        grad_clip: Optional[float] = 100.0,
        # 损失配置
        use_symlog: bool = True,
        use_var: bool = True,
        use_residual: bool = True,
        # 设备
        device: str = "cuda:0"
    ):
        """
        Args:
            obs_dim: 观测维度
            action_dim: 动作维度
            hidden_dims: MLP 隐藏层维度列表
            dropout: Dropout 概率
            learning_rate: 学习率
            weight_decay: 权重衰减
            grad_clip: 梯度裁剪阈值
            use_symlog: 是否使用 symlog 变换
            use_var: 是否预测方差
            use_residual: 是否使用残差预测
            device: 设备
        """
        super().__init__()
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = torch.device(device)
        
        # 训练配置
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.grad_clip = grad_clip
        self.use_symlog = use_symlog
        self.use_var = use_var
        self.use_residual = use_residual
        
        # 构建 MLP 网络
        input_dim = obs_dim + action_dim
        layers = []
        
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(MLPBlock(prev_dim, hidden_dim, dropout))
            prev_dim = hidden_dim
        
        self.mlp = nn.Sequential(*layers)
        
        # 输出层：预测均值和对数方差（如果需要）
        output_dim = obs_dim * 2 if use_var else obs_dim
        self.output_layer = nn.Linear(prev_dim, output_dim)
        
        # 可训练的方差上下界
        if use_var:
            self.register_parameter(
                "max_logvar",
                nn.Parameter(torch.ones(obs_dim, device=self.device) * 0.5, requires_grad=True)
            )
            self.register_parameter(
                "min_logvar",
                nn.Parameter(torch.ones(obs_dim, device=self.device) * -10, requires_grad=True)
            )
        
        # 移动到设备
        self.to(self.device)
        
        # 优化器
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=1e-8
        )
        
        # 日志标志
        self.logging = True
    
    def symlog(self, x: torch.Tensor) -> torch.Tensor:
        """Symlog 变换：sign(x) * log(|x| + 1)"""
        return torch.sign(x) * torch.log(torch.abs(x) + 1)
    
    def symexp(self, x: torch.Tensor) -> torch.Tensor:
        """Symlog 的逆变换：sign(x) * (exp(|x|) - 1)"""
        return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
    
    def forward(
        self, 
        obs: torch.Tensor,
        action: torch.Tensor
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        前向传播
        
        Args:
            obs: [B, obs_dim] 当前观测
            action: [B, action_dim] 动作
            
        Returns:
            mean: [B, obs_dim] 预测均值
            logvar: [B, obs_dim] 预测对数方差（如果 use_var=True）
        """
        # 拼接观测和动作
        x = torch.cat([obs, action], dim=-1)  # [B, obs_dim + action_dim]
        
        # 通过 MLP
        features = self.mlp(x)  # [B, hidden_dim]
        
        # 输出层
        output = self.output_layer(features)  # [B, obs_dim*2] 或 [B, obs_dim]
        
        if self.use_var:
            mean, logvar = torch.chunk(output, 2, dim=-1)
            # 裁剪方差
            logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar)
            return mean, logvar
        else:
            return output, None
    
    def compute_loss(
        self,
        obs: torch.Tensor,
        next_obs: torch.Tensor,
        actions: torch.Tensor
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        计算损失
        
        Args:
            obs: [B, obs_dim] 当前观测
            next_obs: [B, obs_dim] 下一状态观测
            actions: [B, action_dim] 动作
            
        Returns:
            loss: 总损失
            metrics: 损失指标字典
        """
        # 计算目标
        if self.use_residual:
            # 残差预测：预测 delta = next_obs - obs
            target = next_obs - obs
        else:
            # 直接预测：预测 next_obs
            target = next_obs
        
        # 应用 symlog 到目标
        if self.use_symlog:
            target = self.symlog(target)
        
        # 预测
        mean, logvar = self.forward(obs, actions)
        
        # 计算损失
        raw_mse = torch.pow(mean - target, 2).mean()
        
        if self.use_var:
            # 带方差的损失
            inv_var = torch.exp(-logvar)
            mse_loss = (torch.pow(mean - target, 2) * inv_var).mean()
            var_loss = logvar.mean()
            total_loss = mse_loss + var_loss
        else:
            # 只使用 MSE 损失
            total_loss = raw_mse
            mse_loss = raw_mse
            var_loss = torch.tensor(0.0)
        
        # 记录指标
        metrics = {
            'total_loss': total_loss.item(),
            'raw_mse': raw_mse.item(),
        }
        
        if self.use_var:
            metrics['mse_loss'] = mse_loss.item()
            metrics['var_loss'] = var_loss.item()
        
        return total_loss, metrics
    
    def update(
        self,
        obs: torch.Tensor,
        next_obs: torch.Tensor,
        actions: torch.Tensor
    ) -> Dict[str, float]:
        """
        执行一次训练更新
        
        Args:
            obs: [B, obs_dim] 当前观测
            next_obs: [B, obs_dim] 下一状态观测
            actions: [B, action_dim] 动作
            
        Returns:
            metrics: 训练指标
        """
        # 计算损失
        loss, metrics = self.compute_loss(obs, next_obs, actions)
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪
        if self.grad_clip is not None:
            grad_norm = nn.utils.clip_grad_norm_(
                self.parameters(), 
                self.grad_clip
            )
            if self.logging:
                metrics['grad_norm'] = grad_norm.item()
        
        # 优化器步进
        self.optimizer.step()
        
        return metrics
    
    @torch.no_grad()
    def predict(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor
    ) -> torch.Tensor:
        """
        预测未来状态（推理模式）- 单步预测
        
        Args:
            obs: [B, obs_dim] 当前观测
            actions: [B, action_dim] 动作
            
        Returns:
            next_obs: [B, obs_dim] 预测的下一状态
        """
        self.eval()
        
        # 预测
        mean, _ = self.forward(obs, actions)
        
        # 如果是残差预测，需要加上当前观测
        if self.use_residual:
            if self.use_symlog:
                # symlog 空间的残差，需要先 symexp，再加上当前观测
                delta = self.symexp(mean)
                next_obs = obs + delta
            else:
                next_obs = obs + mean
        else:
            # 直接预测
            if self.use_symlog:
                next_obs = self.symexp(mean)
            else:
                next_obs = mean
        
        self.train()
        return next_obs
    
    @torch.no_grad()
    def predict_all(
        self,
        obs: torch.Tensor,
        action_sequence: torch.Tensor
    ) -> torch.Tensor:
        """
        预测未来状态序列（推理模式）- 多步rollout预测
        
        循环调用单步预测，逐步预测整个action序列对应的状态轨迹。
        
        Args:
            obs: [B, obs_dim] 初始观测
            action_sequence: [B, T, action_dim] 动作序列
            
        Returns:
            next_obs_sequence: [B, T, obs_dim] 预测的未来状态序列
        """
        self.eval()
        
        batch_size, seq_len, action_dim = action_sequence.shape
        obs_dim = obs.shape[-1]
        
        # 初始化输出
        predictions = []
        current_obs = obs  # [B, obs_dim]
        
        # 循环预测每一步
        for t in range(seq_len):
            action = action_sequence[:, t, :]  # [B, action_dim]
            
            # 单步预测
            next_obs = self.predict(current_obs, action)  # [B, obs_dim]
            predictions.append(next_obs)
            
            # 更新当前状态
            current_obs = next_obs
        
        # 堆叠所有预测结果
        next_obs_sequence = torch.stack(predictions, dim=1)  # [B, T, obs_dim]
        
        self.train()
        return next_obs_sequence
    
    def save(self, filepath: str):
        """保存模型"""
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': {
                'obs_dim': self.obs_dim,
                'action_dim': self.action_dim,
                'use_symlog': self.use_symlog,
                'use_var': self.use_var,
                'use_residual': self.use_residual,
            }
        }
        torch.save(checkpoint, filepath)
        print(f"模型已保存至: {filepath}")
    
    def load(self, filepath: str):
        """加载模型"""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"模型已从 {filepath} 加载")

