import torch

from extensions.rl_poisoneddoors.poisoneddoors_experiments.base import (
    PoisonedDoorsBaseExperimentConfig,
)
from utils.experiment_utils import PipelineStage


class PoisonedDoorsPPOExperimentConfig(PoisonedDoorsBaseExperimentConfig):
    """Tackle (currently) the PoisonedDoors environment.

    Training with PPO.
    """

    GPU_ID = 4 if torch.cuda.is_available() else None
    USE_EXPERT = False

    @classmethod
    def extra_tag(cls):
        return "PPO__lr_{}".format(cls.lr())

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        loss_info = cls.rl_loss_default("ppo", steps=training_steps)
        return cls._training_pipeline(
            named_losses={"ppo_loss": loss_info["loss"],},
            num_mini_batch=loss_info["num_mini_batch"],
            update_repeats=loss_info["update_repeats"],
            pipeline_stages=[
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=training_steps,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                ),
            ],
            **kwargs,
        )
