import torch

from extensions.rl_minigrid.minigrid_experiments.base import (
    MiniGridBaseExperimentConfig,
)
from utils.experiment_utils import PipelineStage, LinearDecay


class MiniGridDaggerThenPPOExperimentConfig(MiniGridBaseExperimentConfig):
    """Training with DAgger and then PPO."""

    GPU_ID = 1 if torch.cuda.is_available() else None
    USE_EXPERT = True

    @classmethod
    def extra_tag(cls):
        return "DaggerThenPPO__lr_{}__tf_{}".format(cls.lr(), cls.tf_ratio(),)

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        steps_dagger_stage = int(training_steps * cls.tf_ratio())
        steps_ppo_stage = training_steps - steps_dagger_stage

        ppo_info = cls.rl_loss_default("ppo", steps=steps_ppo_stage)
        imitation_info = cls.rl_loss_default("imitation")

        return cls._training_pipeline(
            named_losses={
                "imitation_loss": imitation_info["loss"],
                "ppo_loss": ppo_info["loss"],
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    max_stage_steps=steps_dagger_stage,
                    teacher_forcing=LinearDecay(
                        startp=1.0, endp=0.0, steps=steps_dagger_stage,
                    ),
                ),
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=steps_ppo_stage,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                ),
            ],
            num_mini_batch=min(
                info["num_mini_batch"] for info in [ppo_info, imitation_info]
            ),
            update_repeats=min(
                info["update_repeats"] for info in [ppo_info, imitation_info]
            ),
        )
