import torch
import numpy as np
from numpy import deg2rad, rad2deg
from matplotlib.pyplot import *
from sympy.physics.vector import vlatex
from pydy.viz.shapes import *
from pydy.viz.visualization_frame import VisualizationFrame
from pydy.viz import Scene
from src.multibody_sim.kinetics import *
# from src.multibody_sim.simulation import *

UnitDic = {"angle": "degree", "qdot": "rad", "moment": "Nm", "GRF": "N", "translation": "m"}


# defining some required variables
constants = [thigh_length, thigh_com_dist, thigh_mass, thigh_inertia, 
            shank_length, shank_com_dist, shank_mass, shank_inertia,
            foot_length, foot_com_dist, foot_mass, foot_inertia,
            torso_com_dist, torso_mass, torso_inertia,
            g ]

coordinates = [x, y, theta_pelvis, r_theta_hip, r_theta_knee, r_theta_ankle,
                l_theta_hip, l_theta_knee, l_theta_ankle]

speeds = [hip_vel_x, hip_vel_y, omega_torso, r_omega_hip, r_omega_knee, r_omega_ankle,
            l_omega_hip, l_omega_knee, l_omega_ankle]

der_x, der_y = x.diff(), y.diff()
der_hip_vel_x, der_hip_vel_y = hip_vel_x.diff(), hip_vel_y.diff()

symbols_hip_data = [theta_pelvis,
                    x, y,
                    der_x, der_y,                                      # this &
                    hip_vel_x, hip_vel_y,                              # this are same actually                       
                    der_hip_vel_x, der_hip_vel_y,
                    right_center_of_press_x, right_center_of_press_y,
                    left_center_of_press_x, left_center_of_press_y]  

symbols_est_data_right = [r_theta_hip, r_theta_knee, r_theta_ankle,
                         right_hip_torque, right_knee_torque, right_ankle_torque, 
                         right_ground_force_x, right_ground_force_y]

symbols_est_data_left = [r_theta_hip, r_theta_knee, r_theta_ankle,
                         right_hip_torque, right_knee_torque, right_ankle_torque, 
                         right_ground_force_x, right_ground_force_y]



