#
# MIT License
#
# 
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.#
import torch
import torch.autograd.profiler as profiler

from ...differentiable_robot_model.coordinate_transform import matrix_to_quaternion, quaternion_to_matrix, CoordinateTransform, multiply_transform, multiply_inv_transform
from ..cost import DistCost, PoseCost, ZeroCost, FiniteDifferenceCost
from ...mpc.rollout.arm_base import ArmBase
from ...mpc.model.integration_utils import build_int_matrix, build_fd_matrix, tensor_step_acc, tensor_step_acc_offset_vel
from ...mpc.model.urdf_kinematic_model import URDFKinematicModel
from ...mpc.model.kinematic_model import KinematicModel
from ...util_file import join_path, get_assets_path

class ArmBalance(ArmBase):
    """
    This rollout function is for reaching a cartesian pose for a robot

    Todo: 
    1. Update exp_params to be kwargs
    """

    def __init__(self, exp_params, tensor_args={'device':"cpu", 'dtype':torch.float32}, world_params=None):
        
        
        super(ArmBalance, self).__init__(exp_params=exp_params,
                                         tensor_args=tensor_args,
                                         world_params=world_params)
        self.goal_state = None
        self.goal_ee_pos = None
        self.goal_ee_rot = None
        self.object_state_world = None
        self.object_quat_world = None
        self.object_rot_world = None
        self.object_vel_world = None
        mppi_params = exp_params['mppi']
        
        # self.object_pos_ee = None
        # self.object_rot_ee = None
        
        self.device = self.tensor_args['device']
        self.float_dtype = self.tensor_args['dtype']
        self.gravity_world = 0.1*torch.tensor([0.0, 0.0, -9.81], device=self.device, dtype=self.float_dtype) #0.05*torch.tensor([0.0, 0.0, -9.81], device=self.device, dtype=self.float_dtype)

        self.dist_cost = DistCost(**self.exp_params['cost']['joint_l2'], device=self.device,float_dtype=self.float_dtype)

        self.goal_cost = PoseCost(**exp_params['cost']['goal_pose'],
                                  tensor_args=self.tensor_args)
        self.obj_cost = DistCost(**self.exp_params['cost']['object_goal_cost'], device=self.device, float_dtype=self.float_dtype)
        self.obj_vel_cost = DistCost(**self.exp_params['cost']['object_vel_cost'], device=self.device, float_dtype=self.float_dtype)

        self.step_fn = tensor_step_acc         
        # self.step_fn = tensor_step_acc_offset_vel

        self.num_traj_points = self.dynamics_model.num_traj_points

        self._integrate_matrix = build_int_matrix(self.num_traj_points, device=self.device, dtype=self.float_dtype)

        self._fd_matrix = build_fd_matrix(self.num_traj_points, device=self.device,
                                          dtype=self.float_dtype, order=1)
        self.dt_traj_params = self.dynamics_model.dt_traj_params
        self.dt = self.dynamics_model.dt
        self._dt_h = self.dynamics_model._dt_h
        self.dt_traj = self.dynamics_model.dt_traj
        self.traj_dt = self.dynamics_model.traj_dt

        self._traj_tstep = torch.matmul(self._integrate_matrix, self._dt_h)
        
        self.action_order = 0
        self._integrate_matrix_nth = self.dynamics_model._integrate_matrix_nth
        self._nth_traj_dt = self.dynamics_model._nth_traj_dt #torch.pow(self.traj_dt, self.action_order)

        assets_path = get_assets_path()
        model_params = exp_params['object_model']
        dynamics_horizon = mppi_params['horizon'] * model_params['dt']
        # self.object_dynamics_model = KinematicModel(n_dofs=exp_params['object_model']['n_dofs'],
        #                                             dt=exp_params['object_model']['dt'],
        #                                             batch_size=mppi_params['num_particles'],
        #                                             horizon=dynamics_horizon,
        #                                             tensor_args=self.tensor_args,
        #                                             ee_link_name=exp_params['object_model']['ee_link_name'],
        #                                             link_names=exp_params['object_model']['link_names'],
        #                                             dt_traj_params=exp_params['object_model']['dt_traj_params'],
        #                                             control_space=exp_params['control_space'],
        #                                             vel_scale=exp_params['object_model']['vel_scale'])
        self.n_dofs_obj = exp_params['object_model']['n_dofs']
        self.obj_z_offset = exp_params['object_model']['z_offset']
        self.obj_state_seq = torch.zeros((self.dynamics_model.batch_size, mppi_params['horizon'], 3*self.n_dofs_obj), dtype=self.float_dtype, device=self.device)


    def cost_fn(self, state_dict, action_batch, no_coll=False, horizon_cost=True, return_dist=False):
        cost = super(ArmBalance, self).cost_fn(state_dict, action_batch, no_coll, horizon_cost)

        ee_pos_batch, ee_rot_batch = state_dict['ee_pos_seq'], state_dict['ee_rot_seq']

        state_batch = state_dict['state_seq']
        goal_ee_pos = self.goal_ee_pos
        goal_ee_rot = self.goal_ee_rot
        retract_state = self.retract_state
        goal_state = self.goal_state
        
        
        goal_cost, rot_err_norm, goal_dist = self.goal_cost.forward(ee_pos_batch, ee_rot_batch,
                                                                    goal_ee_pos, goal_ee_rot)


        cost += goal_cost
        
        # joint l2 cost
        if(self.exp_params['cost']['joint_l2']['weight'] > 0.0 and goal_state is not None):
            disp_vec = state_batch[:,:,0:self.n_dofs] - goal_state[:,0:self.n_dofs]
            cost += self.dist_cost.forward(disp_vec)

        if(return_dist):
            return cost, rot_err_norm, goal_dist

            
        if self.exp_params['cost']['zero_acc']['weight'] > 0:
            cost += self.zero_acc_cost.forward(state_batch[:, :, self.n_dofs*2:self.n_dofs*3], goal_dist=goal_dist)

        if self.exp_params['cost']['zero_vel']['weight'] > 0:
            cost += self.zero_vel_cost.forward(state_batch[:, :, self.n_dofs:self.n_dofs*2], goal_dist=goal_dist)
        
        #object cost
        if('object_state_seq' in state_dict):

            obj_state_batch = state_dict['object_state_seq']
            obj_pos_batch = obj_state_batch[:,:,0:self.n_dofs_obj]
            des_pos_batch = torch.zeros_like(obj_pos_batch)
            des_pos_batch[:,:,-1] = self.obj_z_offset#-0.005 - 0.0375 - 0.01 
            disp_vec = obj_pos_batch - des_pos_batch# - ee_pos_batch[:,:,0:self.n_dofs_obj]
            # print('cost', disp_vec[:,:,-1])

            obj_disp_cost = self.obj_cost.forward(disp_vec)
            cost += obj_disp_cost


            obj_vel_batch = obj_state_batch[:,:,self.n_dofs_obj:2*self.n_dofs_obj]
            cost += self.obj_vel_cost(obj_vel_batch)

        return cost


    def predict_next_obj_pose(self, object_state_world, ee_pos_world, ee_rot_world, dt):
        inp_device = object_state_world.device
        # batch_size, _ = object_state_world.shape
        object_state_world = object_state_world.to(self.device, dtype=self.float_dtype)
        object_pos_world = object_state_world[0:self.n_dofs_obj]
        object_vel_world = object_state_world[self.n_dofs_obj:2*self.n_dofs_obj]
        object_acc_world = object_state_world[2*self.n_dofs_obj:3*self.n_dofs_obj]
        object_rot_world = torch.eye(3,3, device=self.device, dtype=self.float_dtype)
        tstep = torch.tensor([object_state_world[-1]], device=self.device, dtype=self.float_dtype)
        ee_pos_world = ee_pos_world.to(self.device, dtype=self.float_dtype)
        ee_rot_world = ee_rot_world.to(self.device, dtype=self.float_dtype)


        #calculate tangential acceleration due to gravity
        grav_ee = (ee_rot_world.transpose(-1,-2)) @ self.gravity_world #gravity in ee frame
        acc_ee = grav_ee
        acc_ee[-1] = 0.0 #tangential acceleration

        #calculate object position in ee frame
        object_rot_ee, object_pos_ee = multiply_inv_transform(ee_rot_world.unsqueeze(0), ee_pos_world.unsqueeze(0), object_rot_world.unsqueeze(0), object_pos_world.unsqueeze(0))
        #calculate object velocity and acc in ee frame
        object_vel_ee = (ee_rot_world.transpose(-1,-2)) @ object_vel_world.t()
        object_acc_ee = (ee_rot_world.transpose(-1,-2)) @ object_acc_world.t()
        #set normal component of velocity and acceleration to zero
        object_vel_ee[-1] = 0.0
        object_acc_ee[-1] = 0.0

        obj_state_ee = torch.cat((object_pos_ee.squeeze(0), object_vel_ee.t(), object_acc_ee.t()))

        #integrate object state in ee frame
        obj_state_ee[2 * self.n_dofs_obj:3 * self.n_dofs_obj] = acc_ee
        obj_state_ee[self.n_dofs_obj:2*self.n_dofs_obj] = obj_state_ee[self.n_dofs_obj:2*self.n_dofs_obj] + obj_state_ee[self.n_dofs_obj*2:self.n_dofs_obj*3] * dt
        
        obj_state_ee[:self.n_dofs_obj] = obj_state_ee[:self.n_dofs_obj] + obj_state_ee[self.n_dofs_obj:2*self.n_dofs_obj] * dt
        
        #set z component of position to offset
        obj_state_ee[2] = self.obj_z_offset

        #convert back to world 
        obj_rot_world, object_pos_world = multiply_transform(ee_rot_world.unsqueeze(0), ee_pos_world.unsqueeze(0), object_rot_ee.unsqueeze(0), obj_state_ee[0:self.n_dofs_obj].unsqueeze(0))
        
        
        object_vel_world = ee_rot_world @ obj_state_ee[self.n_dofs_obj:2*self.n_dofs_obj]
        object_acc_world = ee_rot_world @ obj_state_ee[2*self.n_dofs_obj:]

        object_state_world = torch.cat((object_pos_world.squeeze(0), object_vel_world, object_acc_world, tstep))

        return object_state_world, object_rot_world



    def rollout_object(self, object_state_world, object_rot_world, ee_pos_batch_world, ee_rot_batch_world, state_batch, lin_jac_batch):
        #get start state
        object_state_world = object_state_world.to(self.device)
        ee_pos_batch_world = ee_pos_batch_world.to(self.device) 
        ee_rot_batch_world = ee_rot_batch_world.to(self.device)
        lin_jac_batch = lin_jac_batch.to(self.device)
        batch_size, horizon, _ = ee_pos_batch_world.shape
        object_state_world = object_state_world.to(self.device, dtype=self.float_dtype)
        object_pos_world = object_state_world[:, 0:self.n_dofs_obj].repeat(batch_size, 1)
        object_vel_world = object_state_world[:, self.n_dofs_obj:2*self.n_dofs_obj].repeat(batch_size, 1)
        object_acc_world = object_state_world[:, 2*self.n_dofs_obj:3*self.n_dofs_obj].repeat(batch_size, 1)
        object_rot_world = object_rot_world.squeeze(0).repeat(batch_size,1,1)
        # print(object_rot_world)
        # object_rot_world = torch.eye(3,3, device=self.device, dtype=self.float_dtype).unsqueeze(0).repeat(batch_size, 1,1)
        tsteps = object_state_world[:,-1].repeat(batch_size, 1)


        #calculate tangential acceleration due to gravity
        grav_ee = (ee_rot_batch_world.transpose(-1,-2)) @ self.gravity_world #gravity in ee frame
        acc_ee = grav_ee
        acc_ee[:,:,-1] = 0.0 #tangential acceleration

        #calculate end effector velocity in world frame
        q_dot = state_batch[:,:,self.n_dofs:2 * self.n_dofs]

        ee_vel_batch_world = torch.matmul(lin_jac_batch, q_dot.unsqueeze(-1)).squeeze(-1)

        # #calculate ee_velocity in ee_frame
        ee_vel_batch_ee = (ee_rot_batch_world.transpose(-1,-2) @ ee_vel_batch_world.unsqueeze(-1)).squeeze(-1)

        #calculate object velocity and acc in ee frame
        object_vel_ee = (ee_rot_batch_world[:,0,:,:].transpose(-1,-2) @ object_vel_world.unsqueeze(2)).squeeze(-1)
        # object_vel_ee = (ee_rot_batch_world[:,0,:,:].transpose(-1,-2) @ object_vel_relative.unsqueeze(2)).squeeze(-1)
        object_acc_ee = (ee_rot_batch_world[:,0,:,:].transpose(-1,-2) @ object_acc_world.unsqueeze(2)).squeeze(-1)
        #set normal component of vel and acc to be zero

        object_vel_ee[:,-1] = 0.0
        object_acc_ee[:,-1] = 0.0
        # print('before', object_vel_ee)
        # print('berfore ee', ee_vel_batch_ee[:,0])

        #calculate object position in ee frame
        object_rot_ee, object_pos_ee = multiply_inv_transform(ee_rot_batch_world[:,0,:,:], ee_pos_batch_world[:,0,:], object_rot_world, object_pos_world)
        #object start state in ee frame
        obj_state_ee = torch.cat((object_pos_ee, object_vel_ee, object_acc_ee, tsteps), dim=-1)
        #rollout in ee frame
        # object_state_seq_ee = self.step_fn(obj_state_ee, acc_ee, self.obj_state_seq, self._dt_h, self.n_dofs_obj, self._integrate_matrix, ee_vel_batch_ee)
        object_state_seq_ee = self.step_fn(obj_state_ee, acc_ee, self.obj_state_seq, self._dt_h, self.n_dofs_obj, self._integrate_matrix)

        # print('after', object_state_seq_ee[:,0,self.n_dofs_obj:2*self.n_dofs_obj])

        #set normal position to offset velocity and acceleration to zero
        object_state_seq_ee[:,:,2] = self.obj_z_offset
        # object_state_seq_ee[:,:,5] = 0.0
        # object_state_seq_ee[:,:,8] = 0.0
 

        #convert back to world
        # object_pos_seq_ee = object_state_seq_ee[:,:,0:self.n_dofs_obj]
        # object_vel_seq_ee = object_state_seq_ee[:,:,self.n_dofs_obj:2*self.n_dofs_obj]
        # object_acc_seq_ee = object_state_seq_ee[:,:,2*self.n_dofs_obj:3*self.n_dofs_obj]
        # object_rot_seq_ee = object_rot_ee.unsqueeze(1).repeat(1, horizon, 1, 1)


        # obj_rot_seq_world, object_pos_seq_world = multiply_transform(ee_rot_batch_world, ee_pos_batch_world, object_rot_seq_ee, object_pos_seq_ee)
        object_rot_seq_world = object_rot_world.unsqueeze(1).repeat(1, horizon, 1, 1)


        # object_vel_world = ee_rot_batch_world @ object_vel_seq_ee 
        # object_acc_world = ee_rot_batch_world @ object_acc_seq_ee
        # print(object_vel_world.shape, object_acc_world.shape)
        # object_state_world = torch.cat((object_pos_world.squeeze(0), object_vel_world, object_acc_world, tstep))

        # object_trans_ee = (ee_rot_transpose @ object_pos_world.unsqueeze(-1)).squeeze(-1) + world_trans_ee

        # # object_rot_ee, object_trans_ee = multiply_inv_transform(ee_rot_batch[:,0], ee_pos_batch[:,0], object_rot_world, object_pos_world)
        # object_start_rot_ee, object_start_trans_ee = multiply_inv_transform(start_ee_rot, start_ee_pos, object_rot_world, object_pos_world)

        # #assume zero vel and acceleration for now
        # object_start_vel_ee = torch.zeros_like(object_start_trans_ee)
        # object_start_acc_ee = torch.zeros_like(object_start_trans_ee)

        # tstep = (start_robot_state[0,-1]).unsqueeze(-1).unsqueeze(-1)

        return object_state_seq_ee, object_rot_seq_world




    def rollout_fn(self, start_state, act_seq):
        """
        Return sequence of costs and states encountered
        by simulating a batch of action sequences

        Parameters
        ----------
        action_seq: torch.Tensor [num_particles, horizon, d_act]
        """

        with profiler.record_function("robot_model"):
            state_dict = self.dynamics_model.rollout_open_loop(start_state, act_seq)
        

        with profiler.record_function("object_model"):
            object_state_seq, object_rot_seq = self.rollout_object(self.object_state_world,
                                                                   self.object_rot_world, 
                                                                   state_dict['ee_pos_seq'],
                                                                   state_dict['ee_rot_seq'],
                                                                   state_dict['state_seq'],
                                                                   state_dict['lin_jac_seq'])
            state_dict['object_state_seq'] = object_state_seq
            state_dict['object_rot_seq'] = object_rot_seq
            # state_dict['prev_object_state_seq'] = object_state_dict['prev_state_seq']



        #link_pos_seq, link_rot_seq = self.dynamics_model.get_link_poses()
        with profiler.record_function("cost_fns"):
            cost_seq = self.cost_fn(state_dict, act_seq)

        sim_trajs = dict(
            actions=act_seq,#.clone(),
            costs=cost_seq,#clone(),
            ee_pos_seq=state_dict['ee_pos_seq'],#.clone(),
            #link_pos_seq=link_pos_seq,
            #link_rot_seq=link_rot_seq,
            rollout_time=0.0
        )
        
        return sim_trajs


    def update_params(self, retract_state=None, goal_state=None, goal_ee_pos=None, goal_ee_rot=None, goal_ee_quat=None,
                        object_state=None, object_quat=None, object_rot=None):
        """
        Update params for the cost terms and dynamics model.
        goal_state: n_dofs
        goal_ee_pos: 3
        goal_ee_rot: 3,3
        goal_ee_quat: 4

        """
        
        super(ArmBalance, self).update_params(retract_state=retract_state)
        
        if goal_ee_pos is not None:
            self.goal_ee_pos = torch.as_tensor(goal_ee_pos, **self.tensor_args).unsqueeze(0)
            self.goal_state = None
        
        if goal_ee_rot is not None:
            self.goal_ee_rot = torch.as_tensor(goal_ee_rot, **self.tensor_args).unsqueeze(0)
            self.goal_ee_quat = matrix_to_quaternion(self.goal_ee_rot)
            self.goal_state = None
        
        if goal_ee_quat is not None:
            self.goal_ee_quat = torch.as_tensor(goal_ee_quat, **self.tensor_args).unsqueeze(0)
            self.goal_ee_rot = quaternion_to_matrix(self.goal_ee_quat)
            self.goal_state = None
        
        if goal_state is not None:
            self.goal_state = torch.as_tensor(goal_state, **self.tensor_args).unsqueeze(0)
            self.goal_ee_pos, self.goal_ee_rot = self.dynamics_model.robot_model.compute_forward_kinematics(self.goal_state[:,0:self.n_dofs], self.goal_state[:,self.n_dofs:2*self.n_dofs], link_name=self.exp_params['model']['ee_link_name'])
            self.goal_ee_quat = matrix_to_quaternion(self.goal_ee_rot)
        
        if object_state is not None:
            self.object_state_world = torch.as_tensor(object_state, **self.tensor_args).unsqueeze(0)
        
        if object_quat is not None:
            self.object_quat_world = torch.as_tensor(object_quat, **self.tensor_args).unsqueeze(0)
            self.object_rot_world = quaternion_to_matrix(self.object_quat_world)

        if object_rot is not None:
            self.object_rot_world = torch.as_tensor(object_rot, **self.tensor_args).unsqueeze(0)
            self.object_quat_world = matrix_to_quaternion(self.object_rot_world)

        return True
    



    # def rollout_object(self, object_pos_world, object_rot_world, state_dict):
    #     #get start state
    #     start_robot_state = state_dict['start_state']
    #     start_ee_pos, start_ee_rot = self.dynamics_model.robot_model.compute_forward_kinematics(start_robot_state[:,0:self.n_dofs], start_robot_state[:,self.n_dofs:2*self.n_dofs], link_name=self.exp_params['model']['ee_link_name'])

    #     ee_pos_batch, ee_rot_batch = state_dict['ee_pos_seq'], state_dict['ee_rot_seq']

    #     inp_device = object_pos_world.device
    #     object_pos_world = object_pos_world.to(self.device, dtype=self.float_dtype)
    #     object_rot_world = object_rot_world.to(self.device, dtype=self.float_dtype)
    #     ee_pos_batch = ee_pos_batch.to(self.device, dtype=self.float_dtype)
    #     ee_rot_batch = ee_rot_batch.to(self.device, dtype=self.float_dtype)

    #     batch_size, horizon, _ = ee_pos_batch.shape
        
    #     #get gravity accelerations tangential to end effector
    #     grav_ee_batch = (ee_rot_batch.transpose(-1,-2)) @ self.gravity_world #gravity in ee frame
    #     grav_ee_batch[:,:,-1] = 0.0

    #     #transform start object pose to end effector frame
    #     # ee_rot_transpose = ee_rot_batch.transpose(-1,-2)[:,0]
    #     # object_rot_ee = ee_rot_transpose @ object_rot_world
    #     # world_trans_ee = -(ee_rot_transpose @ ee_pos_batch[:,0].unsqueeze(2)).squeeze(2)
 

    #     # object_trans_ee = (ee_rot_transpose @ object_pos_world.unsqueeze(-1)).squeeze(-1) + world_trans_ee


    #     # object_rot_ee, object_trans_ee = multiply_inv_transform(ee_rot_batch[:,0], ee_pos_batch[:,0], object_rot_world, object_pos_world)
    #     object_start_rot_ee, object_start_trans_ee = multiply_inv_transform(start_ee_rot, start_ee_pos, object_rot_world, object_pos_world)

    #     # print('shapes', object_start_rot_ee.shape, object_start_trans_ee.shape)
    #     #assume zero vel and acceleration for now
    #     object_start_vel_ee = torch.zeros_like(object_start_trans_ee)
    #     object_start_acc_ee = torch.zeros_like(object_start_trans_ee)
    #     # print('start_robot_state', start_robot_state.shape)

    #     tstep = (start_robot_state[0,-1]).unsqueeze(-1).unsqueeze(-1)

    #     # print('object trans', object_start_trans_ee.shape, tstep.shape, tstep)
    #     object_start_state = torch.cat((object_start_trans_ee, object_start_vel_ee, object_start_acc_ee, tstep), dim=-1)
    #     # print('object_start_state', object_start_state.shape)


    #     object_state_dict = self.object_dynamics_model.rollout_open_loop(object_start_state, grav_ee_batch)
    #     # print('obj state batch', object_state_dict)       
    #     # object_pos_batch = object_pos_world.repeat(batch_size, horizon, 1)
    #     # object_rot_batch = object_rot_world.repeat(batch_size, horizon, 1, 1)
    #     return object_state_dict