import torch
from qpsolvers import solve_qp
import numpy as np
from src.multibody_sim.equations_of_motion import *
import matplotlib.pyplot as plt
import time
#from PIPeqm import *

k_p = 100**2
k_p_1 = 10#100**2 * (2/3)
k_d = 10
bc = None

# Joint angle controller
def get_phi_ddot_des(phi, phi_dot, q_des, kp=6600, kd=60):
    return kp*(q_des-phi) - kd * phi_dot

# Position controller
# r_ref = r + T(v) * 0.01
# r_ddot_des  = k_p * (r_ref-r) - k_d * r_dot
# r_ref: reference position
# r: current position
# r_dot: current velocity
# T(v): velocity from the NN
# rcc = torch.stack([reconstructed_data[key][0,:,:] for key in reconstructed_data.keys()])
def get_r_ddot_des(rcc, r, r_dot):
    v_x = rcc[:,:,1]
    v_y = rcc[:,:,4]
    r_ref = r + torch.cat([v_x,v_y],dim=-1)*0.01
    return k_p * (r_ref-r) - k_d * r_dot

def block_diagonal_matrix_np(matrix2d_list):
    r"""
    Generate a block diagonal 2d matrix using a series of 2d matrices. (numpy, single)

    :param matrix2d_list: A list of matrices (2darray).
    :return: The block diagonal matrix.
    """
    ret = np.zeros(sum([np.array(m.shape) for m in matrix2d_list]))
    r, c = 0, 0
    for m in matrix2d_list:
        lr, lc = m.shape
        ret[r:r+lr, c:c+lc] = m
        r += lr
        c += lc
    return ret

def simple_FK(q, r_only=False, l_only=False):
    # returns c, gc, cp_pip
    c = torch.zeros((2,7)) # joints
    c[0,0] = q[0] # Pelvis
    c[1,0] = q[1]
    c[:,1] = c[:,0] # Hip_R
    if r_only or not l_only:
        c[0,2] = c[0,1] + torch.sin(q[2]+q[3]) * bc[0] # Knee_R
        c[1,2] = c[1,1] + torch.cos(q[2]+q[3]) * -bc[0] # Knee_R
        c[0,3] = c[0,2] + torch.sin(q[2]+q[3]+q[4]) * bc[4] # Ankle_R
        c[1,3] = c[1,2] + torch.cos(q[2]+q[3]+q[4]) * -bc[4] # Ankle_R
        if r_only:
            return c
    c[:,4] = c[:,0] # Hip_L
    c[0,5] = c[0,4] + torch.sin(q[2]+q[6]) * bc[0] # Knee_L
    c[1,5] = c[1,4] + torch.cos(q[2]+q[6]) * -bc[0] # Knee_L
    c[0,6] = c[0,5] + torch.sin(q[2]+q[6]+q[7]) * bc[4] # Ankle_L
    c[1,6] = c[1,5] + torch.cos(q[2]+q[6]+q[7]) * -bc[4] # Ankle_L
    return c

def geta(q, qdot, qddot):
    """ 
        Get the acceleration of each joint based on the current state, in cartesian coordinates. Qdot assued to be 0
    """
    c_ddot = torch.zeros((2,7))
    c = simple_FK(q)
    c_ddot[0,0] = qddot[0]
    c_ddot[1,0] = qddot[1]
    c_ddot[0,1] = c_ddot[0,0]
    c_ddot[1,1] = c_ddot[1,0]
    angle_r_knee = q[2] + q[3]
    angle_r_knee_dot = qdot[2] + qdot[3]
    c_ddot[0, 2] = (
        c_ddot[0, 1]
        + (- torch.sin(angle_r_knee) * angle_r_knee_dot**2) * bc[0]
    )
    c_ddot[1, 2] = (
        c_ddot[1, 1]
        + (- torch.cos(angle_r_knee) * angle_r_knee_dot**2) * -bc[0]
    )
    c_ddot[0,4] = c_ddot[0,0]
    c_ddot[1,4] = c_ddot[1,0]
    angle_l_knee = q[2] + q[6]
    angle_l_knee_dot = qdot[2] + qdot[6]
    c_ddot[0, 5] = (
        c_ddot[0, 4]
        + (- torch.sin(angle_l_knee) * angle_l_knee_dot**2) * bc[0]
    )
    c_ddot[1, 5] = (
        c_ddot[1, 4]
        + (- torch.cos(angle_l_knee) * angle_l_knee_dot**2) * -bc[0]
    )
    return c_ddot

