import gym
import numpy as np
import warmup
from matplotlib import pyplot as plt
from pudb import set_trace


def com_main_body(env):
    data = env.unwrapped.data
    model = env.unwrapped.model
    tot_mass = np.sum(model.body_subtreemass)
    com_full = 0
    for idx in range(1, model.nbody):
        com_full += model.body_subtreemass[idx] * data.body_xpos[idx]
    return com_full / tot_mass


def get_reward(env, com_before):
    com = com_main_body(env)
    return np.exp(com[2] - com_before[2])


# for env_str in ['Hopper-v3']:
for env_str in ["leg2dof4m-v0"]:
    data_collect = []
    env = gym.make(env_str)
    for ep in range(1):
        ep_steps = 0
        state = env.reset()
        # env.unwrapped.data.qpos[1] = 10.
        env.unwrapped.data.qvel[:] = 0.0
        action = np.random.randint(0, 2, size=env.action_space.shape)
        while True:
            com_before = com_main_body(env)
            if not ep_steps % 100:
                action = np.random.randint(-2, 2, size=env.action_space.shape)
                # action[:] = 0
            for idx in range(3):
                env.unwrapped.model.opt.gravity[idx] = 0
            next_state, reward, done, info = env.step(action)
            reward = get_reward(env, com_before)
            data_collect.append(
                (
                    com_main_body(env).copy(),
                    env.unwrapped.data.get_body_xpos("pelvis").copy(),
                    reward,
                )
            )
            # data_collect.append((com_main_body(env).copy(), [env.unwrapped.data.qpos[0], env.unwrapped.data.qpos[2],
            #                                                  env.unwrapped.data.qpos[1]], reward))
            if env.data.ncon > 0:
                if np.any(
                    [env.data.contact[idx].geom2 != 4 for idx in range(env.data.ncon)]
                ):
                    done = 1
            # env.render()
            if ep_steps >= 100000:  # (ep_steps >= env.max_episode_steps):
                if done:
                    print("did it!")
                break
            ep_steps += 1
    fig, axs = plt.subplots(3, 1)
    data_collect = data_collect[10:]
    axs[0].plot(
        [x[0][1] for x in data_collect], [x[0][2] for x in data_collect], label="com"
    )
    axs[0].plot(
        [x[1][1] for x in data_collect], [x[1][2] for x in data_collect], label="pelvis"
    )
    axs[0].set_ylabel("y")
    axs[1].plot([x[0][2] for x in data_collect], label="com")
    axs[1].plot([x[1][2] for x in data_collect], label="pelvis")
    axs[1].set_ylabel("z")
    axs[1].set_xlabel("time")
    axs[0].set_ylabel("z")
    axs[0].set_xlabel("y")
    axs[2].plot([x[2] for x in data_collect])
    axs[2].set_ylabel("reward")
    # axs[1].plot([x[0][2] for x in data_collect], label='com')
    # axs[1].plot([x[1][2] for x in data_collect], label='pelvis')
    # axs[1].set_ylabel('z')
    axs[0].legend()
    # axs[1].legend()
    plt.savefig(
        "com_plot_noframeskip_leg_motor_with_tendons_all_openai_frameskip_5.pdf"
    )
    # plt.show()
