import casadi as ca
import torch
import numpy as np

def evaluate_casadi_expression(mx):
    func = ca.Function('eval', [ca.MX.sym('dummy')], [mx],{"allow_free": True})
    return np.array(func(0).full()).astype(np.float64)

def project(input_vector, bound):
    
    print(f"Current bound: {bound}")
    
    # input_shape = input_tensor.cpu().numpy().shape

    # input_vector = ca.DM(input_tensor.numel(), 1)
    # gt = ca.DM(input_tensor.numel(), 1)
    # for i in range(263):
        # input_vector[i*60:(i+1)*60, 0] = ca.DM(input_tensor[:, i].cpu().numpy())
        # gt[i*60:(i+1)*60, 0] = ca.DM(input_tensor[:, i].cpu().numpy())
        
    # input_vector = ca.reshape(ca.DM(input_tensor.cpu().numpy()), input_tensor.numel(), 1)
    # gt           = ca.reshape(ca.DM(input_tensor.cpu().numpy()), input_tensor.numel(), 1)
    
    gt = input_vector
    
    tester = evaluate_casadi_expression(convert_to_xyz(input_vector, 22))
    
    print("Normal inside: ", tester.reshape((60, 22, 3))[:, :, 2].min(), tester.reshape((60, 22, 3))[:, :, 2].max())
    
    X = ca.MX.sym('X', input_vector.numel())
    obj = ca.sumsqr(ca.reshape(gt, (-1, 1)) - ca.reshape(X, (-1, 1)))
    
    
    # Convert input to the required XYZ format
    transformed = convert_to_xyz(X, 22)
    
    # Constraints: minimum z-coordinate in each frame is zero
    constraints = []
    test_list = []
    for i in range(60):
        test_list.append(tester[i * 22:(i + 1) * 22, 2].min())
        frame_positions = transformed[i * 22:(i + 1) * 22, :]
        constraints.append(ca.mmin(frame_positions[:, 2]))
        
        
    # for i in range(60):
    #     indices = list(range(60)) + [i]*60
    #     test_list.append(tester[indices, 2].min())
    #     frame_positions = transformed[indices, :]
    #     constraints.append(ca.mmin(frame_positions[:, 2]))
    
    
    print("Constraint inside: ", min(test_list), max(test_list))


    # Create an NLP solver
    nlp = {'x': X, 'f': obj, 'g': ca.vertcat(*constraints)}
    opts = {
        "ipopt.max_iter": 150,
        "ipopt.print_level": 5,
        "print_time": 1,
        "ipopt.constr_viol_tol": 1e-3,
        "ipopt.tol": 1e-3,
        "ipopt.mu_strategy": 'adaptive',
        "ipopt.jacobian_approximation": 'finite-difference-values',
        "ipopt.hessian_approximation": 'limited-memory',
    }

    solver = ca.nlpsol('solver', 'ipopt', nlp, opts)

    # Initial guess: the original input tensor flattened
    x0 = input_vector

    # Solve the problem
    cu = ca.inf
    cl = -ca.inf
    sol = solver(x0=x0, lbx=-ca.inf, ubx=ca.inf, lbg=-bound, ubg=bound)

    # Fetch the optimal solution and reshape it back to the original matrix form
    status = solver.stats()
    if status['success']:
        result = convert_to_xyz(sol['x'], 22)
        tester = evaluate_casadi_expression(result)
        bool_list_low = []
        bool_list_high = []
        for i in range(60):
            frame_positions = tester[i * 22:(i + 1) * 22, :]
            bool_list_low.append(frame_positions[:, 2].min() > -bound - 0.1)
            bool_list_high.append(frame_positions[:, 2].min() < bound + 0.1)
            
        print(f"Bool Evaluation | High: {all(bool_list_high)} | Low: {all(bool_list_low)} ")
        print("Complete inside: ", tester.reshape((60, 22, 3))[:, :, 2].min(), tester.reshape((60, 22, 3))[:, :, 2].max())
        
    else:
        return input_vector
    
    
    eval_result = sol['x']

    return eval_result


###########################
#     HELPER FUNCTIONS    #
###########################

def repeat_mx(mx, n_repeats):
    new_matrix = ca.MX.zeros(mx.shape[0]*n_repeats, mx.shape[1])
    for i in range(mx.shape[0]):
        new_matrix[(i*n_repeats):((i+1)*n_repeats), :] = ca.repmat(mx[i, :], n_repeats, 1)
    return new_matrix