def create_visual_frame(batch_data, system=None):

    if isinstance(batch_data['body_constants'], torch.Tensor):
        batch_data['body_constants'] = batch_data['body_constants'].numpy()
        
    # check if the constants are identical (belong to same model)
    if batch_data['body_constants'].ndim >=2:
        first_array = batch_data['body_constants'][0]
        are_identical = np.all(np.equal(batch_data['body_constants'], first_array))

        if not are_identical:
            raise ValueError("The data do not belong the same subject, constants are different !")
    
    constants_dict = dict(zip(constants, batch_data['body_constants'][0]))

    # some constants related foot & this part should be laater varying !!
    foot_height = 0.0702
    heel_dist = 0.06
    toe_dist = 0.1636
    radius = foot_height / 2

    # creating ankle,knee,hip shapes 
    r_ankle_shape = Sphere(name='r_ankle', color='black', radius=radius)
    r_knee_shape = Sphere(name='r_knee', color='black', radius=radius)
    l_ankle_shape = Sphere(name='l_ankle', color='black', radius=radius)
    l_knee_shape = Sphere(name='l_knee', color='black', radius=radius)
    hip_shape = Sphere(name='hip_shape', color='black', radius=radius)

    # creating visualisation frames for ankle,knee,hip shapes
    r_ankle_viz_frame = VisualizationFrame('r_ankle', ground_frame, right_ankle, r_ankle_shape)
    r_knee_viz_frame = VisualizationFrame('r_knee', ground_frame, right_knee, r_knee_shape)
    l_ankle_viz_frame = VisualizationFrame('l_ankle', ground_frame, left_ankle, l_ankle_shape)
    l_knee_viz_frame = VisualizationFrame('l_knee', ground_frame, left_knee, l_knee_shape)
    hip_viz_frame = VisualizationFrame('hip_frame', ground_frame, hip, hip_shape)

    # creating center points for tigh,shank and foot
    torso_center = Point('t_c')
    r_tigh_center = Point('rt_c')
    r_shank_center = Point('rs_c')
    r_foot_center = Point('rf_c')
    l_tigh_center = Point('lt_c')
    l_shank_center = Point('ls_c')
    l_foot_center = Point('lf_c')

    # setting position for center points 
    torso_center.set_pos(hip, thigh_length/2 * torso_frame.y)
    r_tigh_center.set_pos(hip, thigh_length/2 * -right_thigh_frame.y)
    r_shank_center.set_pos(right_knee, shank_length/2 * -right_shank_frame.y)
    r_foot_center.set_pos(right_ankle, (foot_length/2 - heel_dist) * right_foot_frame.x + 
                                                    (-foot_height/2) * right_foot_frame.y)
    l_tigh_center.set_pos(hip, thigh_length/2 * -left_thigh_frame.y)
    l_shank_center.set_pos(left_knee, shank_length/2 * -left_shank_frame.y)
    l_foot_center.set_pos(left_ankle, (foot_length/2 - heel_dist) * left_foot_frame.x + 
                                                    (-foot_height/2) * left_foot_frame.y)

    # creating torso shape
    torso_shape = Cylinder(name='torso', radius=radius, length=constants_dict[thigh_length], color='darkorange')
    torso_viz_frame = VisualizationFrame('Torso', torso_frame, torso_center, torso_shape)                                                  
    
    # creating thigh shapes 
    r_thigh_shape = Cylinder(name='r_thigh', radius=radius, length=constants_dict[thigh_length], color='blue')
    r_thigh_viz_frame = VisualizationFrame('R_Thigh', right_thigh_frame, r_tigh_center, r_thigh_shape)
    l_thigh_shape = Cylinder(name='l_thigh', radius=radius, length=constants_dict[thigh_length], color='blue')
    l_thigh_viz_frame = VisualizationFrame('L_Thigh', left_thigh_frame, l_tigh_center, l_thigh_shape)

    # creating shank shapes 
    r_shank_shape = Cylinder(name='r_shank', radius=radius, length=constants_dict[shank_length], color='green')
    r_shank_viz_frame = VisualizationFrame('R_Shank', right_shank_frame, r_shank_center, r_shank_shape)
    l_shank_shape = Cylinder(name='l_shank', radius=radius, length=constants_dict[shank_length], color='green')
    l_shank_viz_frame = VisualizationFrame('L_Shank', left_shank_frame, l_shank_center, l_shank_shape)

    # creating foot shapes
    r_foot_shape = Box(name='r_foot', width=constants_dict[foot_length], height=foot_height, depth=0.08, color='red') 
    r_foot_viz_frame = VisualizationFrame('R_Foot', right_foot_frame, r_foot_center, r_foot_shape)
    l_foot_shape = Box(name='l_foot', width=constants_dict[foot_length], height=foot_height, depth=0.08, color='red') 
    l_foot_viz_frame = VisualizationFrame('L_Foot', left_foot_frame, l_foot_center, l_foot_shape)

    # creating center of pressure shapes for visualization
    r_cop_shape = Sphere(name='r_cop', color='black', radius=0.01)
    r_cop_viz_frame = VisualizationFrame('R_CoP', right_foot_frame, right_ground_reaction_point, r_cop_shape)
    l_cop_shape = Sphere(name='l_cop', color='yellow', radius=0.01)
    l_cop_viz_frame = VisualizationFrame('L_CoP', left_foot_frame, left_ground_reaction_point, l_cop_shape)

    # adding center of mass for the foot
    r_foot_com_shape = Box(name='rf_com', width=0.01, height=0.01, depth=0.30, color='black')
    r_foot_com_viz_frame = VisualizationFrame('Rf_CoM', right_foot_frame, right_foot_mass_center, r_foot_com_shape)
    l_foot_com_shape = Box(name='lf_com', width=0.01, height=0.01, depth=0.30, color='black')
    l_foot_com_viz_frame = VisualizationFrame('Lf_CoM', left_foot_frame, left_foot_mass_center, l_foot_com_shape)

    # adding center of mass for the shank
    r_shank_com_shape = Box(name='rs_com', width=0.01, height=0.01, depth=0.30, color='black')
    r_shank_com_viz_frame = VisualizationFrame('Rs_CoM', right_shank_frame, right_shank_mass_center, r_shank_com_shape)
    l_shank_com_shape = Box(name='ls_com', width=0.01, height=0.01, depth=0.30, color='black')
    l_shank_com_viz_frame = VisualizationFrame('Ls_CoM', left_shank_frame, left_shank_mass_center, l_shank_com_shape)

    # adding center of mass for the thigh
    r_thigh_com_shape = Box(name='rt_com', width=0.01, height=0.01, depth=0.30, color='black')
    r_thigh_com_viz_frame = VisualizationFrame('Rt_CoM', right_thigh_frame, right_thigh_mass_center, r_thigh_com_shape)
    l_thigh_com_shape = Box(name='lt_com', width=0.01, height=0.01, depth=0.30, color='black')
    l_thigh_com_viz_frame = VisualizationFrame('Lt_CoM', left_thigh_frame, left_thigh_mass_center, l_thigh_com_shape)

    # adding center of mass for the torso
    torso_com_shape = Box(name='t_com', width=0.01, height=0.01, depth=0.30, color='black')
    torso_com_viz_frame = VisualizationFrame('T_CoM', torso_frame, torso_mass_center, torso_com_shape)


    # creating a scene object 
    scene = Scene(ground_frame, O, 
                                   name='biomechEst_fullBody',
                                #    system=system,
                                   )
    
    scene.visualization_frames = [r_ankle_viz_frame,
                                   r_knee_viz_frame,
                                   l_ankle_viz_frame,
                                   l_knee_viz_frame,
                                   torso_viz_frame,
                                   hip_viz_frame,
                                   r_thigh_viz_frame,
                                   r_shank_viz_frame,
                                   r_foot_viz_frame, 
                                   r_cop_viz_frame,
                                   l_thigh_viz_frame,
                                   l_shank_viz_frame,
                                   l_foot_viz_frame, 
                                   l_cop_viz_frame,
                                   r_foot_com_viz_frame,
                                   r_shank_com_viz_frame,
                                   r_thigh_com_viz_frame,
                                   l_foot_com_viz_frame,
                                   l_shank_com_viz_frame,
                                   l_thigh_com_viz_frame,
                                   torso_com_viz_frame,
                                   ]
    
    return scene


