"""
Base Ant environment classes.
"""

import numpy as np
import mujoco as mj
import os
import gymnasium as gym
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.envs.mujoco.ant_v4 import AntEnv
from gymnasium.spaces import Box
from gymnasium import utils
from warnings import filterwarnings
from utils import *
from dotenv import load_dotenv

# Load environment variables
load_dotenv()
BASE_XML_DIR = os.getenv('BASE_XML_DIR')

# Filter warnings
filterwarnings(action="ignore", category=DeprecationWarning, 
               message="`np.bool8` is a deprecated alias for `np.bool_`")


class RandomTerranAnt(AntEnv):
    """
    Ant environment with randomized terrain.
    """
    def __init__(self, test_mode=False, **kwargs):
        super(RandomTerranAnt, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant_random_terran.xml",
            **kwargs,
        )
        self.test_model = test_mode




class AntRDir(AntEnv):
    def __init__(self, test_mode=False, **kwargs):
        super(AntRDir, self).__init__(
            exclude_current_positions_from_observation=True,
            healthy_z_range=(0.1, 10.0),
            xml_file=f"{BASE_XML_DIR}/ant_random_terran.xml",
            **kwargs,
        )

        self.step_count = 0
        self.max_steps = 1000
        
    def calculate_reward(self, x_velocity, y_velocity):
        """Calculate reward based on velocities. To be implemented by subclasses."""
        return 0
    
    def step(self, action):
        self.step_count += 1
        xy_position_before = self.get_body_com("torso")[:2].copy()
        self.do_simulation(action, self.frame_skip)
        xy_position_after = self.get_body_com("torso")[:2].copy()

        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity

        rewards = self.calculate_reward(xy_position_before, xy_position_after)

        terminated = self.terminated
        observation = self._get_obs()
        
        position_torso = self.get_body_com("torso").copy()
        
        front_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_left_leg_site")][2]
        front_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "front_right_leg_site")][2]
        back_left_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_left_leg_site")][2]
        back_right_leg_height = self.data.site_xpos[mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_SITE, "back_right_leg_site")][2]
        
        info = {
            "reward_forward": rewards,
            "x_position": xy_position_after[0],
            "y_position": xy_position_after[1],
            "distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
            "x_velocity": x_velocity,
            "y_velocity": y_velocity,
            "ant_y_position": position_torso[1],
            "distance_from_origin": np.linalg.norm(position_torso, ord=2),
            "front_left_leg_height": front_left_leg_height,
            "front_left_leg_site": self.data.sensordata[1],
            "front_right_leg_height": front_right_leg_height,
            "front_right_leg_site": self.data.sensordata[2],
            "back_left_leg_height": back_left_leg_height,
            "back_left_leg_site": self.data.sensordata[3],
            "back_right_leg_height": back_right_leg_height,
            "back_right_leg_site": self.data.sensordata[4],
        }

        if self.render_mode == "human":
            self.render()
        truncated = False
        if self.step_count >= self.max_steps:
            truncated = True
        terminated = terminated or truncated 
        return observation, rewards, terminated, truncated, info
      
    def get_obs(self):
        return self._get_obs() 
    
    def set_obs(self, obs):
        qpos = obs[:self.model.nq]
        qvel = obs[self.model.nq:]
        self.set_state(qpos, qvel)
        return self._get_obs()

    def reset_model(self):
        self.step_count = 0
        qpos = self.init_qpos 
        qvel = self.init_qvel
        self.set_state(qpos, qvel)
        observation = self._get_obs()

        return observation


class AntRPosX(AntRDir):
    """Ant environment that rewards positive x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return x_velocity


class AntRNegX(AntRDir):
    """Ant environment that rewards negative x velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -x_velocity


class AntRPosY(AntRDir):
    """Ant environment that rewards positive y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return y_velocity


class AntRNegY(AntRDir):
    """Ant environment that rewards negative y velocity"""
    def calculate_reward(self, xy_position_before, xy_position_after):
        xy_velocity = (xy_position_after - xy_position_before) / self.dt
        x_velocity, y_velocity = xy_velocity
        return -y_velocity


import numpy as np
import struct
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def generate_heightfield(width=100, height=100, max_height=0.0):
    """
    Generate a heightfield with very dense low hills using a combination of sine waves
    
    Args:
        width: Width of the heightfield (number of grid points)
        height: Height of the heightfield (number of grid points)
        max_height: Maximum height of the hills
    
    Returns:
        2D numpy array of height values
    """
    # Create coordinate grids
    x = np.linspace(0, 20, width)  # Increased range for more cycles
    y = np.linspace(0, 20, height)  # Increased range for more cycles
    xx, yy = np.meshgrid(x, y)
    
    # # Generate multiple sine waves with different frequencies and amplitudes
    terrain = np.zeros((height, width))
    
    # # Just use one or two high-frequency components as requested
    terrain += np.sin(xx * 10.0 ) * np.sin(yy * 10.0 )  # High frequency for small bumps
    
    
    # # Add some Perlin-like noise for additional small bumps
    # noise = np.random.normal(0, 0.01, (height, width))
    # from scipy.ndimage import gaussian_filter
    # noise = gaussian_filter(noise, sigma=1.0)  # Lower sigma for finer details
    # terrain += noise
    
    # # Normalize to [0, max_height]
    terrain = (terrain - terrain.min()) / (terrain.max() - terrain.min()) * max_height
    
    return terrain

def save_heightfield(heightfield, filename):
    """
    Save the heightfield as a binary file for MuJoCo
    
    Args:
        heightfield: 2D numpy array of height values
        filename: Output binary file name
    """
    # Convert to float32 and flatten to row-major order
    flat_heightfield = heightfield.astype(np.float32).flatten()
    
    # Save to binary file
    with open(filename, 'wb') as f:
        f.write(struct.pack('i', heightfield.shape[0]))  # Width
        f.write(struct.pack('i', heightfield.shape[1]))  # Height
        f.write(flat_heightfield.tobytes())

def visualize_heightfield(heightfield):
    """
    Visualize the heightfield as a 3D surface
    
    Args:
        heightfield: 2D numpy array of height values
    """
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    x = np.arange(0, heightfield.shape[1])
    y = np.arange(0, heightfield.shape[0])
    xx, yy = np.meshgrid(x, y)
    
    surf = ax.plot_surface(xx, yy, heightfield, cmap='terrain', linewidth=0, antialiased=True)
    fig.colorbar(surf, ax=ax)
    
    ax.set_title('Generated Heightfield')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Height')
    
    plt.show()

def main():
    # Parameters based on the XML file
    width = 200
    height = 200  # Increased resolution for more detailed terrain
    max_height = 0.00001  # Reduced max height for lower hills (changed from 0.5)
    
    # Generate heightfield
    print(f"Generating {width}x{height} heightfield...")
    heightfield = generate_heightfield(width, height, max_height)
    
    # Save to binary file
    filename = "heightfield.bin"
    save_heightfield(heightfield, filename)
    print(f"Saved heightfield to {filename}")
    
    # Visualize the generated terrain
    # print("Visualizing heightfield...")
    # visualize_heightfield(heightfield)

if __name__ == "__main__":
    main()