import random
import numpy as np
from gymnasium import spaces

from src.envs.ant import AntEnv
from src.utils.misc import set_seed


class AntDirEnv(AntEnv):
    def __init__(self, goal_pos=None):
        if goal_pos is None:
            goal_pos = np.array(self.sample_tasks(1))
        self.set_task(goal_pos)
        self.task_dim = 1
        super(AntEnv, self).__init__()
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(27,))

    def set_task(self, task):
        if isinstance(task, np.ndarray):
            task = task[0]
        self._goal = task

    def get_task(self):
        return np.array([self._goal])

    def step(self, action):
        torso_xyz_before = np.array(self.get_body_com("torso"))

        direct = (np.cos(self._goal), np.sin(self._goal))

        self.do_simulation(action, self.frame_skip)

        # goal_marker_idx = self.sim.model.site_name2id('goal')
        #
        # self.data.site_xpos[goal_marker_idx, :2] = 5 * np.array([np.cos(self._goal), np.sin(self._goal)])
        # self.data.site_xpos[goal_marker_idx, -1] = 1

        torso_xyz_after = np.array(self.get_body_com("torso"))
        torso_velocity = torso_xyz_after - torso_xyz_before
        forward_reward = np.dot((torso_velocity[:2] / self.dt), direct)

        ctrl_cost = .5 * np.square(action).sum()
        contact_cost = 0.5 * 1e-3 * np.sum(
            np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
        survive_reward = 1.0
        reward = forward_reward - ctrl_cost - contact_cost + survive_reward
        state = self.state_vector()
        notdone = np.isfinite(state).all() \
                  and state[2] >= 0.2 and state[2] <= 1.0
        # done = not notdone
        done = False
        ob = self._get_obs()
        return ob, reward, done, done, dict(
            reward_forward=forward_reward,
            reward_ctrl=-ctrl_cost,
            reward_contact=-contact_cost,
            reward_survive=survive_reward,
            torso_velocity=torso_velocity,
        )

    def sample_tasks(self, n_tasks):
        return [np.random.uniform(0., 2.0 * np.pi) for _ in range(n_tasks)]

    def pos_to_state(self, arg):
        return tuple(arg.tolist())


def train_test_goals_ant(num_test_goals, seed):
    num_train_goals = 1000
    set_seed(seed)
    train_goals = np.array([[np.random.uniform(0., 2.0 * np.pi)] for _ in range(num_train_goals)])
    set_seed(seed + 1)
    test_goals = np.array([[np.random.uniform(0., 2.0 * np.pi)] for _ in range(num_test_goals)])
    return train_goals, test_goals


if __name__ == "__main__":
    import gymnasium as gym

    gym.register(
        'AntDir-v0',
        entry_point='src.envs.ant_dir:AntDirEnv',
        max_episode_steps=200,
        kwargs={},
    )

    goal = np.array([1])
    env = gym.make("AntDir-v0", goal_pos=goal)
    obs, _ = env.reset()
    print(obs.shape)