def gc_r(q,d, toe=False, heel=False):
    c = simple_FK(q)
    # Ground contact
    gc_0 = torch.zeros(2)
    angle_r = q[2]+q[3]+q[4]+q[5]
    cp_r_h_x = c[0,3] - torch.sin(angle_r) * heel_y + torch.cos(angle_r) * heel_x
    cp_r_h_y = c[1,3] + torch.cos(angle_r) * heel_y + torch.sin(angle_r) * heel_x
    cp_r_t_x = c[0,3] - torch.sin(angle_r) * heel_y + torch.cos(angle_r) * toe_x
    cp_r_t_y = c[1,3] + torch.cos(angle_r) * heel_y + torch.sin(angle_r) * toe_x
    if heel:
        return torch.stack([cp_r_h_x, cp_r_h_y])
    if toe:
        return torch.stack([cp_r_t_x, cp_r_t_y])
    cp_ratio = torch.tanh(angle_r*7)/2 + 0.5
    gc_0[0] = cp_ratio * cp_r_h_x + (1 - cp_ratio) * cp_r_t_x + d
    gc_0[1] = cp_ratio * cp_r_h_y + (1 - cp_ratio) * cp_r_t_y
    return gc_0

def gcr1(q):
    return gc_r(q,0.1)

def cfkp(q):
    c = simple_FK(q)
    return c[:,0]

def cfkkr(q):
    c = simple_FK(q)
    return c[:,2]

def cfkar(q):
    c = simple_FK(q)
    return c[:,3]

def cfkkl(q):
    c = simple_FK(q)
    return c[:,5]

def cfkal(q):
    c = simple_FK(q)
    return c[:,6]

def gcr2(q):
    return gc_r(q,-0.1)

def gcl1(q):
    return gc_l(q,0.1)

def gcl2(q):
    return gc_l(q,-0.1)

def gcr0(q):
    return gc_r(q,0)

def gcl0(q):
    return gc_l(q,0)

def gcrt(q):
    return gc_r(q,0, toe=True)

def gcrh(q):
    return gc_r(q,0, heel=True)

def gclt(q):
    return gc_l(q,0, toe=True)

def gclh(q):
    return gc_l(q,0, heel=True)

def gc_l(q,d, heel = False, toe = False):
    c = simple_FK(q)
    gc_0 = torch.zeros(2)
    angle_l = q[2]+q[6]+q[7]+q[8]
    cp_l_h_x = c[0,6] - torch.sin(angle_l) * heel_y + torch.cos(angle_l) * heel_x
    cp_l_h_y = c[1,6] + torch.cos(angle_l) * heel_y + torch.sin(angle_l) * heel_x
    cp_l_t_x = c[0,6] - torch.sin(angle_l) * heel_y + torch.cos(angle_l) * toe_x
    cp_l_t_y = c[1,6] + torch.cos(angle_l) * heel_y + torch.sin(angle_l) * toe_x
    if heel:
        return torch.stack([cp_l_h_x, cp_l_h_y])
    if toe:
        return torch.stack([cp_l_t_x, cp_l_t_y])
    cp_ratio = torch.tanh(angle_l*7)/2 + 0.5
    gc_0[0] = cp_ratio * cp_l_h_x + (1 - cp_ratio) * cp_l_t_x + d
    gc_0[1] = cp_ratio * cp_l_h_y + (1 - cp_ratio) * cp_l_t_y
    return gc_0

def gc_ht_l(q):
    return gc_l(q,0, heel_and_toe=True)