def convert_to_xyz(input_data, joints_num):

    data = input_data.reshape((60, 263))
    # data = ca.MX.zeros((60, 263))
    # for i in range(60):
        # data[i, :] = input_data[i*263:(i+1)*263, 0]

    r_rot_quat, r_pos = recover_root_rot_pos(data)
    
    positions = data[:, 4:(joints_num - 1) * 3 + 4]
    new_positions = ca.MX.zeros(1260, 3)
    for i in range(positions.shape[0]):
        new_positions[(i*21):((i+1)*21), :] = positions[i, :].reshape((3, 21)).T
    positions = new_positions
    

    num_positions = positions.shape[0]
    r_rot_quat = qinv(r_rot_quat)
    repeated_quat = repeat_mx(r_rot_quat, num_positions // r_rot_quat.shape[0])
    rotated_positions = qrot(repeated_quat, positions)
 
    for i in range(60):
        start_index = i * 21
        end_index = start_index + 21

        for j in range(start_index, end_index):
            rotated_positions[j, 0] += r_pos[i, 0]
            rotated_positions[j, 2] += r_pos[i, 2]
                    

    final_positions = ca.MX.zeros((60 * 22, 3))
    for i in range(60):
        final_positions[i * 22, :] = r_pos[i, :]
        final_positions[i * 22 + 1:(i + 1) * 22, :] = rotated_positions[i * 21:(i + 1) * 21, :]
    
    return final_positions


# Helper functions
def qrot(q, v):
    qvec = q[:, 1:]
    q_real = q[:, :1]
    uv = ca.cross(qvec, v)
    uuv = ca.cross(qvec, uv)
    return v + 2 * (q_real * uv + uuv)

def qinv(q):
    mask = ca.MX.ones(q.shape)
    mask[:, 1:] = -1
    return q * mask





##########################
#     TORCH FUNCTIONS    #
##########################


def convert_to_xyz_torch(data, joints_num):
    r_rot_quat, r_pos = recover_root_rot_pos_torch(data)
    positions = data[..., 4:(joints_num - 1) * 3 + 4]
    positions = positions.view(positions.shape[:-1] + (-1, 3))
    
    '''Add Y-axis rotation to local joints'''
    positions = qrot_torch(qinv_torch(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)

    '''Add root XZ to joints'''
    positions[..., 0] += r_pos[..., 0:1]
    positions[..., 2] += r_pos[..., 2:3]
    

    '''Concate root and joints'''
    positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)

    return positions

def recover_root_rot_pos_torch(data):
    # print(data.shape)
    rot_vel = data[..., 0]
    r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
    '''Get Y-axis rotation from rotation velocity'''
    r_rot_ang[..., 1:] = rot_vel[..., :-1]
    r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
    
    r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
    r_rot_quat[..., 0] = torch.cos(r_rot_ang)
    r_rot_quat[..., 2] = torch.sin(r_rot_ang)

    r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
    r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
    '''Add Y-axis rotation to root position'''
    r_pos = qrot_torch(qinv_torch(r_rot_quat), r_pos)

    r_pos = torch.cumsum(r_pos, dim=-2)

    r_pos[..., 1] = data[..., 3]
    return r_rot_quat, r_pos

def qrot_torch(q, v):
    assert q.shape[-1] == 4
    assert v.shape[-1] == 3
    assert q.shape[:-1] == v.shape[:-1]
    original_shape = list(v.shape)
    q = q.contiguous().view(-1, 4)
    v = v.contiguous().view(-1, 3)
    qvec = q[:, 1:]
    uv = torch.cross(qvec, v, dim=1)
    uuv = torch.cross(qvec, uv, dim=1)
    return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)

def qinv_torch(q):
    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
    mask = torch.ones_like(q)
    mask[..., 1:] = -mask[..., 1:]
    # print(mask.sum())
    return q * mask

def recover_root_rot_pos(data):
    rot_vel = data[:,0]
    r_rot_ang = ca.MX.zeros(1, data.shape[0])
    r_rot_ang[:, 1:] = rot_vel[:-1]

    r_rot_ang = ca.cumsum(r_rot_ang)    

    r_rot_quat = ca.MX.zeros(data.shape[0], 4)
    r_rot_quat[:, 0] = ca.cos(r_rot_ang)
    r_rot_quat[:, 2] = ca.sin(r_rot_ang)

    r_pos = ca.MX.zeros(data.shape[0], 3)
    r_pos[1:, [0, 2]] = data[:-1, 1:3]
    r_pos = qrot(qinv(r_rot_quat), r_pos)
    r_pos = ca.cumsum(r_pos)
    r_pos[:, 1] = data[:, 3]

    return r_rot_quat, r_pos


# sample = torch.rand(1, 60, 263) * 10000
# for i, s in enumerate(sample):
#     pytorch_output = convert_to_xyz_torch(s.unsqueeze(0), 22)#.squeeze(0)
#     casadi_output = convert_to_xyz(ca.DM(s.numpy().flatten()[:, np.newaxis]), 22)
#     casadi_output_tensor = torch.tensor(evaluate_casadi_expression(casadi_output).reshape(60, 22, 3)).float()
#     print(i, torch.norm(casadi_output_tensor - pytorch_output, p=2)/pytorch_output.norm(p=2))
