# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional

import numpy as np
import torch
from pydantic import Field
from scipy.spatial.transform import Rotation as R

from gr00t.data.transform.base import ModalityTransform


class DvrkActionTransform(ModalityTransform):
    """
    Transform for DVRK actions to convert from absolute positions to relative actions.
    Converts quaternion rotations to axis-angle format.
    """

    apply_to: List[str] = Field(
        ..., description="List of action keys to apply the transform to."
    )
    
    state_keys: List[str] = Field(
        ..., description="List of state keys corresponding to the action keys."
    )
    
    def __call__(self, data: Dict) -> Dict:
        """Apply the transform to the data."""
        return self.apply(data)
    
    def apply(self, data: Dict) -> Dict:
        """
        Apply the DVRK action transform to convert absolute actions to relative actions.
        
        Args:
            data: Dictionary containing state and action data
            
        Returns:
            Dictionary with transformed actions
        """
        # Process each action key and its corresponding state key
        for action_key, state_key in zip(self.apply_to, self.state_keys):
            if action_key not in data or state_key not in data:
                continue
                
            # Get the state and action tensors
            state = data[state_key]  # Shape: [T, 8] - xyz, xyzw, jaw angle
            action = data[action_key]  # Shape: [T, 8] - xyz, xyzw, jaw angle
            
            # Convert tensors to numpy for processing with scipy
            # state_np = state.cpu().numpy()
            # action_np = action.cpu().numpy()
            if isinstance(state, torch.Tensor):
                state_np = state.cpu().numpy()
            else:
                state_np = state
            if isinstance(action, torch.Tensor):
                action_np = action.cpu().numpy()
            else:
                action_np = action

            # Initialize array to store transformed actions
            transformed_actions = np.zeros((action_np.shape[0], 7))
            
            # For each timestep, compute the relative action
            for t in range(action_np.shape[0]):
                # Get current state (qpos)
                qpos = state_np[min(t, state_np.shape[0] - 1)]
                
                # Get the action for this timestep and reshape to [1, 8]
                action_t = action_np[t:t+1]
                
                # Compute the relative action
                transformed_action = self.compute_diff_actions(qpos, action_t)
                transformed_actions[t] = transformed_action[0]
            
            # Convert back to tensor and update the data dictionary
            data[action_key] = transformed_actions

        return data
    
    @staticmethod
    def compute_diff_actions(qpos, action):
        """
        Computes the relative actions with respect to the current position using axis-angle rotation.

        Parameters:
        - qpos: Current pose (array of shape [8] - xyz, xyzw, jaw angle)
        - action: Actions commanded by the user (array of shape [n_actions x 8] - xyz, xyzw, jaw angle)

        Returns:
        - diff_expand: Relative actions with delta translation and delta rotation in axis-angle format.
                    Shape: (n_actions, 7) - [delta_translation, delta_rotation, jaw_angle]
        """
        # Compute the delta translation w.r.t da vinci endoscope tip frame (approx the camera frame)
        delta_translation = action[:, 0:3] - qpos[0:3]  # Shape: (n_actions, 3)

        # Extract quaternions from qpos and action
        quat_init = qpos[3:7]          # Shape: (4,)
        quat_actions = action[:, 3:7]  # Shape: (n_actions, 4)

        # Convert quaternions to Rotation objects
        r_init = R.from_quat(quat_init)
        r_actions = R.from_quat(quat_actions)

        # Compute the relative rotations
        diff_rs = r_init.inv() * r_actions  # Shape: (n_actions,)

        # Convert the rotation differences to rotation vectors (axis-angle representation)
        delta_rotation = diff_rs.as_rotvec()  # Shape: (n_actions, 3)

        # Extract the jaw angle from the action (note: jaw angle is not relative)
        jaw_angle = action[:, -1]  # Shape: (n_actions,)

        # Prepare the final diff array
        delta_action = np.zeros((action.shape[0], 7))  # Shape: (n_actions, 7)

        # Populate the diff_expand array
        delta_action[:, 0:3] = delta_translation       # Delta translation
        delta_action[:, 3:6] = delta_rotation          # Delta rotation (axis-angle)
        delta_action[:, 6] = jaw_angle                 # Jaw angle (not relative)

        return delta_action