def prepare_visualization(batch_data, t, scene, slower=1):

    # convert all data arrays to numpy
    for name, array in batch_data.items():
        if not isinstance(array, np.ndarray):
            batch_data[name] = array.numpy()

    # check if the constants are identical (belong to same model)
    if batch_data['body_constants'].ndim >=2:
        first_array = batch_data['body_constants'][0]
        are_identical = np.all(np.equal(batch_data['body_constants'], first_array))

        if not are_identical:
            raise ValueError("The data do not belong the same subject, constants are different !")
    
    # set scene.constants 
    constants_dict = dict(zip(constants, batch_data['body_constants'][0]))
    scene.constants = constants_dict

    # set scene times and state symbols
    scene.times = t * slower
    scene.states_symbols = coordinates + speeds + [right_center_of_press_x, right_center_of_press_y,
                                                   left_center_of_press_x, left_center_of_press_y]

    # preparing trajectory data 
    y_traj = np.concatenate((batch_data['translation_data'], batch_data['angles'],
                             batch_data['qdot_translation_data'], batch_data['qdot_angles']), axis=-1)
    

    nSample = int(batch_data['subject_data'][0,-1])
    # # concatenate cycles 
    # for i in range(int(len(t)/nSample)):
    #     y_traj[(i+1)*nSample:(i+2)*nSample,0] += y_traj[(i+1)*nSample-1,0] 

    cop = np.concatenate((batch_data['right_cops'],batch_data['left_cops']), axis=-1) 

    y_traj = np.concatenate((y_traj,cop), axis=-1)

    # set scene states trajectories
    scene.states_trajectories = y_traj[0]

    # convert all data arrays to tensor again 
    for name, array in batch_data.items():
        if not isinstance(array, torch.Tensor):
            batch_data[name] = torch.from_numpy(array)

    return scene, y_traj


def plot_trajectories(batch_est, batch_true, t):

    labels = [["angle", "pelvis_tilt"],
              ["angle", "hip_flexion_r"],
              ["angle", "knee_angle_r"],
              ["angle", "ankle_angle_r"],
              ["angle", "hip_flexion_l"],
              ["angle", "knee_angle_l"],
              ["angle", "ankle_angle_l"]]

    est_angles = batch_est['angles'][0] 
    true_angles = batch_true['angles'][0]

    # # rad to deg   
    est_angles = rad2deg(est_angles)
    true_angles = rad2deg(true_angles)

    # Create a single figure and three subplots (1 row, 3 columns)
    fig, axs = subplots(4, 2, figsize=(10, 10))

    # Define the number of rows and columns for subplots
    num_rows = 4
    num_cols = 2

    for i, label in enumerate(labels):

        # Calculate the subplot position
        row = i // num_cols
        col = i % num_cols
        ax = axs[row, col]  # Get the current subplot

        rmse = np.sqrt(((est_angles[:,i] - true_angles[:,i]) ** 2).mean())

        ax.plot(t, est_angles[:,i], c="g", label="pred")
        ax.plot(t, true_angles[:,i], c="r", linestyle="dashed", label="real")
        ax.set_ylabel(f"{label[0]} in [{UnitDic[label[0]]}]", fontsize='small')
        ax.set_xlabel('time [s]', fontsize='small')
        ax.set_title(f"{label[1]}  (rmse={np.round(rmse,3):.3f})", fontsize='small')
        ax.legend(loc="upper right", fontsize='x-small')

    # Adjust spacing between subplots
    subplots_adjust(hspace=0.5)
    show()


