"""
World Model 训练器
支持从像素或低维观测输入，预测下一状态
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Tuple, Callable
import numpy as np

from .sadm import SADModel


def soft_clamp(x: torch.Tensor, _min=None, _max=None):
    """软裁剪：保持梯度的裁剪函数"""
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x


class SimpleConvEncoder(nn.Module):
    """简单的 CNN Encoder 用于处理图像"""
    
    def __init__(self, input_channels: int = 3, output_dim: int = 256, image_size: Tuple[int, int] = (256, 256)):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        
        # 计算输出维度（假设输入是 84x84）
        with torch.no_grad():
            dummy_input = torch.zeros(1, input_channels, image_size[0], image_size[1])
            dummy_output = self.encoder(dummy_input)
            self.flatten_dim = dummy_output.shape[1]
        
        self.fc = nn.Linear(self.flatten_dim, output_dim)
        self.output_dim = output_dim
    
    def forward(self, x):
        """
        Args:
            x: [B, C, H, W] 或 [B, V, C, H, W] (V是视角数)
        Returns:
            features: [B, output_dim] 或 [B, V, output_dim]
        """
        original_shape = x.shape
        
        # 处理多视角
        if len(x.shape) == 5:  # [B, V, C, H, W]
            B, V = x.shape[:2]
            x = x.reshape(B * V, *x.shape[2:])  # [B*V, C, H, W]
            x = x.float() / 255.0  # 归一化
            features = self.fc(self.encoder(x))
            features = features.reshape(B, V, -1)  # [B, V, output_dim]
        else:  # [B, C, H, W]
            x = x.float() / 255.0
            features = self.fc(self.encoder(x))
        
        return features


class WorldModel(nn.Module):
    """
    World Model 用于训练动力学模型
    
    功能：
    - 支持像素观测（RGB）和低维观测作为输入
    - 使用 SADM 预测下一状态
    - 支持 symlog 变换
    - 支持方差预测（不确定性估计）
    - 支持残差预测
    - 支持多视角融合
    """
    
    def __init__(
        self,
        # 观测和动作维度
        obs_dim: int,
        action_dim: int,
        # 图像观测配置
        use_pixels: bool = False,
        image_channels: int = 3,
        image_size: Tuple[int, int] = (84, 84),
        num_cameras: int = 1,
        # 视觉编码器配置
        use_dinov2: bool = False,
        dinov2_model_type: str = 'dinov2_vits14',
        dinov2_visual_feature_dim: int = 256,
        dinov2_mlp_hidden_dims: list = None,
        dinov2_use_cls_token: bool = True,
        dinov2_dropout: float = 0.0,
        # SADM 参数
        hidden_dim: int = 256,
        rnn_num_layers: int = 3,
        dropout: float = 0.1,
        # 训练参数
        learning_rate: float = 1e-4,
        weight_decay: float = 1e-5,
        grad_clip: Optional[float] = 100.0,
        # 损失配置
        use_symlog: bool = True,
        use_var: bool = True,
        use_residual: bool = True,
        # Framestack
        framestack: int = 1,
        # 设备
        device: str = "cuda:0"
    ):
        """
        Args:
            obs_dim: 低维观测维度
            action_dim: 动作维度
            use_pixels: 是否使用图像观测
            image_channels: 图像通道数（framestack 已经应用时，这个值应该是 C * framestack）
            image_size: 图像尺寸
            num_cameras: 相机数量
            use_dinov2: 是否使用 DINOv2 编码器（替代 SimpleConvEncoder）
            dinov2_model_type: DINOv2 模型类型
            dinov2_visual_feature_dim: DINOv2 输出特征维度
            dinov2_mlp_hidden_dims: DINOv2 MLP 隐藏层维度
            dinov2_use_cls_token: DINOv2 是否使用 CLS token
            dinov2_dropout: DINOv2 MLP dropout
            hidden_dim: SADM 隐藏层维度
            rnn_num_layers: GRU 层数
            dropout: Dropout 概率
            learning_rate: 学习率
            weight_decay: 权重衰减
            grad_clip: 梯度裁剪阈值
            use_symlog: 是否使用 symlog 变换
            use_var: 是否预测方差
            use_residual: 是否使用残差预测
            framestack: 帧堆叠数量（1表示不堆叠，3表示堆叠3帧）
            device: 设备
        """
        super().__init__()
        
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.use_pixels = use_pixels
        self.image_channels = image_channels
        self.image_size = image_size
        self.num_cameras = num_cameras
        self.device = torch.device(device)
        self.framestack = framestack
        self.use_dinov2 = use_dinov2
        
        # 训练配置
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.grad_clip = grad_clip
        self.use_symlog = use_symlog
        self.use_var = use_var
        self.use_residual = use_residual
        
        # 构建编码器
        self.pixel_encoder = None
        self.pixel_latent_dim = 0
        
        if self.use_pixels:
            if use_dinov2:
                # 使用 DINOv2 编码器
                from .dinov2_encoder import DINOv2MultiViewEncoder
                
                if dinov2_mlp_hidden_dims is None:
                    dinov2_mlp_hidden_dims = [256, 64]
                
                print(f"使用 DINOv2 编码器: {dinov2_model_type}")
                self.pixel_encoder = DINOv2MultiViewEncoder(
                    num_views=num_cameras,
                    visual_feature_dim=dinov2_visual_feature_dim,
                    mlp_hidden_dims=dinov2_mlp_hidden_dims,
                    model_type=dinov2_model_type,
                    use_cls_token=dinov2_use_cls_token,
                    dropout=dinov2_dropout,
                    framestack=framestack  # 传递 framestack 参数
                )
                self.pixel_encoder.to(self.device)
                # DINOv2 输出维度 = visual_feature_dim * framestack
                self.pixel_latent_dim = dinov2_visual_feature_dim * framestack
                self.view_fusion = None  # DINOv2Encoder 内部已处理多视角融合
            else:
                # 使用原来的 SimpleConvEncoder
                self.pixel_encoder = SimpleConvEncoder(
                    input_channels=image_channels,
                    output_dim=256,
                    image_size=image_size
                )
                self.pixel_encoder.to(self.device)
                self.pixel_latent_dim = self.pixel_encoder.output_dim
                
                # 多视角融合
                if num_cameras > 1:
                    self.view_fusion = nn.Linear(self.pixel_latent_dim * num_cameras, self.pixel_latent_dim)
                    self.view_fusion.to(self.device)
                else:
                    self.view_fusion = None
        
        # 低维观测编码器（可选）
        self.framestack_obs_dim = obs_dim * framestack
        self.low_dim_encoder = None
        self.low_dim_latent_dim = 0
        if obs_dim > 0:
            self.low_dim_encoder = nn.Sequential(
                nn.Linear(self.framestack_obs_dim, 256),
                nn.ReLU(),
                nn.Linear(256, 128)
            )
            self.low_dim_encoder.to(self.device)
            self.low_dim_latent_dim = 128
        
        # 计算特征维度
        feat_dim = self.pixel_latent_dim + self.low_dim_latent_dim
        
        # 构建 SADM 模型
        self.dynamics_model = SADModel(
            in_dim=feat_dim,
            out_dim=obs_dim,
            action_dim=action_dim,
            hidden_dim=hidden_dim,
            rnn_num_layers=rnn_num_layers,
            dropout=dropout,
            device=device
        )
        
        # 可训练的方差上下界
        if use_var:
            self.register_parameter(
                "max_logvar",
                nn.Parameter(torch.ones(obs_dim, device=self.device) * 0.5, requires_grad=True)
            )
            self.register_parameter(
                "min_logvar",
                nn.Parameter(torch.ones(obs_dim, device=self.device) * -10, requires_grad=True)
            )
        
        # 优化器
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            eps=1e-8
        )
        
        # 日志标志
        self.logging = True
    
    def symlog(self, x: torch.Tensor) -> torch.Tensor:
        """Symlog 变换：sign(x) * log(|x| + 1)"""
        return torch.sign(x) * torch.log(torch.abs(x) + 1)
    
    def symexp(self, x: torch.Tensor) -> torch.Tensor:
        """Symlog 的逆变换：sign(x) * (exp(|x|) - 1)"""
        return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
    
    def encode_obs(
        self,
        low_dim_obs: Optional[torch.Tensor] = None,
        rgb_obs: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        编码观测为特征向量
        
        Args:
            low_dim_obs: [B, obs_dim] 或 [B, obs_dim*framestack] 低维观测（如果 framestack>1，已经堆叠）
            rgb_obs: [B, V, C, H, W] 或 [B, V, C*framestack, H, W] RGB 观测（如果 framestack>1，已经堆叠）
            
        Returns:
            features: [B, feat_dim] 编码后的特征
        """
        features = []
        
        # 编码图像
        if self.use_pixels and rgb_obs is not None:
            # rgb_obs 已经堆叠好了（通道维度），直接编码
            if self.use_dinov2:
                # DINOv2 编码器：输入 [B, V, C*framestack, H, W]，输出 [B, dinov2_visual_feature_dim]
                # DINOv2 内部已经处理多视角融合
                pixel_features = self.pixel_encoder(rgb_obs)  # [B, dinov2_visual_feature_dim]
            else:
                # SimpleConvEncoder: 输入 [B, V, C*framestack, H, W]，输出 [B, V, pixel_latent_dim] 或 [B, pixel_latent_dim]
                pixel_features = self.pixel_encoder(rgb_obs)  # [B, V, pixel_latent_dim] 或 [B, pixel_latent_dim]
                
                # 多视角融合
                if self.view_fusion is not None and len(pixel_features.shape) == 3:
                    B, V = pixel_features.shape[:2]
                    pixel_features = pixel_features.reshape(B, -1)  # [B, V*pixel_latent_dim]
                    pixel_features = self.view_fusion(pixel_features)  # [B, pixel_latent_dim]
                elif len(pixel_features.shape) == 3:
                    # 简单平均
                    pixel_features = pixel_features.mean(dim=1)
            
            features.append(pixel_features)
        
        # 编码低维观测
        if low_dim_obs is not None and self.low_dim_encoder is not None:
            # low_dim_obs 已经堆叠好了（特征维度），直接编码
            if self.use_symlog:
                low_dim_obs = self.symlog(low_dim_obs)
            low_dim_features = self.low_dim_encoder(low_dim_obs)
            features.append(low_dim_features)
        
        # 拼接特征
        if len(features) == 0:
            raise ValueError("必须提供至少一种观测（low_dim_obs 或 rgb_obs）")
        
        return torch.cat(features, dim=-1)
    
    def forward(
        self, 
        low_dim_obs: Optional[torch.Tensor] = None,
        rgb_obs: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播
        
        Args:
            low_dim_obs: [B, obs_dim] 当前低维观测
            rgb_obs: [B, V, C, H, W] 当前 RGB 观测
            actions: [B, T, action_dim] 动作序列
            
        Returns:
            mean: [B, T, obs_dim] 预测均值
            logvar: [B, T, obs_dim] 预测对数方差
        """
        # 编码观测
        encoded_obs = self.encode_obs(low_dim_obs, rgb_obs)
        
        # 通过 SADM 预测
        model_out = self.dynamics_model.forward_all(encoded_obs, actions)  # [B, T, obs_dim*2]
        mean, logvar = torch.chunk(model_out, 2, dim=-1)
        
        # 裁剪方差
        if self.use_var:
            logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar)
        
        return mean, logvar
    
    def compute_loss(
        self,
        low_dim_obs: Optional[torch.Tensor] = None,
        rgb_obs: Optional[torch.Tensor] = None,
        next_low_dim_obs: Optional[torch.Tensor] = None,
        next_rgb_obs: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Dict[str, float]]:
        """
        计算损失
        
        Args:
            low_dim_obs: [B, T, obs_dim] 当前低维观测序列
            rgb_obs: [B, T, V, C, H, W] 当前 RGB 观测序列
            next_low_dim_obs: [B, T, obs_dim] 下一状态低维观测序列
            next_rgb_obs: [B, T, V, C, H, W] 下一状态 RGB 观测序列（用于计算损失，可选）
            actions: [B, T, action_dim] 动作序列
            
        Returns:
            loss: 总损失
            metrics: 损失指标字典
        """
        batch_size, seq_len = actions.shape[:2]
        
        # 只使用第一个观测作为初始状态
        first_low_dim_obs = None if low_dim_obs is None else low_dim_obs[:, 0]  # [B, obs_dim]
        first_rgb_obs = None if rgb_obs is None else rgb_obs[:, 0]  # [B, V, C, H, W]
        
        # 计算目标（只对低维观测）
        if next_low_dim_obs is None:
            raise ValueError("必须提供 next_low_dim_obs 作为预测目标")
        
        if self.use_residual:
            # 残差预测：预测 delta = next_obs - obs
            if low_dim_obs is None:
                raise ValueError("使用残差预测时必须提供 low_dim_obs")
            target = next_low_dim_obs - low_dim_obs
        else:
            # 直接预测：预测 next_obs
            target = next_low_dim_obs
        
        # 应用 symlog 到目标
        if self.use_symlog:
            target = self.symlog(target)
            # actions = self.symlog(actions)
        
        # 预测
        mean, logvar = self.forward(first_low_dim_obs, first_rgb_obs, actions)
        # 计算损失
        raw_mse = torch.pow(mean - target, 2).mean()
        raw_mse_perstep = torch.pow(mean - target, 2).mean(dim=0).mean(dim=-1)
        if self.use_var:
            # 带方差的损失
            inv_var = torch.exp(-logvar)
            mse_loss = (torch.pow(mean - target, 2) * inv_var).mean()
            var_loss = logvar.mean()
            total_loss = mse_loss + var_loss
        else:
            # 只使用 MSE 损失
            total_loss = raw_mse
            mse_loss = raw_mse
            var_loss = torch.tensor(0.0)
        
        # 记录指标
        metrics = {
            'total_loss': total_loss.item(),
            'raw_mse': raw_mse.item(),
        }
        seq_len = len(raw_mse_perstep)
        for i in range(0, seq_len, 5):
            metrics[f'raw_mse_step{i}'] = raw_mse_perstep[i].item()

        if self.use_var:
            metrics['mse_loss'] = mse_loss.item()
            metrics['var_loss'] = var_loss.item()
        
        return total_loss, metrics
    
    def update(
        self,
        low_dim_obs: Optional[torch.Tensor] = None,
        rgb_obs: Optional[torch.Tensor] = None,
        next_low_dim_obs: Optional[torch.Tensor] = None,
        next_rgb_obs: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None
    ) -> Dict[str, float]:
        """
        执行一次训练更新
        
        Args:
            low_dim_obs: [B, T, obs_dim] 低维观测序列
            rgb_obs: [B, T, V, C, H, W] RGB 观测序列
            next_low_dim_obs: [B, T, obs_dim] 下一状态低维观测序列
            next_rgb_obs: [B, T, V, C, H, W] 下一状态 RGB 观测序列
            actions: [B, T, action_dim] 动作序列
            
        Returns:
            metrics: 训练指标
        """
        # 计算损失
        loss, metrics = self.compute_loss(
            low_dim_obs, rgb_obs, next_low_dim_obs, next_rgb_obs, actions
        )
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        
        # 梯度裁剪
        if self.grad_clip is not None:
            grad_norm = nn.utils.clip_grad_norm_(
                self.parameters(), 
                self.grad_clip
            )
            if self.logging:
                metrics['grad_norm'] = grad_norm.item()
        
        # 优化器步进
        self.optimizer.step()
        
        return metrics
    
    @torch.no_grad()
    def predict(
        self,
        low_dim_obs: Optional[torch.Tensor] = None,
        rgb_obs: Optional[torch.Tensor] = None,
        actions: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        预测未来状态（推理模式）
        
        Args:
            low_dim_obs: [B, obs_dim] 当前低维观测
            rgb_obs: [B, V, C, H, W] 当前 RGB 观测
            actions: [B, T, action_dim] 动作序列
            
        Returns:
            next_obs: [B, T, obs_dim] 预测的未来低维状态
        """
        self.eval()
        
        # if self.use_symlog:
        #     actions = self.symlog(actions)
            
        # 预测
        mean, _ = self.forward(low_dim_obs, rgb_obs, actions)
        
        # 如果是残差预测，需要累加
        if self.use_residual:
            # mean 是增量，需要累加
            if low_dim_obs is None:
                raise ValueError("使用残差预测时必须提供 low_dim_obs")
            next_obs = torch.cumsum(mean, dim=1) + low_dim_obs.unsqueeze(1)
        else:
            next_obs = mean
        
        # 应用 symexp
        if self.use_symlog and not self.use_residual:
            next_obs = self.symexp(next_obs)
        
        self.train()
        return next_obs
    
    def save(self, filepath: str):
        """保存模型"""
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': {
                'obs_dim': self.obs_dim,
                'action_dim': self.action_dim,
                'use_symlog': self.use_symlog,
                'use_var': self.use_var,
                'use_residual': self.use_residual,
            }
        }
        torch.save(checkpoint, filepath)
        print(f"模型已保存至: {filepath}")
    
    def load(self, filepath: str):
        """加载模型"""
        checkpoint = torch.load(filepath, map_location=self.device)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"模型已从 {filepath} 加载")

