"""
Dynamics Model Training Script
Learns a dynamics model (SADM) from replay buffer data.
Supports both pixel and low-dim observations as input (like DreamerV3).
Only predicts low-dimensional observations (no pixel reconstruction, no reward prediction).
"""

from typing import Iterator, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

from robobase import utils
from robobase.method.core import Method
from robobase.replay_buffer.replay_buffer import ReplayBuffer
from robobase.method.utils import (
    extract_from_spec,
    extract_many_from_spec,
    extract_from_batch,
    extract_many_from_batch,
    stack_tensor_dictionary,
    
)
from robobase.models.fusion import FusionMultiCamFeature
from robobase.models.encoder import EncoderModule
from robobase.models.fully_connected import FullyConnectedModule
from robobase.models.model_based.sadm import SADModel
from robobase.models.model_based.utils import BatchTimeInputDictWrapperModule

def soft_clamp(x : torch.Tensor, _min=None, _max=None):
    # clamp tensor values while maintaining the gradient
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x

class SADynamics(nn.Module):
    """
    Trainer for dynamics model (SADM) that learns to predict low-dimensional next observations.
    
    Like DreamerV3:
    - Supports both pixel (RGB) and low-dim observations as input
    - Uses encoder and fusion modules to process multi-modal inputs
    
    Unlike DreamerV3:
    - Uses SADM instead of RSSM
    - Does NOT reconstruct pixel observations (only predicts low-dim states)
    - Does NOT predict rewards
    - Does NOT train actor/critic for policy learning
    """
    
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        # Dynamics model parameters
        hidden_dim: int,
        rnn_num_layers: int,
        dropout: float,
        # Training parameters
        learning_rate: float,
        weight_decay: float,
        grad_clip: Optional[float],
        # Loss scales
        obs_loss_scale: float,
        use_symlog: bool,
        use_var: bool = True,
        use_qpos_pred: bool = False,
        use_pixels: bool = True,
        use_torch_compile: bool = False,
        # Model components (like DreamerV3)
        pixel_encoder_model: Optional[EncoderModule] = None,
        low_dim_encoder_model: Optional[FullyConnectedModule] = None,
        *args,
        **kwargs
    ):
        """
        Args:
            observation_space: Gymnasium observation space
            action_space: Gymnasium action space
            device: torch device
            hidden_dim: Hidden dimension for SADM
            rnn_num_layers: Number of RNN layers in SADM
            dropout: Dropout rate
            learning_rate: Learning rate for optimizer
            weight_decay: Weight decay for optimizer
            grad_clip: Gradient clipping magnitude
            obs_loss_scale: Scale for observation prediction loss
            use_symlog: Whether to use symlog transformation on observations
            pixel_encoder_model: Encoder for pixel observations (like DreamerV3)
            low_dim_encoder_model: Encoder for low-dim observations
            use_torch_compile: Use torch.compile for acceleration
        """
        super(SADynamics, self).__init__()
        self.observation_space = observation_space # framestack, xx
        self.action_space = action_space # sequence_length, xx
        self.device = device
        self._eval_env_running = False

        # Model parameters
        self.hidden_dim = hidden_dim
        self.rnn_num_layers = rnn_num_layers
        self.dropout = dropout
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.grad_clip = grad_clip
        self.use_symlog = use_symlog
        self.use_torch_compile = use_torch_compile
        self.pixel_encoder_model = pixel_encoder_model
        self.low_dim_encoder_model = low_dim_encoder_model
        # Loss scales
        self.obs_loss_scale = obs_loss_scale
        self.use_var = use_var
        self.use_qpos_pred = use_qpos_pred

        # Check what types of observations we have
        self.rgb_spaces = extract_many_from_spec(
            self.observation_space, r"rgb.*", missing_ok=True
        )
        self.use_pixels = use_pixels
        self.use_multicam_fusion = len(self.rgb_spaces) > 1
        
        # Get dimensions
        self.low_dim_size = self._get_low_dim_size()
        self.eef_dim_size = self._get_eef_dim_size()
        self.action_dim = action_space.shape[-1]
        print("Dimensions (low_dim_size, eef_dim_size, action_dim):")
        print(self.low_dim_size, self.eef_dim_size, self.action_dim)
        # Build model components (like DreamerV3)
        self.trainable_modules = []
        (
            self.pixel_encoder,
            self.low_dim_encoder,
            self.view_fusion,
        ) = (None,) * 3
        
        self.build_encoder()
        self.build_view_fusion()
        
        # Get feature dimension after encoding
        self.feat_dim = self._compute_feat_dim()
        self.output_size = self.eef_dim_size
        self.output_size += self.low_dim_size if self.use_qpos_pred else 0
        assert self.output_size > 0, "Output size must be greater than 0"
        # Build SADM dynamics model
        self.dynamics_model = self._build_dynamics_model()
        self.trainable_modules.append(self.dynamics_model)
        
        # trainable max/min logvar
        self.register_parameter(
            "max_logvar",
            nn.Parameter(torch.ones(self.output_size, device=self.device) * 0.5, requires_grad=True)
        )
        self.register_parameter(
            "min_logvar",
            nn.Parameter(torch.ones(self.output_size, device=self.device) * -10, requires_grad=True)
        )

        # Setup optimizer for all trainable modules
        self.optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
            eps=1e-8
        )
        
        # For logging
        self.logging = True
    
    def _get_low_dim_size(self) -> int:
        """Get the dimension of low-dimensional observations."""
        low_dim_state_spec = extract_from_spec(
            self.observation_space, "low_dim_state", missing_ok=True
        )
        if low_dim_state_spec is not None:
            return low_dim_state_spec.shape[-1]
        return 0

    def _get_eef_dim_size(self) -> int:
        """Get the dimension of end-effector force/torque observations."""
        eef_spec = extract_from_spec(
            self.observation_space, "eef", missing_ok=True
        )
        if eef_spec is not None:
            return eef_spec.shape[-1]
        return 0
    
    def build_encoder(self):
        """Builds the pixel or low-dimensional encoder"""
        rgb_spaces = extract_many_from_spec(
            self.observation_space, r"rgb.*", missing_ok=True
        )
        if len(rgb_spaces) > 0:
            # Build encoder for pixels
            rgb_shapes = [s.shape for s in rgb_spaces.values()]
            assert np.all(
                [sh == rgb_shapes[0] for sh in rgb_shapes]
            ), "Expected all RGB obs to be same shape."

            num_views = len(rgb_shapes)
            obs_shape = rgb_shapes[0][1:]  # [C, H, W]
            self.pixel_encoder = self.pixel_encoder_model(
                input_shape=(num_views, *obs_shape)
            )
            self.pixel_encoder.to(self.device)
            if self.use_torch_compile:
                self.pixel_encoder = torch.compile(self.pixel_encoder)
            self.trainable_modules.append(self.pixel_encoder)
            self.pixel_encoder.requires_grad_(requires_grad=False)

        self.low_dim_latent_size = 0
        if self.low_dim_size > 0:
            # Build encoder for low_dim_obs
            self.low_dim_encoder = BatchTimeInputDictWrapperModule(
                self.low_dim_encoder_model(
                    input_shapes={"low_dim_obs": (self.low_dim_size,)},
                ),
                "low_dim_obs",
            )
            self.low_dim_encoder.to(self.device)
            if self.use_torch_compile:
                self.low_dim_encoder = torch.compile(self.low_dim_encoder)
            self.low_dim_latent_size = self.low_dim_encoder.output_shape[-1]
            self.trainable_modules.append(self.low_dim_encoder)
            self.low_dim_encoder.requires_grad_(requires_grad=False)
    
    def build_view_fusion(self):
        """Builds the view fusion model for multi-view RGB observations (like DreamerV3)"""
        self.rgb_latent_size = 0
        if not self.use_pixels:
            return
        if self.use_multicam_fusion:
            self.view_fusion = FusionMultiCamFeature(
                input_shape=self.pixel_encoder.output_shape,
                mode="flatten",
            )
            self.view_fusion.to(self.device)
            if self.use_torch_compile:
                self.view_fusion = torch.compile(self.view_fusion)
            self.rgb_latent_size = self.view_fusion.output_shape[-1]
        else:
            # Input can be [B, T, ..] or [B, ..]
            self.view_fusion = lambda x: x[..., 0, :]
            self.rgb_latent_size = self.pixel_encoder.output_shape[-1]
    
    def _compute_feat_dim(self) -> int:
        """Compute feature dimension after encoding and fusion"""
        return self.rgb_latent_size + self.low_dim_latent_size
    
    def _build_dynamics_model(self) -> SADModel:
        """Build the SADM dynamics model."""
        # Create SADM model with fused feature as input
        model = SADModel(
            in_dim=self.feat_dim,
            out_dim=self.output_size,
            action_dim=self.action_dim,
            hidden_dim=self.hidden_dim,
            rnn_num_layers=self.rnn_num_layers,
            dropout=self.dropout,
            device=self.device
        )
        
        return model
    
    def symlog(self, x: torch.Tensor) -> torch.Tensor:
        """Apply symlog transformation: sign(x) * log(|x| + 1)"""
        return torch.sign(x) * torch.log(torch.abs(x) + 1)
    
    def symexp(self, x: torch.Tensor) -> torch.Tensor:
        """Inverse of symlog: sign(x) * (exp(|x|) - 1)"""
        return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
    
    def extract_batch(
        self, replay_iter: Iterator[Dict[str, torch.Tensor]]
    ) -> tuple[Dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Extract required data from batch (like DreamerV3).
        
        Returns:
            batch: Full batch dictionary
            rgb_obs: [B, T+1, V, C, H, W] RGB observations (or None)
            low_dim_obs: [B, T+1, D] low-dim observations (or None)
            actions: [B, T, A] actions
        """
        batch = next(replay_iter)
        batch = {k: v.to(self.device) for k, v in batch.items()}
        
        # Extract RGB observations if available
        rgb_obs = None
        if self.use_pixels:
            rgb_dict = extract_many_from_batch(batch, r"rgb(?!.*?tp1)")
            rgb_obs = stack_tensor_dictionary(rgb_dict, 2)  # [B, T+1, V, C, H, W]
        
        # Extract low-dim observations if available
        low_dim_obs = None
        if self.low_dim_size > 0:
            low_dim_obs = extract_from_batch(batch, 'low_dim_state')  # [B, T+1, D]
            if self.use_symlog:
                low_dim_obs = self.symlog(low_dim_obs)
        
        eef_obs = None
        if self.eef_dim_size > 0:
            eef_obs = extract_from_batch(batch, 'eef')  # [B, T+1, E]
            if self.use_symlog:
                eef_obs = self.symlog(eef_obs)

        # Extract actions
        actions = batch['action']  # [B, T+1, A]

        ret_rgb = None if rgb_obs is None else rgb_obs[:, :-1]
        ret_low_dim = None if low_dim_obs is None else low_dim_obs[:, :-1]
        ret_low_dim_tp1 = None if low_dim_obs is None else low_dim_obs[:, 1:]
        return batch, ret_rgb, ret_low_dim, ret_low_dim_tp1, actions[:, 1:], eef_obs
    
    def encode(
        self, low_dim_obs: torch.Tensor, rgb_obs: torch.Tensor
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """Encode raw low dimensional / RGB observations into embeddings (like DreamerV3)

        Args:
            low_dim_obs: torch.Tensor for low dimensional observations [B, D]
            rgb_obs: torch.Tensor for RGB observations [B, V, C, H, W]

        Returns:
            embed: Concatenated embeddings [B, feat_dim]
            outs_dict: Dictionary containing each embedding
        """
        # from robobase.models.model_based.utils import batch_time_forward
        
        outs, outs_dict = [], dict()
        if low_dim_obs is not None:
            low_dim_latent = self.low_dim_encoder(low_dim_obs)
            outs.append(low_dim_latent)
            outs_dict["low_dim_latent"] = low_dim_latent
        if rgb_obs is not None:
            multi_view_feats = self.pixel_encoder(rgb_obs.float())
            fused_view_feats = self.view_fusion(multi_view_feats)
            outs.append(fused_view_feats)
            outs_dict["fused_view_feats"] = fused_view_feats
        return torch.cat(outs, -1), outs_dict
    
    def compute_loss(
        self, 
        rgb_obs: torch.Tensor,
        low_dim_obs: torch.Tensor,
        low_dim_obs_tp1: torch.Tensor,
        actions: torch.Tensor,
        eef_obs_tp1: torch.Tensor
    ) -> tuple[torch.Tensor, Dict[str, float]]:
        """
        Compute training loss for low-dim observation prediction only.
        
        Args:
            rgb_obs: [B, T, V, C, H, W] RGB observations (or None)
            low_dim_obs: [B, T, D] low-dim observations (or None)
            low_dim_obs_tp1: [B, T, D] next low-dim observations (or None)
            actions: [B, T, A] actions
            
        Returns:
            total_loss: Observation prediction loss
            metrics: Dictionary of loss components
        """
        batch_size, seq_len = actions.shape[:2]
        
        # Encode observations to get fused features
        first_low_dim_obs = None if low_dim_obs is None else low_dim_obs[:, 0]
        first_rgb_obs = None if rgb_obs is None else rgb_obs[:, 0]
        embeds = self.encode(first_low_dim_obs, first_rgb_obs)[0]  # [B, feat_dim]
        action_seq = actions[:, :, 0]
        # Forward pass through SADM
        model_out = self.dynamics_model.forward_all(embeds, action_seq) # [B, out_dim]
        mean_eef, logvar = torch.chunk(model_out, 2, dim=-1)
        logvar = soft_clamp(logvar, self.min_logvar, self.max_logvar)

        if self.use_qpos_pred:
            mean_eef, mean_qpos = torch.split(mean_eef, [self.eef_dim_size, self.low_dim_size], dim=-1)
        
        
        def cal_loss(model_out, target, use_var):
            raw_mse = torch.pow(model_out - target, 2).mean()
            if use_var:
                inv_var = torch.exp(-logvar)
                mse_loss = (torch.pow(model_out - target, 2) * inv_var).mean()
                var_loss = logvar.mean()
                obs_loss = mse_loss + var_loss
            else:
                mse_loss = raw_mse
                obs_loss = mse_loss
                var_loss = 0
            return obs_loss, mse_loss, var_loss
        
        eef_obs_loss, eef_mse_loss, eef_var_loss = cal_loss(mean_eef, eef_obs_tp1, self.use_var)
        total_loss = eef_obs_loss
        metrics = {
            "eef_obs_loss": eef_obs_loss.item(),
            "eef_mse_loss": eef_mse_loss.item(),
            "eef_var_loss": eef_var_loss.item() if self.use_var else 0,
        }
        if self.use_qpos_pred:
            q_obs_loss, q_mse_loss, q_var_loss = cal_loss(mean_qpos, low_dim_obs_tp1, self.use_var)
            total_loss += q_obs_loss
            metrics.update({
                "q_obs_loss": q_obs_loss.item(),
                "q_mse_loss": q_mse_loss.item(),
                "q_var_loss": q_var_loss.item() if self.use_var else 0,
            })
        
        return total_loss, metrics
    
    def update(
        self,
        replay_iter: Iterator[Dict[str, torch.Tensor]],
        step: int,
        replay_buffer: ReplayBuffer = None,
    ) -> Dict[str, float]:
        """
        Perform one training update.
        
        Args:
            replay_iter: Iterator from replay buffer
            step: Current training step
            replay_buffer: Replay buffer (optional)
            
        Returns:
            metrics: Dictionary of training metrics
        """
        metrics = {}
        # Extract batch
        batch, rgb_obs, low_dim_obs, low_dim_obs_tp1, actions, eef_obs = self.extract_batch(replay_iter)
        eef_obs_tp1 = eef_obs[:, 1:] if eef_obs is not None else None

        # Compute loss
        loss, loss_metrics = self.compute_loss(rgb_obs, low_dim_obs, low_dim_obs_tp1, actions, eef_obs_tp1)
        metrics.update(loss_metrics)
        
        # Optimization step
        self.optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        if self.grad_clip is not None:
            grad_norm = nn.utils.clip_grad_norm_(
                self.dynamics_model.parameters(), 
                self.grad_clip
            )
            if self.logging:
                metrics['grad_norm'] = grad_norm.item()
        
        self.optimizer.step()
        
        return metrics
    
    def reset(self, step: int, agents_to_reset: list[int]):
        """
        Reset method (not used for dynamics model training).
        
        Args:
            step: Current step
            agents_to_reset: List of agent indices to reset
        """
        pass
    
    def step_multi(self, rgb_obs, low_dim_obs, action_sequence: torch.Tensor) -> torch.Tensor:
        """
        Predict next observation given current obs and action sequence.
        
        Args:
            rgb_obs: [B, V, C, H, W] current RGB observations (or None)
            low_dim_obs: [B, D] current low-dim observations (or None)
            action_sequence: [B, T, A] action sequence
        Returns:
            next_low_dim_pred: [B, T, D] predicted t-step next low-dim observation
        """
        self.dynamics_model.eval()
        if self.pixel_encoder is not None:
            self.pixel_encoder.eval()
        if self.low_dim_encoder is not None:
            self.low_dim_encoder.eval()
        
        with torch.no_grad():
            if self.use_symlog and low_dim_obs is not None:
                low_dim_obs = self.symlog(low_dim_obs)
            
            # Encode to get features
            embed = self.encode(low_dim_obs, rgb_obs)[0]  # [B, feat_dim]
            
            # Forward through SADM
            model_out = self.dynamics_model.forward_all(embed, action_sequence)
            mean, _ = torch.chunk(model_out, 2, dim=-1)
            # if self.use_residual:
            #     mean = torch.cumsum(mean, dim=1) + low_dim_obs.unsqueeze(1)
            next_low_dim_pred = mean
            if self.use_symlog:
                next_low_dim_pred = self.symexp(next_low_dim_pred)
            if self.use_qpos_pred:
                next_eef_pred, next_qpos_pred = torch.split(next_low_dim_pred, [self.eef_dim_size, self.low_dim_size], dim=-1)
            else:
                next_eef_pred = next_low_dim_pred
                next_qpos_pred = None

            self.dynamics_model.train()
            if self.pixel_encoder is not None:
                self.pixel_encoder.train()
            if self.low_dim_encoder is not None:
                self.low_dim_encoder.train()

            return next_eef_pred, next_qpos_pred

    
    # def save(self, filepath: str):
    #     """Save model checkpoint."""
    #     checkpoint = {
    #         'dynamics_model_state_dict': self.dynamics_model.state_dict(),
    #         'optimizer_state_dict': self.optimizer.state_dict(),
    #         'low_dim_size': self.low_dim_size,
    #         'action_dim': self.action_dim,
    #         'feat_dim': self.feat_dim,
    #     }
        
    #     # Save encoder states if they exist
    #     if self.pixel_encoder is not None:
    #         checkpoint['pixel_encoder_state_dict'] = self.pixel_encoder.state_dict()
    #     if self.low_dim_encoder is not None:
    #         checkpoint['low_dim_encoder_state_dict'] = self.low_dim_encoder.state_dict()
        
    #     torch.save(checkpoint, filepath)
    #     print(f"Model saved to {filepath}")
    
    # def load(self, filepath: str):
    #     """Load model checkpoint."""
    #     checkpoint = torch.load(filepath, map_location=self.device)
        
    #     self.dynamics_model.load_state_dict(checkpoint['dynamics_model_state_dict'])
    #     self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
    #     # Load encoder states if they exist
    #     if 'pixel_encoder_state_dict' in checkpoint and self.pixel_encoder is not None:
    #         self.pixel_encoder.load_state_dict(checkpoint['pixel_encoder_state_dict'])
    #     if 'low_dim_encoder_state_dict' in checkpoint and self.low_dim_encoder is not None:
    #         self.low_dim_encoder.load_state_dict(checkpoint['low_dim_encoder_state_dict'])
        
    #     print(f"Model loaded from {filepath}")
