import numpy as np
import quaternion
from scipy.spatial.transform import Rotation as R
# see quaternion reference in https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html

graph_globals = [0.0, 0.0, -9.81]
position_norm = [1.0, 1.0, 1.0]
max_E_G = [1.0e6, 1.0e6]


def init_bend_rod_graph(num_nodes, length, angle, Young_M,  with_alpha=True):
    graph = dict()
    graph["globals"] = [graph_globals]
    graph["nodes"] = []
    graph["edges"] = []
    graph["senders"] = []
    graph["receivers"] = []

    radius = 0.05
    #radius = 0.5
    density = 1000.0
    v = [0, 0, 0]
    w = [0, 0, 0]
    li = length/num_nodes
    r_0 = R.from_rotvec(np.pi/2 * np.array([0.0, 1.0, 0.0]))
    r_angle = R.from_rotvec(np.pi * angle/180.0 * np.array([0.0, 1.0, 0.0]))
    r_1 = R.from_rotvec(np.pi/4 * np.array([0.0, 0.0, 1.0]))
    r = r_1 * (r_angle * r_0)
    quat = r.as_quat()
    quat_p = np.roll(quat, 1).tolist()
    for i in range(num_nodes):
        position = r.apply([0, 0, i*li]).tolist()
        if i == 0:
            indictor = 0
        else:
            indictor = 1

        if with_alpha:
            graph["nodes"].append(position + v + quat_p + w + [radius, density, li, i / (num_nodes - 1.0)])
        else:
            graph["nodes"].append(position + v + quat_p + w + [radius, density, li, indictor])
    for i in range(num_nodes-1):
        graph["edges"].append([0.0, 0.0, 0.0, Young_M/max_E_G[0], 0.24])
        graph["senders"].append(i)
        graph["receivers"].append(i+1)

    #print(graph)
    split_list = [num_nodes]
    return position_norm, max_E_G, split_list, graph


def init_helix_spring_graph(num_nodes, HR, HH, HW, Torsion_M, with_alpha=True):
    graph = dict()
    graph["globals"] = [graph_globals]
    graph["nodes"] = []
    graph["edges"] = []
    graph["senders"] = []
    graph["receivers"] = []

    radius = 0.05
    # radius = 0.05
    density = 1000.0
    v = [0, 0, 0]
    w = [0, 0, 0]
    phi_total = HW * 2 * np.pi
    h = - HH
    from_vec = np.array([0, 0, 1])
    for i in range(num_nodes):
        x1 = HR * np.cos(phi_total / num_nodes * i)
        y1 = HR * np.sin(phi_total / num_nodes * i)
        z1 = h/num_nodes * i
        p1 = np.array([x1, y1, z1])

        x2 = HR * np.cos(phi_total / num_nodes * (i+1))
        y2 = HR * np.sin(phi_total / num_nodes * (i+1))
        z2 = h/num_nodes * (i+1)
        p2 = np.array([x2, y2, z2])

        li = np.linalg.norm((p2-p1), ord=2)

        position = 0.5*(p1+p2)

        to_vec = p2 - p1
        from_norm = np.linalg.norm(from_vec, ord=2)
        to_norm = np.linalg.norm(to_vec, ord=2)
        # axis is normed axis, A is rotation angle
        # q = (cos A/2, sin A/2 * axis)
        # q * 2 * cos A/2 = (1 + cos A, sin A * axis)
        # axis = p1 X p2 / p1_norm / p2_norm / sin A
        axis_sinA_product = np.cross(from_vec, to_vec)/from_norm/to_norm
        cosA = np.dot(from_vec, to_vec)/from_norm/to_norm

        #print(axis_sinA_product)
        temp = np.array([1 + cosA, axis_sinA_product[0], axis_sinA_product[1], axis_sinA_product[2]])
        quat = temp / np.linalg.norm(temp)

        from_vec = to_vec
        if i==0:
            indictor = 0
            quat_p = quat
        else:
            indictor = 1
            dq = quaternion.as_quat_array(quat)
            q_prev = quaternion.as_quat_array(graph["nodes"][i-1][6:10])
            quat_p = quaternion.as_float_array(dq * q_prev)

        if with_alpha:
            node = position.tolist() + v + quat_p.tolist() + w + [radius, density, li, i / (num_nodes - 1.0)]
        else:
            node = position.tolist() + v + quat_p.tolist() + w + [radius, density, li, indictor]
        #print(node)
        graph["nodes"].append(node)

    for i in range(num_nodes-1):
        q1 = graph["nodes"][i][6:10]
        q2 = graph["nodes"][i+1][6:10]
        length = 0.5*(graph["nodes"][i][15] + graph["nodes"][i+1][15])

        Q1 = quaternion.as_quat_array(q1)
        Q2 = quaternion.as_quat_array(q2)
        restD = 2.0 / length * quaternion.as_float_array(Q1.conjugate() * Q2)[..., 1:4]

        graph["edges"].append(restD.tolist() + [0.1, Torsion_M/max_E_G[1]])
        graph["senders"].append(i)
        graph["receivers"].append(i+1)

    #print(graph)
    split_list = [num_nodes]
    return position_norm, max_E_G, split_list, graph


if __name__ == "__main__":
    #init_bend_rod_graph(30, 40, 0, 1.0e6)
    init_helix_spring_graph(30, 2, 2, 2, 1.0e5)

