import math
import os
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import argparse

import habitat_sim
import numpy as np

from typing import Dict, List, Tuple

from habitat_utils import save_obs, make_cfg_mp3d, load_poses_from_file


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="../scene_datasets/mp3d",
        help="the directory where mp3d dataset is stored. This directory should contain val and train subfolders",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="data/UwV_walks",
        help="the directory where the generated data will be saved. This directory should contain val or train subfolders.",
    )
    parser.add_argument(
        "--pose_dir",
        type=str,
        default="../bellman/bellman-exploration/output/UvW_trajs",
        help="the file containing the poses to be used for generating the data",
    )
    parser.add_argument("--split", type=str, required=True)
    parser.add_argument("--scene_dir", type=str, required=True)
    args = parser.parse_args()
    split = args.split
    scene_dir = args.scene_dir
    root_dataset_dir = args.dataset_dir
    root_save_dir = args.save_dir
    
    scene_name = scene_dir
    scene_data_dir = f"{root_dataset_dir}/{scene_dir}/"
    save_dir = f"{root_save_dir}/{split}/{scene_dir}"

    scene_mesh = os.path.join(scene_data_dir, scene_name + ".glb")
    os.makedirs(save_dir, exist_ok=True)

    sim_settings = {
        "scene": scene_mesh,
        "default_agent": 0,
        "sensor_height": 1.5,
        "color_sensor": True,
        "depth_sensor": True,
        "semantic_sensor": True,
        "lidar_sensor": False,
        "move_forward": 0.2,
        "move_backward": 0.2,
        "turn_left": 5,
        "turn_right": 5,
        "look_up": 5,
        "look_down": 5,
        "look_left": 5,
        "look_right": 5,
        "width": 1080,
        "height": 720,
        "enable_physics": False,
        "seed": 42,
        "lidar_fov": 360,
        "depth_img_for_lidar_n": 20,
        "img_save_dir": save_dir,
    }
    os.environ["MAGNUM_LOG"] = "quiet"
    os.environ["HABITAT_SIM_LOG"] = "quiet"

    sim_cfg = make_cfg_mp3d(sim_settings, root_dataset_dir, scene_data_dir, scene_name)
    sim = habitat_sim.Simulator(sim_cfg)
    scene = sim.semantic_scene

    # # initialize the agent
    agent = sim.initialize_agent(sim_settings["default_agent"])
    agent_state = habitat_sim.AgentState()
    random_pt = sim.pathfinder.get_random_navigable_point()
    agent_state.position = random_pt
    agent.set_state(agent_state)

    agent_state = agent.get_state()

    init_agent_state = agent_state
    actions_list = []

    agent_height = agent_state.position[1]
    obs = sim.get_sensor_observations(0)
    last_action = None
    release_count = 0

    pose_file = os.path.join(args.pose_dir, split + ".txt")
    poses = np.loadtxt(pose_file)

    pbar = tqdm(poses, total=len(poses), desc="saving frames")
    steps = 0
    for pose in pbar:
        pbar.set_description(f"saving frame {steps}/{len(poses) + 1}")
        agent = sim.get_agent(0)

        # original code, with strange height
        # agent_state = agent.get_state()
        # agent_state.sensor_states["color_sensor"].position = pose[:3]
        # agent_state.sensor_states["color_sensor"].rotation = pose[3:]
        # agent_state.sensor_states["depth_sensor"].position = pose[:3]
        # agent_state.sensor_states["depth_sensor"].rotation = pose[3:]
        # agent_state.sensor_states["semantic"].position = pose[:3]
        # agent_state.sensor_states["semantic"].rotation = pose[3:]
        
        # code from VLMaps
        agent_state = habitat_sim.AgentState()
        agent_state.position = pose[:3]
        agent_state.rotation = pose[3:]

        agent.set_state(agent_state, reset_sensors=True, infer_sensor_states=False)
        obs = sim.get_sensor_observations(0)
        rgb = obs["color_sensor"]
        depth = obs["depth_sensor"]
        depth = ((depth / 10) * 255).astype(np.uint8)
        semantic = obs["semantic"]
        save_obs(save_dir, sim_settings, obs, pose, steps)
        steps += 1


if __name__ == "__main__":
    main()
