#!/usr/bin/env python

from email.quoprimime import body_check
from typing import Optional
from sympy.physics.mechanics import KanesMethod
from sympy import lambdify, diff, Matrix, MutableDenseMatrix
import torch
import numpy as np
import jax
import jax.numpy as jnp
from jax.numpy import apply_along_axis
from src.multibody_sim.kinetics import *
from src.utils.utils import IK_columns
import os
import cloudpickle

# Equations of Motion
# ===================
import time

SIMPLIFY_KANE = False

class equation_of_motions:
    def __init__(self):
        print("calculating equations of motion...")
        # Check if pickled equation of motion exists
        fr, frstar = self.get_fr_star()
        F_sym = self.simplify(fr + frstar)
        self.get_F_and_grad(F_sym)

    def get_fr_star(self):
        coordinates = [
            x,
            y,
            theta_pelvis,
            r_theta_hip,
            r_theta_knee,
            r_theta_ankle,
            l_theta_hip,
            l_theta_knee,
            l_theta_ankle,
        ]

        speeds = [
            hip_vel_x,
            hip_vel_y,
            omega_torso,
            r_omega_hip,
            r_omega_knee,
            r_omega_ankle,
            l_omega_hip,
            l_omega_knee,
            l_omega_ankle,
        ]

        kane = KanesMethod(
            ground_frame, coordinates, speeds, kinematical_differential_equations
        )

        loads = [
            right_grf,
            right_moment,
            left_grf,
            left_moment,
            torso_grav_force,
            right_thigh_grav_force,
            right_shank_grav_force,
            right_foot_grav_force,
            left_thigh_grav_force,
            left_shank_grav_force,
            left_foot_grav_force,
            right_thigh_torque,
            right_shank_torque,
            right_foot_torque,
            left_thigh_torque,
            left_shank_torque,
            left_foot_torque,
        ]

        bodies = [
            torso,
            right_thigh,
            right_shank,
            right_foot,
            left_thigh,
            left_shank,
            left_foot,
        ]

        fr, frstar = kane.kanes_equations(bodies, loads)
        return fr, frstar

    def simplify(self, F_sym):
        if not SIMPLIFY_KANE: # Simplification takes a while
            print('Skipping simplification of Kane equations')
            return F_sym
        print("Simplifying Kane's equations. Estimated duration: 3-5h")
        for i in range(len(F_sym)):
            print("Simplifying Kane's equations: #", i, "of", len(F_sym))
            F_sym[i] = F_sym[i].simplify()
        return F_sym

    def get_F_and_grad(self, F_sym):
        # this is the vector to be equalized to zero vector (F:Ax-b=0)

        # defining some aditional variables
        der_hip_vel_x, der_hip_vel_y = hip_vel_x.diff(), hip_vel_y.diff()

        der_r_omega_hip, der_r_omega_knee, der_r_omega_ankle = (
            r_omega_hip.diff(),
            r_omega_knee.diff(),
            r_omega_ankle.diff(),
        )

        der_l_omega_hip, der_l_omega_knee, der_l_omega_ankle = (
            l_omega_hip.diff(),
            l_omega_knee.diff(),
            l_omega_ankle.diff(),
        )

        der_omega_torso = omega_torso.diff()

        # coordinates and speeds defining the motion and DoF
        coordinates = [
            x,
            y,
            theta_pelvis,
            r_theta_hip,
            r_theta_knee,
            r_theta_ankle,
            l_theta_hip,
            l_theta_knee,
            l_theta_ankle,
        ]

        speeds = [
            hip_vel_x,
            hip_vel_y,
            omega_torso,
            r_omega_hip,
            r_omega_knee,
            r_omega_ankle,
            l_omega_hip,
            l_omega_knee,
            l_omega_ankle,
        ]

        der_speeds = [
            der_hip_vel_x,
            der_hip_vel_y,
            der_omega_torso,
            der_r_omega_hip,
            der_r_omega_knee,
            der_r_omega_ankle,
            der_l_omega_hip,
            der_l_omega_knee,
            der_l_omega_ankle,
        ]

        # torques, GRFs and CoPs
        torques = [
            right_hip_torque,
            right_knee_torque,
            right_ankle_torque,
            left_hip_torque,
            left_knee_torque,
            left_ankle_torque,
        ]

        ground_forces = [
            grf_r_x,
            grf_r_y,
            grf_l_x,
            grf_l_y,
        ]

        moments = [
            m_r,
            m_l
        ]

        # body constants params
        body_constants = [
            thigh_length,
            thigh_com_dist,
            thigh_mass,
            thigh_inertia,
            shank_length,
            shank_com_dist,
            shank_mass,
            shank_inertia,
            foot_length,
            foot_com_dist,
            foot_mass,
            foot_inertia,
            torso_com_dist,
            torso_mass,
            torso_inertia,
            g,
        ]

        # adding all params needed for Kanes Method loss function
        all_params = (
            coordinates
            + speeds
            + der_speeds
            + torques
            + ground_forces
            + moments
            + body_constants
        )

        F = lambdify(all_params, F_sym, modules="jax")

        # grad definitions for coordinates and speeds
        grad_coordinates = [
            F_sym.diff(x),
            F_sym.diff(y),
            F_sym.diff(theta_pelvis),
            F_sym.diff(r_theta_hip),
            F_sym.diff(r_theta_knee),
            F_sym.diff(r_theta_ankle),
            F_sym.diff(l_theta_hip),
            F_sym.diff(l_theta_knee),
            F_sym.diff(l_theta_ankle),
        ]

        grad_speeds = [
            F_sym.diff(hip_vel_x),
            F_sym.diff(hip_vel_y),
            F_sym.diff(omega_torso),
            F_sym.diff(r_omega_hip),
            F_sym.diff(r_omega_knee),
            F_sym.diff(r_omega_ankle),
            F_sym.diff(l_omega_hip),
            F_sym.diff(l_omega_knee),
            F_sym.diff(l_omega_ankle),
        ]

        grad_der_speeds = [
            F_sym.diff(der_hip_vel_x),
            F_sym.diff(der_hip_vel_y),
            F_sym.diff(der_omega_torso),
            F_sym.diff(der_r_omega_hip),
            F_sym.diff(der_r_omega_knee),
            F_sym.diff(der_r_omega_ankle),
            F_sym.diff(der_l_omega_hip),
            F_sym.diff(der_l_omega_knee),
            F_sym.diff(der_l_omega_ankle),
        ]

        # grad definitions for moments
        grad_torques = [
            F_sym.diff(right_hip_torque),
            F_sym.diff(right_knee_torque),
            F_sym.diff(right_ankle_torque),
            F_sym.diff(left_hip_torque),
            F_sym.diff(left_knee_torque),
            F_sym.diff(left_ankle_torque),
        ]

        # grad definitions for ground_forces
        grad_ground_forces = [
            F_sym.diff(grf_r_x),
            F_sym.diff(grf_r_y),
            F_sym.diff(grf_l_x),
            F_sym.diff(grf_l_y),
        ]

        grad_moments = [
            F_sym.diff(m_r),
            F_sym.diff(m_l),
        ]

        grad_body_constants = [
            F_sym.diff(thigh_length),
            F_sym.diff(thigh_com_dist),
            F_sym.diff(thigh_mass),
            F_sym.diff(thigh_inertia),
            F_sym.diff(shank_length),
            F_sym.diff(shank_com_dist),
            F_sym.diff(shank_mass),
            F_sym.diff(shank_inertia),
            F_sym.diff(foot_length),
            F_sym.diff(foot_com_dist),
            F_sym.diff(foot_mass),
            F_sym.diff(foot_inertia),
            F_sym.diff(torso_com_dist),
            F_sym.diff(torso_mass),
            F_sym.diff(torso_inertia),
            F_sym.diff(g),
        ]


        # adding all grads for backpropagation
        grad_list = (
            grad_coordinates
            + grad_speeds
            + grad_der_speeds
            + grad_torques
            + grad_ground_forces
            + grad_moments
            + grad_body_constants
        )



        # defining lambdify functions for gradients
        grads = Matrix([grad_list])

        # need to trasnpose for matrix multiplication later
        grads = grads.T

        F_grad = lambdify(all_params, grads, modules="jax")


        self.F = F
        self.F_grad = F_grad

    def jit(self):
        self.F = jax.jit(self.F)
        self.F_grad = jax.jit(self.F_grad)


    # Define a wrapper function that extracts the input values from a tensor and passes them to the function `f`
    def Fwp(self, arr):
        return self.F(*arr)

    def Fwp_grad(self, arr):
        return self.F_grad(*arr)

    # wrap the function F in a PyTorch
    # ===================