def compare_trajectories(y_est, y_true,t):

    # right side angle indices 
    inds = [3,4,5]
    selected_symbols = [coordinates[i] for i in inds]
    est_labels = ["${}$".format(vlatex(s)) for s in selected_symbols]

    # rad to degree 
    y_est = y_est.copy()
    y_true = y_true.copy()
    y_est = rad2deg(y_est[0,:,:3])
    y_true = rad2deg(y_true[0,:,:3])

    # Create a single figure with two subplots
    fig, (ax1, ax2) = subplots(2, 1, figsize=(8, 6))

    # Plot estimated angles in the first subplot
    ax1.plot(t, y_est, label=est_labels)
    ax1.set_xlabel('Time [s]')
    ax1.set_ylabel('Angle [deg]')
    ax1.set_title('Estimated Angles')
    ax1.legend()

    # Plot true angles in the second subplot
    ax2.plot(t, y_true, label=est_labels)
    ax2.set_xlabel('Time [s]')
    ax2.set_ylabel('Angle [deg]')
    ax2.set_title('True Angles')
    ax2.legend()

    # Adjust spacing between subplots
    tight_layout()


## plots the Kanes Method loss vector
def plot_losses(ax, K_loss, t, indexes=range(9), linestyle=None):

    num_losses = K_loss.shape[1]
    # copying the variables 
    K_loss = K_loss.detach().clone()
    t = t.copy()

    # # disregard the first sample in the loss due to backward Euler
    # K_loss = K_loss[1:,:]
    # t = t[1:]

    sum_K_loss = torch.sum(torch.abs(K_loss), dim=0) 

    loss_labels = ["${}$".format(vlatex(s)) for s in coordinates]
    unit_list = ['(N)', '(N)', '(Nm)', '(Nm)', '(Nm)', '(Nm)', '(Nm)', '(Nm)', '(Nm)']
    label_list = []
    label_list = ['Kloss({}): {:.2f} {}'.format(label, sum_K_loss[i].numpy(),
                                                 unit_list[i]) for i, label in enumerate(loss_labels)]
    selected_labels = [label_list[i] for i in indexes]
    ax.plot(t, K_loss[:,indexes], label=selected_labels) # the first to losses depend on hip force estimation 
    ax.set_xlabel('Time [s]')                            # so the last tree losses should be close to zero
    ax.set_ylabel('Kane loss')


def plot_GRFs(ax, y_d, t):

    y_d = y_d.detach().clone()
    GRFs = y_d[:,:,6:].reshape((y_d.shape[0]*y_d.shape[1],2))
    ax.plot(t, GRFs, linestyle='--', label=['GFx','GRy'])
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('ground force [N]')

def plot_hip_data(ax, hip_data, t, indexes=range(len(symbols_hip_data)), linestyle=None):

    hip_data = hip_data.numpy().copy().reshape((hip_data.shape[0]*hip_data.shape[1], hip_data.shape[2]))
    selected_symbols = [symbols_hip_data[i] for i in indexes]
    hip_labels = ["${}$".format(vlatex(s)) for s in selected_symbols]
    if len(hip_labels)==1:
        hip_labels = hip_labels[0]
    ax.plot(t, hip_data, linestyle=linestyle, label=hip_labels)
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('hip data')
    ax.legend()

def plot_estimation_variables(ax, y_d, t, indexes=range(len(symbols_est_data_right)), side="right", linestyle=None):

    y_d = y_d.detach().numpy().copy() #.reshape((y_d.shape[0]*y_d.shape[1],y_d.shape[2]))
    
    if side == "right":
        selected_symbols = [symbols_est_data_right[i] for i in indexes]
    elif side == "left":
        selected_symbols = [symbols_est_data_left[i] for i in indexes]
    else:
        raise ValueError("You have entered an invalid side!")

    est_labels = ["${}$".format(vlatex(s)) for s in selected_symbols]
    ax.plot(t, y_d, linestyle=linestyle, label=est_labels)
    ax.set_xlabel('Time [s]')
    ax.set_ylabel('est data')







