#!/usr/bin/env python

import torch
import numpy as np
from pydy.system import System

from sympy.physics.mechanics import KanesMethod
from src.multibody_sim.kinetics import *
from src.multibody_sim.visualization import create_visual_frame, plot_trajectories



# List the symbolic arguments
# ===========================

# Constants
# ---------

constants = [thigh_length, thigh_com_dist, thigh_mass, thigh_inertia, 
            shank_length, shank_com_dist, shank_mass, shank_inertia,
            foot_length, foot_com_dist, foot_mass, foot_inertia,
            torso_com_dist, torso_mass, torso_inertia,
            g ]


# Time Varying
# ------------

coordinates = [x, y, theta_pelvis, r_theta_hip, r_theta_knee, r_theta_ankle,
                l_theta_hip, l_theta_knee, l_theta_ankle]

speeds = [hip_vel_x, hip_vel_y, omega_torso, r_omega_hip, r_omega_knee, r_omega_ankle,
            l_omega_hip, l_omega_knee, l_omega_ankle]


kane = KanesMethod(ground_frame,
                   coordinates,
                   speeds,
                   kinematical_differential_equations)

loads = [right_ground_force,
        left_ground_force,
        torso_grav_force,
        right_thigh_grav_force,
        right_shank_grav_force,
        right_foot_grav_force,
        left_thigh_grav_force,
        left_shank_grav_force,
        left_foot_grav_force,
        torso_torque,
        right_thigh_torque,
        right_shank_torque,
        right_foot_torque,
        left_thigh_torque,
        left_shank_torque,
        left_foot_torque]

bodies = [torso, right_thigh, right_shank, right_foot, left_thigh, left_shank, left_foot]
fr, frstar = kane.kanes_equations(bodies, loads)

mass_matrix = kane.mass_matrix_full
forcing_vector = kane.forcing_full


## given the forces and torques calculates 
#  the forward dynamics and trajectories
def simulate(y_d, y_dl, num_constants, hip_data, initial_conds, duration):

    # check if the constants are identical (belong to same model)
    first_tensor = num_constants[0]
    are_identical = torch.all(torch.eq(num_constants, first_tensor))

    if not are_identical:
        raise ValueError("The data do not belong the same subject, constants are different !")

    nSample = y_d.shape[1]

    # create constants dictionary for scene.constants 
    constants_dict = dict(zip(constants, num_constants[0].numpy()))

    # create system from kane method
    sys = System(kane)

    # delete foot length, not required for kane
    del_constants_dict = constants_dict.copy()
    del del_constants_dict[foot_length]
    del del_constants_dict[foot_com_dist] 

    # set system constants 
    sys.constants = del_constants_dict

    # set system times
    sys.times = np.linspace(0.0, duration, num=nSample)

    # set initial conditions 
    sys.initial_conditions = initial_conds

    fSample = (nSample-1) / duration

    # creating force and torque functions for simulation
    sys.specifieds = {right_hip_torque: lambda x, t: y_d[0, int(t * fSample), 3].numpy(),
                      right_knee_torque: lambda x, t: y_d[0, int(t * fSample), 4].numpy(),
                      right_ankle_torque: lambda x, t: y_d[0, int(t * fSample), 5].numpy(),
      
                      left_hip_torque: lambda x, t: y_dl[0, int(t * fSample), 3].numpy(),
                      left_knee_torque: lambda x, t: y_dl[0, int(t * fSample), 4].numpy(),
                      left_ankle_torque: lambda x, t: y_dl[0, int(t * fSample), 5].numpy(),
      
                      right_ground_force_x: lambda x, t: y_d[0, int(t * fSample), 6].numpy(),
                      right_ground_force_y: lambda x, t: y_d[0, int(t * fSample), 7].numpy(),
      
                      left_ground_force_x: lambda x, t: y_dl[0, int(t * fSample), 6].numpy(),
                      left_ground_force_y: lambda x, t: y_dl[0, int(t * fSample), 7].numpy(),
      
                      right_center_of_press_x: lambda x, t: hip_data[0, int(t * fSample), -4].numpy(),
                      right_center_of_press_y: lambda x, t: hip_data[0, int(t * fSample), -3].numpy(),
      
                      left_center_of_press_x: lambda x, t: hip_data[0, int(t * fSample), -2].numpy(),
                      left_center_of_press_y: lambda x, t: hip_data[0, int(t * fSample), -1].numpy(),
                      }
    
    # # the method about is supposed to work but didnt, maybe they implemented it insufficiently !
    # sys.specifieds = {'symbols':(right_hip_torque,right_knee_torque,right_ankle_torque,
    #                             left_hip_torque,left_knee_torque,left_ankle_torque,
    #                             right_ground_force_x,right_ground_force_y,
    #                             left_ground_force_x,left_ground_force_y,
    #                             right_center_of_press_x, right_center_of_press_y,
    #                             left_center_of_press_x, left_center_of_press_y),

    #                     'values': np.concatenate((y_d[0,:,3:6],y_dl[0,:,3:6],
    #                             y_d[0,:,6:],y_dl[0,:,6:],hip_data[0,:,-4:]), axis=1) }


    # performing simulation of system
    # sys.generate_ode_function(generator='cython')
    y_traj = sys.integrate()

    # adding the center of pressure data(cop) to traj for later visualisation 
    cop = hip_data[:,:,-4:]
    cop = cop.reshape((cop.shape[0]*cop.shape[1], cop.shape[2]))
    y_traj = np.concatenate((y_traj,cop.numpy()), axis=1)

    return y_traj, sys


