import numpy as np


class FTG(object):
    def __init__(self, ):
        self.N = 6
        self.T = 0.3
        self.T_all = 2 * self.T

        # -----------------------------------
        # self.S_x = 0.06
        # self.S_x = 0.08

        self.S_x = 0.14
        # self.S_x = 0.1
        self.S_y = 0.
        self.H_z = 0.18
        # self.H_z = 0.14

        # self.f_0 = np.array([1.])  # np.array([1.25])
        # self.f_0 = np.array([1.1])  # np.array([1.25])
        # self.f_0 = np.array([0.6])  # np.array([1.25])
        self.f_0 = np.array([1.1])  # np.array([1.25])

        self.phi_0 = np.array([0, self.T, self.T, 0])

        self.phi = self.phi_0
        # self.t_list = np.array([0, 0 ,0 , 0])
        self.desired_turn_vel = 0.2
        # self.f_i = np.array([0,0.1,-0.1,0.15])
        # self.f_i = np.array([-0.1, 0., -0., 0.])
        # self.phi = np.array([0, self.T_all, 3/2*self.T_all, self.T_all/2])

        self.beta = np.array([0, 1, 2, 3, 4, 5, 6]) / (self.N + 1)
        # self.target_foot_pos = np.array([0.183, -0.13205, -0.24864399,
        #                                  0.183, 0.13205, - 0.24864399,
        #                                  -0.183, -0.13205, -0.24864399,
        #                                  -0.183, 0.13205, - 0.24864399])
        self.target_foot_pos = np.array([0.16840759, -0.13204999, -0.32405686,
                                      0.16840759, 0.13204999, -0.32405686,
                                      -0.19759241, -0.13204999, -0.32405686,
                                      -0.19759241, 0.13204999, -0.32405686])
        self.foot_pos_last = self.target_foot_pos

        self.UPPER_BOUND, self.LOWER_BOUND = [0.35, -0.0, -0.1,
                                    0.35, 0.2, -0.1,
                                    -0.05, -0.0, -0.1,
                                    -0.05, 0.2, -0.1], \
                                   [0.05, -0.2, -0.35,
                                    0.05, 0.05, -0.35,
                                    -0.3, -0.2, -0.35,
                                    -0.3, 0.05, -0.35]

    def x_trajectory(self, t, S_x, isSwing=True):
        if isSwing:
            x = S_x * (t / self.T - 1 / (2 * np.pi) * np.sin((2 * np.pi * t / self.T)))
        else:
            x = S_x * ((self.T_all - t) / self.T + 1 / (2 * np.pi) * np.sin((2 * np.pi * t / self.T)))
        # x = np.clip(x,a_min=0,a_max=S_x)
        return x

    # def y_trajectory(self, t):
    #     if 0 <= t < self.alpha[0] * self.T:
    #         y = 0
    #     elif self.alpha[0] * self.T <= t < (1 - self.alpha[1]) * self.T:
    #         temp1 = (6 / (self.T_y ** 5)) * self.S_y * ((t - self.alpha[0] * self.T) ** 5)
    #         temp2 = (15 / (self.T_y ** 4)) * self.S_y * ((t - self.alpha[0] * self.T) ** 4)
    #         temp3 = (10 / (self.T_y ** 3)) * self.S_y * ((t - self.alpha[0] * self.T) ** 3)
    #         y = temp1 - temp2 + temp3
    #     elif (1 - self.alpha[1]) * self.T <= t:
    #         y = self.S_y
    #     return y

    def z_trajectory(self, t, H_z, isSwing=True):
        if isSwing:
            f_E = t / self.T - 1 / (4 * np.pi) * np.sin(4 * np.pi * t / self.T)
            z = H_z * (np.sign(self.T / 2 - t) * (2 * f_E - 1) + 1)
        else:
            z = 0
        # z = np.clip(z, a_min=0, a_max=H_z)
        return z

    # def reset(self):
    #     pass
        # self.S_x = np.random.uniform(0,0.1)
        # self.S_y = np.random.uniform(-0.06,0.06,size=4)
        # self.H_z = np.random.uniform(0,0.14)
        # self.f_0 = np.random.uniform(1,1.5)

    def reset(self,DesiredDirection,DesiredTurningDirection):
        if DesiredDirection.any() is False \
            and DesiredTurningDirection.any() is False:
            self.f_0 = np.array([0])
        else:
            self.f_0 = np.array([1.])

    def step(self, t,
             frequency_offset,
             target_foot_pos_res,
             ):
        z_pos_list = np.zeros(shape=4)
        x_pos_list = np.zeros(shape=4)

        # self.t_list = (t*(self.f_0) + self.phi) % self.T_all
        # self.phi = (t * self.f_0 + self.T_all + self.phi_0) % self.T_all
        self.phi = (t * self.f_0 + self.phi_0 + frequency_offset) % self.T_all

        # print(self.phi)

        for j in range(len(self.phi)):
            t = self.phi[j]
            if 0 <= t < self.T:
                x_pos_list[j] = self.x_trajectory(t, self.S_x)
                z_pos_list[j] = self.z_trajectory(t, self.H_z)
            elif self.T <= t < self.T_all:
                x_pos_list[j] = self.x_trajectory(t, self.S_x, isSwing=False)
                z_pos_list[j] = self.z_trajectory(t, self.H_z, isSwing=False)

            if j in [2, 3]:
                x_pos_list[j] -= self.S_x

        target_foot_pos = np.zeros(shape=12)
        for i in range(len(x_pos_list)):
            target_foot_pos[3 * i] = x_pos_list[i] + self.target_foot_pos[
                3 * i] + target_foot_pos_res[3 * i]
            target_foot_pos[3 * i + 1] = self.S_y + self.target_foot_pos[3 * i + 1]  + target_foot_pos_res[3*i+1]
            target_foot_pos[3 * i + 2] = z_pos_list[i] + self.target_foot_pos[
                3 * i + 2] + target_foot_pos_res[3*i+2]
        target_foot_pos = np.clip(target_foot_pos, a_min=self.LOWER_BOUND, a_max=self.UPPER_BOUND)
        # print(target_foot_pos)
        # input()
        return target_foot_pos


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    ftg = FTG()
    T = 200
    dt = 0.001
    N = 5
    S_x_max = 0.1
    H_z_max = 0.15
    S_x_min = 0.0
    H_z_min = 0.
    f_max = 0.1
    f_min = -0.1

    foot_pos = np.array([0.16840759, -0.13204999, -0.32405686,
                         0.16840759, 0.13204999, -0.32405686,
                         -0.19759241, -0.13204999, -0.32405686,
                         -0.19759241, 0.13204999, -0.32405686])
    z_pos_list = np.zeros(shape=(T, 4))
    x_pos_list = np.zeros(shape=(T, 4))
    # z_pos_list1 = np.zeros(shape=(T,4))
    # x_pos_list1 = np.zeros(shape=(T,4))

    for i in range(T):
        target_foot_position = np.squeeze(ftg.step(t=i * N * dt,
                                                   f_i=np.zeros(shape=4),
                                                   S_x=np.random.uniform(low=S_x_min, high=S_x_max, size=(4)),
                                                   H_z=np.random.uniform(low=H_z_min, high=H_z_max, size=(4)),
                                                   target_foot_pos_res=foot_pos))
        x_pos_list[i, :] = target_foot_position[0::3]
        z_pos_list[i, :] = target_foot_position[2::3]

        # target_foot_position = np.squeeze(ftg.step1(t=i * N * dt, target_foot_pos_res=foot_pos))
        # x_pos_list1[i, :] = target_foot_position[0::3]
        # z_pos_list1[i, :] = target_foot_position[2::3]

    # for j in [0,1]:
    #     plt.plot(z_pos_list[:,j])
    # plt.show()
    # for j in [2,3]:
    #     plt.plot(x_pos_list[:,j])
    # plt.show()

    plt.figure()
    plt.subplot(221)
    plt.plot(x_pos_list[:, 0])
    plt.plot(z_pos_list[:, 0])
    plt.subplot(222)
    plt.plot(x_pos_list[:, 1])
    plt.plot(z_pos_list[:, 1])
    plt.subplot(223)
    plt.plot(x_pos_list[:, 2])
    plt.plot(z_pos_list[:, 2])
    plt.subplot(224)
    plt.plot(x_pos_list[:, 3])
    plt.plot(z_pos_list[:, 3])
    plt.show()
    # plt.figure()
    # plt.subplot(121)
    # plt.plot(x_pos_list[:,0])
    # plt.plot(x_pos_list1[:,0])
    # plt.subplot(122)
    # plt.plot(z_pos_list[:,0])
    # plt.plot(z_pos_list1[:,0])
    # plt.show()
    plt.figure()
    plt.subplot(221)
    plt.plot(x_pos_list[:, 0], z_pos_list[:, 0])
    plt.subplot(222)
    plt.plot(x_pos_list[:, 1], z_pos_list[:, 1])
    plt.subplot(223)
    plt.plot(x_pos_list[:, 2], z_pos_list[:, 2])
    plt.subplot(224)
    plt.plot(x_pos_list[:, 3], z_pos_list[:, 3])
    plt.show()

    # ftg = FTG()
    # T=100
    # dt=0.001
    # N=20
    # foot_pos = np.array([0.16840759, -0.13204999, -0.32405686,
    #                               0.16840759, 0.13204999, -0.32405686,
    #                               -0.19759241, -0.13204999, -0.32405686,
    #                               -0.19759241, 0.13204999, -0.32405686])
    # z_pos_list = np.zeros(shape=(T,4))
    # x_pos_list = np.zeros(shape=(T,4))
    # for i in range(T):
    #     t_list = (i * N * dt + ftg.phi) % ftg.T_all
    #     for j in range(len(t_list)):
    #         t = t_list[j]
    #         if 0<=t<ftg.T:
    #             x_pos_list[i, j] = ftg.x_trajectory(t)
    #             z_pos_list[i,j] = ftg.z_trajectory(t)
    #         elif ftg.T<=t<ftg.T_all:
    #             x_pos_list[i,j] = ftg.x_trajectory(t,isSwing=False)
    #             z_pos_list[i,j] = ftg.z_trajectory(t,isSwing=False)
    #
    #         if j in [2, 3]:
    #             x_pos_list[i, j] -= ftg.S_x
    #
    # plt.subplot(221)
    # plt.plot(x_pos_list[:, 0])
    # plt.plot(z_pos_list[:, 0])
    # plt.subplot(222)
    # plt.plot(x_pos_list[:, 1])
    # plt.plot(z_pos_list[:, 1])
    # plt.subplot(223)
    # plt.plot(x_pos_list[:, 2])
    # plt.plot(z_pos_list[:, 2])
    # plt.subplot(224)
    # plt.plot(x_pos_list[:, 3])
    # plt.plot(z_pos_list[:, 3])
    # plt.show()

    # plt.scatter(x_pos_list,z_pos_list)
    # plt.show()

    # plt.subplot(211)
    # plt.figure()
    # plt.plot(x_pos_list)
    # plt.plot(z_pos_list)
    # plt.show()
