from extensions.rl_lighthouse.lighthouse_experiments.base import (
    BaseLightHouseExperimentConfig,
)
from utils.experiment_utils import PipelineStage, LinearDecay


class LightHouseDaggerExperimentConfig(BaseLightHouseExperimentConfig):
    """Find goal in lighthouse env using imitation learning.

    Training with Dagger.
    """

    @classmethod
    def tag(cls):
        return "LightHouseDagger"

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        loss_info = cls.rl_loss_default("imitation")
        return cls._training_pipeline(
            named_losses={"imitation_loss": loss_info["loss"]},
            pipeline_stages=[
                PipelineStage(
                    loss_names=["imitation_loss"],
                    teacher_forcing=LinearDecay(
                        startp=1.0, endp=0.0, steps=training_steps // 2,
                    ),
                    max_stage_steps=training_steps // 2,
                ),
                PipelineStage(
                    loss_names=["imitation_loss"],
                    early_stopping_criterion=cls.get_early_stopping_criterion(),
                    max_stage_steps=training_steps // 2,
                ),
            ],
            num_mini_batch=loss_info["num_mini_batch"],
            update_repeats=loss_info["update_repeats"],
        )