## method for visualisation of simulation results 
def visualize_simulation(y_traj, num_constants, t, slower=1):

    # creating scene object for virtual model
    scene = create_visual_frame(num_constants)

    # create constants dictionary for scene.constants 
    constants_dict = dict(zip(constants, num_constants[0].numpy()))

    scene.constants = constants_dict
    scene.times = t * slower
    scene.states_symbols = coordinates + speeds + [right_center_of_press_x, right_center_of_press_y,
                                                    left_center_of_press_x, left_center_of_press_y]
    scene.states_trajectories = y_traj

    return scene









# # defining some aditional variables 
# der_hip_vel_x, der_hip_vel_y = hip_vel_x.diff(), hip_vel_y.diff()


# specified = [right_hip_torque, right_knee_torque, right_ankle_torque, 
#              right_ground_force_x, right_ground_force_y,
#              left_hip_torque, left_knee_torque, left_ankle_torque, 
#              left_ground_force_x, left_ground_force_y]

# defining some variables which are K_loss input data
# this data is based on equation_of_motion() function 

# der_x, der_y = x.diff(), y.diff()
# der_hip_vel_x, der_hip_vel_y = hip_vel_x.diff(), hip_vel_y.diff()

# der_r_theta_hip, der_r_theta_knee, der_r_theta_ankle = r_theta_hip.diff(), r_theta_knee.diff(), r_theta_ankle.diff()
# der_r_omega_hip, der_r_omega_knee, der_r_omega_ankle = r_omega_hip.diff(), r_omega_knee.diff(), r_omega_ankle.diff()

# der_l_theta_hip, der_l_theta_knee, der_l_theta_ankle = l_theta_hip.diff(), l_theta_knee.diff(), l_theta_ankle.diff()
# der_l_omega_hip, der_l_omega_knee, der_l_omega_ankle = l_omega_hip.diff(), l_omega_knee.diff(), l_omega_ankle.diff()


# symbols_hip_data = [theta_pelvis,
#                     x, y,
#                     der_x, der_y,                                      # this &
#                     hip_vel_x, hip_vel_y,                              # this are same actually                       
#                     der_hip_vel_x, der_hip_vel_y,
#                     right_center_of_press_x, right_center_of_press_y,
#                     left_center_of_press_x, left_center_of_press_y]  

# symbols_est_data = [r_theta_hip, r_theta_knee, r_theta_ankle,
#                     right_hip_torque, right_knee_torque, right_ankle_torque, 
#                     right_ground_force_x, right_ground_force_y]

