#
# MIT License
#
# 
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.#
from storm_kit.mpc.model.urdf_kinematic_model import URDFKinematicModel
from typing import List, Tuple, Dict, Optional, Any
import torch
import torch.autograd.profiler as profiler

from ...differentiable_robot_model.differentiable_robot_model import DifferentiableRobotModel
from urdfpy import URDF
from .model_base import DynamicsModelBase
from .urdf_kinematic_model import URDFKinematicModel
from .integration_utils import build_int_matrix, build_fd_matrix, tensor_step_acc, tensor_step_vel, tensor_step_pos, tensor_step_jerk

class SphereBalanceModel(DynamicsModelBase):
    def __init__(self, urdf_path, dt, batch_size=1000, horizon=5,
                 tensor_args={'device':'cpu','dtype':torch.float32}, ee_link_name='ee_link', link_names=[], dt_traj_params=None, vel_scale=0.5, control_space='acc'):
        
        self.arm_model = URDFKinematicModel(urdf_path, dt, batch_size, horizon,
                                            tensor_args, ee_link_name, link_names, dt_traj_params, 
                                            vel_scale, control_space='acc')
        
        self.urdf_path = urdf_path
        self.device = tensor_args['device']

        self.float_dtype = tensor_args['dtype']
        self.tensor_args = tensor_args
        self.dt = dt
        self.ee_link_name = ee_link_name
        self.batch_size = batch_size
        self.horizon = horizon
        self.num_traj_points = int(round(horizon / dt))
        self.link_names = link_names
        self.traj_dt = self.arm_model.traj_dt

        self.n_dofs_arm = self.arm_model.n_dofs
        self.n_dofs_obj = 3
        self.n_dofs = self.n_dofs_arm + self.n_dofs_obj

        self.d_state = 3 * self.n_dofs + 1
        self.d_action = self.arm_model.d_action

        self._traj_tstep = self.arm_model._traj_tstep

    def get_next_state(self, curr_state: torch.Tensor, act:torch.Tensor, dt):
        """ Does a single step from the current state
        Args:
        curr_state: current state
        act: action
        dt: time to integrate
        Returns:
        next_state
        """

        # predict state for arm
        curr_state = self.arm_model.get_next_state(curr_state, act)
        
        #predict state for sphere



        return curr_state
    
    
    def tensor_step(self, state: torch.Tensor, act: torch.Tensor, state_seq: torch.Tensor, dt=None) -> torch.Tensor:
        """
        Args:
        state: [1,N]
        act: [H,N]
        todo:
        Integration  with variable dt along trajectory
        """
        inp_device = state.device
        state = state.to(self.device, dtype=self.float_dtype)
        act = act.to(self.device, dtype=self.float_dtype)

        state_seq = self.arm_model.tensor_step(state, act, state_seq, dt)

        
        return state_seq
        
        
    
    def rollout_open_loop(self, start_state: torch.Tensor, act_seq: torch.Tensor,
                          dt=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # batch_size, horizon, d_act = act_seq.shape
        curr_dt = self.dt if dt is None else dt
        curr_horizon = self.horizon
        # get input device:
        inp_device = start_state.device
        start_state = start_state.to(self.device, dtype=self.float_dtype)
        act_seq = act_seq.to(self.device, dtype=self.float_dtype)

        curr_batch_size = self.batch_size
        num_traj_points = self.num_traj_points



        #rollout arm
        state_dict = self.arm_model.rollout_open_loop(start_state, act_seq, dt)

        #rollout object
        # state_dict 
        state_dict['sphere_state_seq'] = torch.zeros(curr_batch_size, num_traj_points, self.n_dofs_obj).to(inp_device)

        return state_dict


    def enforce_bounds(self, state_batch):
        """
            Project state into bounds
        """
        batch_size = state_batch.shape[0]
        state_batch = self.arm_model.enforce_bounds(state_batch)
        return state_batch

    def integrate_action(self, act_seq):
        nth_act_seq = self.arm_model.integrate_action(act_seq)
        return nth_act_seq

    def integrate_action_step(self, act, dt):
        act = self.arm_model.integrate_action_step(act, dt)
        
        return act

    #Rendering
    def render(self, state):
        pass

    def render_trajectory(self, state_list):
        pass

