"""
Self-transition Any-step Dynamics Model (SADM)
基于 robobase 实现的动力学模型
"""

import torch
import torch.nn as nn
from torch.nn import functional as F


class Swish(nn.Module):
    """Swish 激活函数"""
    def __init__(self):
        super(Swish, self).__init__()

    def forward(self, x):
        x = x * torch.sigmoid(x)
        return x


class ResBlock(nn.Module):
    """残差块"""
    def __init__(
        self,
        input_dim,
        output_dim,
        activation=None,
        layer_norm=True,
        with_residual=True,
        dropout=0.1
    ):
        super().__init__()
        
        if activation is None:
            activation = Swish()

        self.linear = nn.Linear(input_dim, output_dim)
        self.activation = activation
        self.layer_norm = nn.LayerNorm(output_dim) if layer_norm else None
        self.dropout = nn.Dropout(dropout) if dropout else None
        self.with_residual = with_residual
    
    def forward(self, x):
        y = self.activation(self.linear(x))
        if self.dropout is not None:
            y = self.dropout(y)
        if self.with_residual:
            y = x + y
        if self.layer_norm is not None:
            y = self.layer_norm(y)
        return y


class SADModel(nn.Module):
    """
    Self-transition Any-step Dynamics Model (SADM)
    
    可以预测任意步长的状态转移，使用 GRU 处理动作序列。
    输出包含均值和方差（用于不确定性估计）。
    """

    def __init__(
        self,
        in_dim,
        out_dim,
        action_dim,
        hidden_dim=200,
        rnn_num_layers=3,
        dropout=0.1,
        device="cuda:0"
    ):
        """
        Args:
            in_dim: 输入特征维度（编码后的观测维度）
            out_dim: 输出维度（要预测的状态维度）
            action_dim: 动作维度
            hidden_dim: 隐藏层维度
            rnn_num_layers: GRU 层数
            dropout: Dropout 概率
            device: 设备
        """
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.action_dim = action_dim
        self.rnn_num_layers = rnn_num_layers
        self.hidden_dim = hidden_dim
        self.device = device
        
        # 观测编码器：将观测特征映射为 GRU 的初始隐藏状态
        self.encoder = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            nn.Linear(hidden_dim, hidden_dim * rnn_num_layers)
        )
        
        # GRU：处理动作序列
        self.rnn_layer = nn.GRU(
            input_size=action_dim,
            hidden_size=hidden_dim,
            num_layers=rnn_num_layers,
            batch_first=True
        )
        
        # 输出层：预测下一状态的均值和对数方差
        self.out_layer = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            ResBlock(hidden_dim, hidden_dim, dropout=dropout),
            nn.Linear(hidden_dim, self.out_dim * 2)  # 输出均值和对数方差
        )
        self.to(device)

    def forward(self, obs, act_seq):
        """
        单步前向传播
        
        Args:
            obs: [B, in_dim] 当前观测特征
            act_seq: [B, T, action_dim] 动作序列
            
        Returns:
            output: [B, out_dim*2] 预测输出（均值和对数方差）
            h_state: [rnn_num_layers, B, hidden_dim] GRU 隐藏状态
        """
        self.rnn_layer.flatten_parameters()
        
        # 编码观测为初始隐藏状态
        h_state = self.encoder(obs)
        h_state = h_state.view(-1, self.hidden_dim, self.rnn_num_layers).permute(2, 0, 1)
        h_state = h_state.contiguous()
        
        # 保存倒数第二个隐藏状态
        seclast_h = h_state[-1]
        
        # 通过 GRU 处理动作序列
        rnn_out, h_state = self.rnn_layer(act_seq, h_state)
        
        # 更新 seclast_h
        if rnn_out.shape[1] > 1:
            seclast_h = rnn_out[:, -2]
        last_h = rnn_out[:, -1]
        
        # 连接倒数第二个和最后一个隐藏状态
        next_in = torch.cat((seclast_h, last_h), dim=-1)
        output = self.out_layer(next_in)
        
        return output, h_state

    def forward_all(self, obs, act_seq):
        """
        批量预测所有时间步
        
        Args:
            obs: [B, in_dim] 初始观测
            act_seq: [B, T, action_dim] 动作序列
            
        Returns:
            outputs: [B, T, out_dim*2] 所有时间步的预测
        """
        # 初始化隐藏状态
        self.set_hiddens(self.init_hiddens(obs.unsqueeze(1), None))
        
        outputs = []
        for t in range(act_seq.shape[1]):
            output = self.transition(act_seq[:, t])
            output = output.squeeze(0)
            outputs.append(output)
        
        return torch.stack(outputs, dim=1)
    
    def encode_obs(self, obs):
        """编码观测"""
        return self.encoder(obs)
    
    def init_hiddens(self, obs_seq, act_seq):
        """
        初始化隐藏状态序列
        
        Args:
            obs_seq: [B, M, in_dim] 观测序列
            act_seq: [B, M-1, action_dim] 动作序列（可以为 None）
        """
        hiddens = []
        bs, m, _ = obs_seq.shape
        
        if act_seq is not None:
            for i in range(m - 1):
                _, hidden = self.forward(obs_seq[:, i], act_seq[:, i:])
                hidden = hidden.permute(1, 2, 0).reshape(bs, -1)
                hiddens.append(hidden)
        
        # 最后一个隐藏状态直接编码观测
        hiddens.append(self.encode_obs(obs_seq[:, -1]))
        
        # [M, B, hidden_dim * rnn_num_layers]
        hiddens = torch.stack(hiddens, dim=0)
        return hiddens
    
    def set_hiddens(self, hiddens, env_ids=None):
        """设置隐藏状态"""
        if env_ids is None:
            self.hiddens = hiddens
            self.n_hiddens, self.n_parallels, _ = hiddens.shape
        else:
            self.hiddens[:, env_ids] = hiddens
            
    def update_hiddens(self, hiddens, env_ids):
        """更新隐藏状态（滚动窗口）"""
        self.hiddens[:, env_ids] = torch.cat((self.hiddens[1:, env_ids], hiddens[None]), dim=0)
        
    def transition(self, action):
        """
        单步状态转移
        
        Args:
            action: [B, action_dim] 动作
            
        Returns:
            output: [n_hiddens, B, out_dim*2] 多个历史状态的预测
        """
        # 为每个隐藏状态复制动作
        action = torch.cat([action] * self.n_hiddens, dim=0)
        
        # 重塑隐藏状态
        h_state = self.hiddens.view(-1, self.hidden_dim, self.rnn_num_layers).permute(2, 0, 1)
        h_state = h_state.contiguous()
        seclast_h = h_state[-1]
        
        # GRU 前向传播
        rnn_out, h_state = self.rnn_layer(action[:, None], h_state)
        
        if rnn_out.shape[1] > 1:
            seclast_h = rnn_out[:, -2]
        last_h = rnn_out[:, -1]
        
        # 更新隐藏状态
        self.hiddens = h_state.permute(1, 2, 0).reshape(self.n_hiddens, self.n_parallels, -1)
        
        # 输出预测
        next_in = torch.cat((seclast_h, last_h), dim=-1)
        output = self.out_layer(next_in).view(self.n_hiddens, self.n_parallels, -1)
        
        return output