class PIPadapter():
    def __init__(self, datasample):
        global bc
        global heel_x
        global heel_y
        global toe_x

        self.V0 = np.zeros(46)
        bc = datasample['body_constants'][0,0,:].detach()
        # PIP is tuned towards a weight of roughly 75 kgs
        #bc[[2,3,6,7,10,11,13,14]] *= 75
        self.V0[-16:] = bc
        self.V0[-1] = -9.81
        heel_x = datasample['ground_contact_model'][0,0,0].detach()
        toe_x = datasample['ground_contact_model'][0,0,5].detach()
        heel_y = datasample['ground_contact_model'][0,0,1].detach()

    def predict(self, estimation, reconstructed_data, grf_0, gc_positions):
        self.t0 = 0
        self.t = self.t0
        self.t_max = grf_0.shape[1]
        #self.t_max = 80
        q_des = estimation['IK_data'][0,:,::3].detach().cpu().numpy().copy()
        q_dot_des = estimation['IK_data'][0,:,1::3].detach().cpu().numpy().copy()
        self.q = q_des[self.t]
        self.qdot = q_dot_des[self.t]
        self.last_x = np.zeros((0))
        # Model c
        ground_contact_probabilities = torch.cat([gc_positions[key][0,:,3:4] for key in gc_positions.keys()],dim=-1)
        ground_contact_probabilities = ground_contact_probabilities.detach().cpu().numpy().copy()
        ground_contact_probabilities = -np.tanh((ground_contact_probabilities-0.015)*20) * 0.49 + 0.51 # Ground contact by y position

        ground_contact_probabilities2 = torch.cat([gc_positions[key][0,:,1:2] for key in gc_positions.keys()],dim=-1)
        ground_contact_probabilities2 = ground_contact_probabilities2.detach().cpu().numpy().copy()
        ground_contact_probabilities2 = (1-np.tanh(np.abs(ground_contact_probabilities2)*2)) * .99 +0.01 # Ground contact by x_speed
        
        self.ground_contact_probabilities = np.minimum(ground_contact_probabilities, ground_contact_probabilities2)

        q = []
        grf = []
        grf_names = []
        reconstructed_data = {key: reconstructed_data[key].detach() for key in reconstructed_data.keys()}
        torque = []
        for _ in range(self.t_max):
            if True:
                q_ret, grf_, torque_, qdot, qddot, grf_names_ = self.step(self.q, q_des[self.t], reconstructed_data, grf_0)
                q.append(q_ret.copy())
                grf.append(grf_)
                torque.append(torque_)
                grf_names.append(grf_names_)
            #except Exception as e:
            #    print(e)
            #    return q, grf, torque
        return q, grf, torque, grf_names

    def get_V0(self):
        return self.V0.copy()

    def step(self, q, q_des, reconstructed_data, grf):
        qdot_size = 9
        Js = [np.empty((0,qdot_size))]
        nc = 0
        cps = []
        cp_names = []
        q = torch.from_numpy(q)
        # Set q to the floor?
        # GRFtest
        if any(grf[0,self.t,[1,3]] > 1e-1) and self.t == 0:
            gc1 = gc_l(q,0)
            gc2 = gc_r(q,0)
            #q[1] -= min(gc1[1],gc2[1]) + 0.03
            #print('Set q to the floor')
        pos = simple_FK(q)
        q2 = q.clone()
        q2[:2] = 0
        for foot_idx, foot_candidate in enumerate(['left_foot','right_foot']):
            for cp_idx, cp_candidate in enumerate(['heel','toe']):
                if cp_candidate == 'heel':
                    if foot_candidate == 'left_foot':
                        gc = gclh(q2)
                    else:
                        gc = gcrh(q2)
                else:
                    if foot_candidate == 'left_foot':
                        gc = gclt(q2)
                    else:
                        gc = gcrt(q2)
                cp_prob = self.ground_contact_probabilities[self.t,foot_idx*2+cp_idx]
                if (cp_prob>0.5 and gc[1] < 0.03) or gc[1] < 0:
                    nc += 1
                    cp_names.append([foot_candidate,cp_candidate])
                    # Get the Jacobian
                    cps.append(gc)
                    if foot_candidate == 'left_foot':
                        if cp_candidate == 'heel':
                            Js.append(torch.autograd.functional.jacobian(gclh, q2))
                        else:
                            Js.append(torch.autograd.functional.jacobian(gclt, q2))
                    else: 
                        if cp_candidate == 'heel':
                            Js.append(torch.autograd.functional.jacobian(gcrh, q2))
                        else:
                            Js.append(torch.autograd.functional.jacobian(gcrt, q2))

        # Initialize quadratic program
        As1, bs1, As2, bs2, As3, bs3 = [np.zeros((0, qdot_size))], [np.empty(0)], [np.empty((0, nc * 2))], \
                               [np.empty(0)], [np.zeros((0, qdot_size))], [np.empty(0)]
        Gs1, hs1, Gs2, hs2, Gs3, hs3 = [np.zeros((0, qdot_size))], [np.empty(0)], [np.empty((0, nc * 2))], \
                               [np.empty(0)], [np.zeros((0, qdot_size))], [np.empty(0)]
        
        n_0_dims = 2
        A = np.hstack((np.zeros((qdot_size - n_0_dims, n_0_dims)), np.eye((qdot_size - n_0_dims))))
        b = get_phi_ddot_des(self.q[n_0_dims:], self.qdot[n_0_dims:], q_des[n_0_dims:], 2400, 60)
        b_track = np.array([1,1,1,1,1,1,1])
        As1.append(A)
        bs1.append(b*b_track)


        # Joint speed PD controller (direct)
        q = torch.from_numpy(self.q)
        qdot = torch.from_numpy(self.qdot)
        c_ddot = geta(q ,qdot, torch.zeros(9))
        for joint_idx, func, joint in zip([0, 2, 3, 5, 6], [cfkp, cfkkr, cfkar, cfkkl, cfkal], ['pelvis','knee_r','ankle_r','knee_l','ankle_l']): # Only pelvis and knees, as the feet are excluded because of contacts
            A = torch.autograd.functional.jacobian(func, q)
            cur_vel = A @ qdot
            if joint == 'pelvis':
                pass
                #print('Joint:',joint, 'Cur vel:', cur_vel, 'Des vel:',reconstructed_data[joint][0,self.t,[1,4]])
            a_des = 60 * (reconstructed_data[joint][0,self.t,[1,4]]-cur_vel) #+ k_p * (reconstructed_data[joint][0,self.t,[0,3]]-A[:,joint_idx])
            b = a_des - c_ddot[:,joint_idx]
            As1.append(A.numpy())
            bs1.append(b.numpy())


        # Signor
        A = []
        if nc != 0:
            for idx, cp_name in enumerate(cp_names):
                if cp_name[0] == 'left_foot':
                    ii = 3
                else:
                    ii = 1
                for _ in range(1):
                        A.append(np.eye(2) * max(np.tanh(cps[idx][1]).numpy()*0.5+0.5, 0.005))
            A = block_diagonal_matrix_np(A)
            As2.append(A*0.01)
            bs2.append(np.zeros(2*nc))
        #
        As3.append(block_diagonal_matrix_np([np.eye(3) * 0.1, 0.1 * np.eye(qdot_size - 3)]))
        bs3.append(np.zeros(qdot_size))

        # Contacting Contact points
        for idx, cp_name in enumerate(cp_names):
            if cp_name[0] == 'left_foot':
                foot_idx = 1
            else:
                foot_idx = 0
            if cp_name[1] == 'heel':
                cp_idx = 0
            else:
                cp_idx = 1
            J = Js[idx+1].numpy()
            v = (J @ self.qdot)
            th = -np.log(min(self.ground_contact_probabilities[self.t, foot_idx*2+cp_idx] , 0.84999) / 0.85)
            th_y = (-0.03-pos[1,3*cp_idx+3])/0.01
            Gs1.append(-0.01 * J)
            hs1.append(v - [-th, th_y])
            Gs1.append(0.01 * J)
            hs1.append(-v + [th, max(th,th_y)+1e-6])

        #for idx, cp_name in enumerate(['left_foot','right_foot']):
        #    if cp_name == 'left_foot':
        #        J = J_l
        #    else:
        #        J = J_r
        #    J = J.numpy()
        #    v = (J @ self.qdot)
        #    th = -np.log(min(, 0.84999)/0.85)
        #    th_y = (-0.03-pos[1,3*idx+3])/0.01
        #    Gs1.append(-0.01 * J)
        #    hs1.append(v - [-th, th_y])
        #    Gs1.append(0.01 * J)
        #    hs1.append(-v + [th, max(th,th_y)+1e-6])


        if nc != 0:
            Gs2.append(block_diagonal_matrix_np([np.array([[np.sqrt(2), -0.6],[-np.sqrt(2), -0.6]]) for _ in range(nc)]))
            hs2.append(np.zeros(2 * nc))


        # equations of motions. they are missing
        v0 = self.get_V0()
        v0[:qdot_size] = self.q.copy()
        v0[qdot_size:2*qdot_size] = self.qdot.copy()
        M = mm_kane(*v0)
        M = (M + M.T) / 2
        h = h_kane(*v0)
        A_ = np.hstack((-M, np.vstack(Js).T, np.eye(qdot_size)))
        b_ = h.squeeze()

        # Next thang
        As1, bs1, As2, bs2, As3, bs3 = np.vstack(As1), np.concatenate(bs1), np.vstack(As2), np.concatenate(bs2), np.vstack(As3), np.concatenate(bs3)
        Gs1, hs1, Gs2, hs2, Gs3, hs3 = np.vstack(Gs1), np.concatenate(hs1), np.vstack(Gs2), np.concatenate(hs2), np.vstack(Gs3), np.concatenate(hs3)
        G_ = block_diagonal_matrix_np([Gs1, Gs2, Gs3])
        h_ = np.concatenate((hs1, hs2, hs3))
        P_ = block_diagonal_matrix_np([np.dot(As1.T, As1), np.dot(As2.T, As2), np.dot(As3.T, As3)])
        q_ = np.concatenate((-np.dot(As1.T, bs1), -np.dot(As2.T, bs2), -np.dot(As3.T, bs3)))

        # sovler
        init_ = self.last_x if len(self.last_x) == len(q_) else None
        for i,a in enumerate([P_, q_, G_, h_, A_, b_]):
            assert not np.isnan(a).any(), f"Matrix {i} contains NaN values"
            assert not np.isinf(a).any(), f"Matrix {i} contains Inf values"
        x = solve_qp(P_, q_, G_, h_, A_, b_, solver='quadprog', initvals=init_)
        #x = solve_qp(P_, q_, A=A_, b=b_, solver='quadprog', initvals=init_)

        if x is None or np.linalg.norm(x) > 1e5:
            print('Fallback to cvxopt')
            
            #x = solve_qp(P_, q_, A= A_,b= b_, solver='cvxopt', initvals=init_)

            x = solve_qp(P_, q_, G_, h_, A_, b_, solver='cvxopt', initvals=init_)

        self.qddot = x[:qdot_size]
        grf = x[qdot_size:-qdot_size]
        torque = x[-qdot_size:]
        self.last_x = x

        self.qdot += self.qddot * 0.01
        self.q += self.qdot * 0.01

        self.t += 1
        if False:
            if self.t == self.t0+1:
                fig, self.axs = plt.subplots(3,3)
                self.axs = self.axs.flatten()
            for a in range(9):
                self.axs[a].plot(self.t,self.q[a], 'ro')
                self.axs[a].plot(self.t,q_des[a], 'bo')
            
            plt.draw()
        if self.t % 20 == 0:
            pass
            #print(self.t)
        return self.q, grf, torque, self.qdot, self.qddot, cp_names

    def step_w_gcm(self, q, q_des, reconstructed_data, grf):
        '''
            From chatgpt
        '''
    
    def convert(self, q, grf, torque, cp_names):
        estimation = np.zeros((1, len(q), 27))
        for i in range(len(q)):
            estimation[0,i,::3] = q[i]
            estimation[0,i,1::3] = np.zeros(9)
            estimation[0,i,2::3] = np.zeros(9)
        estimation[0,1:,1] = np.diff(estimation[0,:,0])
        recon = {}
        for idx, key in enumerate(['pelvis', 'hip_r', 'knee_r', 'ankle_r', 'hip_l', 'knee_l', 'ankle_l']):
            recon[key] = torch.zeros(1, len(q), 9)
            for i in range(len(q)):
                FK = simple_FK(torch.from_numpy(q[i]))
                recon[key][0,i,0] = FK[0,idx]
                recon[key][0,i,3] = FK[1,idx]
                recon[key][0,i,6] = torch.tensor(q[i][2+idx])
            
        grf_ = np.zeros((1, len(q), 4))
        for i in range(len(q)):
            for ii, c in enumerate(cp_names[i]):
                if c[0] == 'left_foot':
                    idxx = 2
                    idxy = 3
                else:
                    idxx = 0
                    idxy = 1
                try:
                    grf_[0,i,idxx] = grf[i][2*ii]
                    grf_[0,i,idxy] = grf[i][2*ii+1]   
                
                except:
                    print('To err is human')
                    continue
        return {'IK_data':torch.from_numpy(estimation)}, recon, torch.from_numpy(grf_)




    


