"""
Directional Ant environment classes.
These environments reward the ant for moving in specific directions.
"""

import numpy as np
import mujoco as mj
import os
from gymnasium.envs.mujoco.ant_v4 import AntEnv
from gymnasium.envs.mujoco.humanoid_v4 import HumanoidEnv
from gymnasium.envs.mujoco.swimmer_v4 import SwimmerEnv
from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv
import gymnasium as gym

from dotenv import load_dotenv

# Load environment variables
load_dotenv()
BASE_XML_DIR = os.getenv('BASE_XML_DIR')


class AntDir(AntEnv):
    """
    Base class for directional ant environments.
    Rewards the ant for moving in a specific direction.
    """
    def __init__(self, test_mode=False, **kwargs):
        super(AntDir, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.2, 1.0),
            xml_file=f"{BASE_XML_DIR}/ant.xml",
            **kwargs,
        )

        self.step_count = 0
        self.max_steps = 1000
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
    
    def step(self, action):
        self.step_count += 1
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        rewards = self.calculate_reward(xy_position_before, xy_position_after)

        terminated = self.terminated
        observation = self._get_obs()
        
        position_torso = self.get_body_com("torso").copy()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "reward_forward": rewards,
            "x_position": xy_position_after[0],
            "y_position": xy_position_after[1],
            "distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
            "x_velocity": x_velocity,
            "y_velocity": y_velocity,
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }

        if self.render_mode == "human":
            self.render()
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True
        terminated = terminated or truncated
         
        return observation, rewards, terminated, truncated, info
      
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.step_count = 0
        qpos = self.init_qpos 
        qvel = self.init_qvel
        self.set_state(qpos, qvel)
        observation = self._get_obs()

        return observation


class AntPosX(AntDir):
    """Ant environment that rewards positive x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return x_velocity


class AntNegX(AntDir):
    """Ant environment that rewards negative x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -x_velocity


class AntPosY(AntDir):
    """Ant environment that rewards positive y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return y_velocity


class AntNegY(AntDir):
    """Ant environment that rewards negative y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -y_velocity
