import os
from pathlib import Path

import hydra
from hydra.utils import instantiate
import numpy as np
import matplotlib.pyplot as plt
from navsim.agents.abstract_agent import AbstractAgent
from navsim.common.dataloader import SceneLoader
from navsim.common.dataclasses import SceneFilter, SensorConfig
from navsim.visualization.plots import plot_bev_frame
from navsim.visualization.plots import plot_bev_with_agent, plot_camera_with_agent
from navsim.agents.constant_velocity_agent import ConstantVelocityAgent
from navsim.agents.gaussianfusion import transfuser_agent, transfuser_config
from navsim.visualization.plots import plot_cameras_frame


SPLIT = "test"  # ["mini", "test", "trainval"]
FILTER = "all_scenes"

hydra.initialize(
    config_path="../navsim/planning/script/config/common/train_test_split/scene_filter"
)
cfg = hydra.compose(config_name=FILTER)
scene_filter: SceneFilter = instantiate(cfg)
openscene_data_root = Path(os.getenv("OPENSCENE_DATA_ROOT"))

scene_loader = SceneLoader(
    openscene_data_root / f"navsim_logs/{SPLIT}",
    openscene_data_root / f"sensor_blobs/{SPLIT}",
    scene_filter,
    sensor_config=SensorConfig.build_all_sensors(),
)


agent_cfg = transfuser_config.TransfuserConfig()
agent = transfuser_agent.TransfuserAgent(
    agent_cfg,
    1e-4,
    "/home/yaya/source/navsim/ckpts/ablation_study/gf_no_implicit_raw_head.ckpt",
)

agent.initialize()
agent.eval()
agent.model_to_cuda()

token = np.random.choice(scene_loader.tokens)
scene = scene_loader.get_scene_from_token(token)

for token in scene_loader.tokens:
    scene = scene_loader.get_scene_from_token(token)
    frame_idx = scene.scene_metadata.num_history_frames - 1  # current frame
    fig, ax = plot_bev_with_agent(scene, agent)
    plt.show()
    fig, ax = plot_camera_with_agent(scene, agent)
    plt.show()

# fig, ax = plot_cameras_frame(scene, frame_idx)
# plt.show()