def calc_M(q):
    """
     Calculate the mass matrix for the given state q
     The mass matrix is calculated using the forward dynamics of the system
        :param q: The state of the system
        :return: The mass matrix for the given state q [size: 9x9], meaning of the dimensions:
    """
    q = Matrix(q)


# torques, GRFs and CoPs
coordinates = [
    x,
    y,
    theta_pelvis,
    r_theta_hip,
    r_theta_knee,
    r_theta_ankle,
    l_theta_hip,
    l_theta_knee,
    l_theta_ankle,
]
speeds = [
    hip_vel_x,
    hip_vel_y,
    omega_torso,
    r_omega_hip,
    r_omega_knee,
    r_omega_ankle,
    l_omega_hip,
    l_omega_knee,
    l_omega_ankle,
]
torques = [
    right_hip_torque,
    right_knee_torque,
    right_ankle_torque,
    left_hip_torque,
    left_knee_torque,
    left_ankle_torque,
]
ground_forces = [
    grf_r_x,
    grf_r_y,
    grf_l_x,
    grf_l_y,
]
moments = [
    m_r,
    m_l
]
# body constants params
body_constants = [
    thigh_length,
    thigh_com_dist,
    thigh_mass,
    thigh_inertia,
    shank_length,
    shank_com_dist,
    shank_mass,
    shank_inertia,
    foot_length, 
    foot_com_dist,
    foot_mass,
    foot_inertia,
    torso_com_dist,
    torso_mass,
    torso_inertia,
    g,
]
# adding all params needed for Kanes Method loss function
all_params = (
    coordinates
    + speeds
    + torques
    + ground_forces
    + moments
    + body_constants
)

kane = equation_of_motions.get_fr_star(None, True)

mm_kane = lambdify(all_params, kane.mass_matrix) 
#h = kane.forcing - kane.mass_matrix * Matrix([qdot.diff() for qdot in kane.u])
#h_kane = lambdify(all_params, h)
from sympy import symbols, zeros
u_dot_zero = zeros(len(kane.u), 1)  # Same shape as u_dot

# Substituting u_dot = 0 into the forcing vector to isolate h
h = kane.forcing.subs(dict(zip(kane.u.diff(), u_dot_zero)))
h_kane = lambdify(all_params, h)
V0 = np.zeros(46)