# Check if pickled equation of motion exists, then save with dill
if os.path.exists("eqm.pkl"):
    with open("eqm.pkl", "rb") as f:
        eqm = cloudpickle.load(f)
        eqm.jit()
else:
    eqm = equation_of_motions()
    with open("eqm.pkl", "wb") as f:
        cloudpickle.dump(eqm, f)
    eqm.jit()


#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


# Define a custom PyTorch autograd function to wrap the expression and compute gradients
class Kmatrix(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        inputs,
        weights,
    ):
        # Evaluate the SymPy expression with the given input values
        arr = inputs.clone()

        # put array to jax device
        arr = jax.device_put(arr.cpu().detach().numpy())

        # perform sym function evalution with jax
        reshaped_arr = arr.reshape((-1, arr.shape[-1]))
        result = jax.vmap(eqm.Fwp)(reshaped_arr)
        result = result.reshape(
            (arr.shape[0], arr.shape[1], result.shape[-2], result.shape[-1])
        )

        # Store the input values for use in the backward pass
        ctx.save_for_backward(inputs, weights)

        # Return the result as a PyTorch tensor
        return torch.from_numpy(np.asarray(result).copy()).squeeze(-1)

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve the input values from the context
        inputs, weights = ctx.saved_tensors

        # Compute the gradient of the SymPy expression with respect to the input values
        arr = inputs.clone()

        # put array to jax device
        arr = jax.device_put(arr.cpu().detach().numpy())

        # perform sym function evalution with jax
        reshaped_arr = arr.reshape((-1, arr.shape[-1]))
        grads = jax.vmap(eqm.Fwp_grad)(reshaped_arr)

        grads = grads.reshape(
            (arr.shape[0], arr.shape[1], grads.shape[-2], grads.shape[-1])
        )


        grads = torch.from_numpy(np.asarray(grads).copy())
        grad_input_tmp = torch.matmul(grads, grad_output.unsqueeze(-1)).squeeze(-1)

        return grad_input_tmp.to(device_), None

