import os
import torch
from pathlib import Path
from torch import nn
from benchmarl.algorithms import MappoConfig
from benchmarl.experiment import Experiment, ExperimentConfig
from benchmarl.models import MlpConfig, GnnConfig, SequenceModelConfig
from experimental_scenarios.narrowcorridor_task import VmasNarrowCorridorTask
from experimental_scenarios.connectivity_task import VmasConnectivityTask  
from experimental_scenarios.waypoint_task import VmasWaypointTask
from experimental_scenarios.sensorcoverage_task import VmasSensorCoverageTask

# Choose scenario:
USE_CONNECTIVITY_ENV = 1  
USE_NARROWCORRIDOR_ENV = 2
USE_NARROWCORRIDOR_ROBOMASTER_ENV = 3
USE_WAYPOINT_ENV = 4
USE_SENSORCOVERAGE_ENV = 5

# What task are we training?
# Modify the line below to choose a different scenario to run.
# For example, to run the connectivity environment, set: mode = USE_CONNECTIVITY_ENV
mode = USE_WAYPOINT_ENV

# Load base experiment config
experiment_config = ExperimentConfig.get_from_yaml("./configs/experiment_config.yaml")
experiment_config.checkpoint_interval = 120000 # Set this to save model checkpoints
experiment_config.save_folder = Path(os.path.dirname(os.path.realpath(__file__)))
experiment_config.restore_file = None

# Load task depending on chosen scenario
if mode == USE_CONNECTIVITY_ENV:
    task = VmasConnectivityTask.CONNECTIVITY.get_from_yaml("./configs/connectivity_task_config.yaml")
elif mode == USE_NARROWCORRIDOR_ENV:
    task = VmasNarrowCorridorTask.NARROWCORRIDOR.get_from_yaml("./configs/narrowcorridor_task_config.yaml")
elif mode == USE_NARROWCORRIDOR_ROBOMASTER_ENV:
    # Using NarrowCorridorEnv with Robomaster
    task = VmasNarrowCorridorTask.NARROWCORRIDOR.get_from_yaml("./configs/robot_experiment_settings/robomaster_task_config.yaml")
elif mode == USE_WAYPOINT_ENV:
    # Using NavigationEnv
    task = VmasWaypointTask.WAYPOINT.get_from_yaml("./configs/waypoint_task_config.yaml")
elif mode == USE_SENSORCOVERAGE_ENV:
    # Using SensorCoverageEnv
    task = VmasSensorCoverageTask.SENSORCOVERAGE.get_from_yaml("./configs/sensorcoverage_task_config.yaml")

# Load algorithm config (e.g., MAPP0)
algorithm_config = MappoConfig.get_from_yaml("./configs/mappo_conf.yaml")

# Define model configs 
model_config = SequenceModelConfig(
    model_configs=[
        GnnConfig.get_from_yaml("./configs/actor_gnn.yaml"),
        MlpConfig(num_cells=[128], activation_class=nn.ReLU, layer_class=nn.Linear)
    ],
    intermediate_sizes=[128],
)

critic_model_config = SequenceModelConfig(
    model_configs=[
        GnnConfig.get_from_yaml("./configs/critic_gnn.yaml"),
        MlpConfig(num_cells=[128], activation_class=nn.Tanh, layer_class=nn.Linear),
    ],
    intermediate_sizes=[128],
)

# Create and run the experiment
experiment = Experiment(
    task=task,
    algorithm_config=algorithm_config,
    model_config=model_config,
    critic_model_config=critic_model_config,
    seed=102,
    config=experiment_config
)


# Begin training
experiment.run()