device_ = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
IK_pred_data_mapping = {
    "coordinates": [
        IK_columns.index("tx"),
        IK_columns.index("ty"),
        IK_columns.index("a_pelvis"),
        IK_columns.index("a_hip_r"),
        IK_columns.index("a_knee_r"),
        IK_columns.index("a_ankle_r"),
        IK_columns.index("a_hip_l"),
        IK_columns.index("a_knee_l"),
        IK_columns.index("a_ankle_l"),
    ],
    "speeds": [
        IK_columns.index("dtx"),
        IK_columns.index("dty"),
        IK_columns.index("da_pelvis"),
        IK_columns.index("da_hip_r"),
        IK_columns.index("da_knee_r"),
        IK_columns.index("da_ankle_r"),
        IK_columns.index("da_hip_l"),
        IK_columns.index("da_knee_l"),
        IK_columns.index("da_ankle_l"),
    ],
    "der_speeds": [
        IK_columns.index("ddtx"),
        IK_columns.index("ddty"),
        IK_columns.index("dda_pelvis"),
        IK_columns.index("dda_hip_r"),
        IK_columns.index("dda_knee_r"),
        IK_columns.index("dda_ankle_r"),
        IK_columns.index("dda_hip_l"),
        IK_columns.index("dda_knee_l"),
        IK_columns.index("dda_ankle_l"),
    ],
}


def Kmatrix_loss_moveEst(
    IK_pred: torch.Tensor,
    torques: torch.Tensor,
    ground_forces: torch.Tensor,
    center_of_pressures: torch.Tensor,
    body_constants: torch.Tensor,
    weights,
    device: torch.device,
) -> torch.Tensor:
    input = torch.cat(
        (
            IK_pred[:, :, IK_pred_data_mapping["coordinates"]],
            IK_pred[:, :, IK_pred_data_mapping["speeds"]],
            IK_pred[:, :, IK_pred_data_mapping["der_speeds"]],
            torques,
            ground_forces,
            center_of_pressures,
            body_constants,
        ),
        dim=-1,
    )
    global device_
    device_ = device

    # Input shape:   torch.Size([batch_size, seq_len, 42])
    # coords shape:  torch.Size([batch_size, seq_len, 9])
    # speeds shape:  torch.Size([batch_size, seq_len, 9])
    # der_speeds shape:  torch.Size([batch_size, seq_len, 9])
    # torques shape:  torch.Size([batch_size, seq_len, 7])
    # ground_forces shape:  torch.Size([batch_size, seq_len, 4])
    # center_of_pressures shape:  torch.Size([batch_size, seq_len, 4])

    # print(input.shape)  # 9 + 9 + 9 + 6 + 8 + 8 + 16 = 65
    # print(IK_pred[:, :, IK_pred_data_mapping["coordinates"]].shape)  # 9
    # print(IK_pred[:, :, IK_pred_data_mapping["speeds"]].shape)  # 9
    # print(IK_pred[:, :, IK_pred_data_mapping["der_speeds"]].shape)  # 9
    # print(torques.shape) # 6
    # print(ground_forces.shape) # 8
    # print(center_of_pressures.shape) # 8 # if using moments, it is only 2
    # print(body_constants.shape) # 16

    loss_vector = Kmatrix.apply(input, torch.tensor(weights, device=device))

    return loss_vector.to(device)  # type: ignore
