from extensions.rl_poisoneddoors.poisoneddoors_experiments.base import (
    PoisonedDoorsBaseExperimentConfig,
)
from onpolicy_sync.losses.advisor import AdvisorWeightedStage, LinearAlphaScheduler
from utils.experiment_utils import Builder, PipelineStage


class PoisonedDoorsAdvisorFixedAlphaDiffHeadsExperimentConfig(
    PoisonedDoorsBaseExperimentConfig
):
    """Tackle (currently) the PoisonedDoors environment.

    Training with adaptive reweighing.
    """

    USE_EXPERT = True
    SAME_WEIGHTS_FOR_ADVISOR_HEAD = False
    INCLUDE_AUXILIARY_HEAD = True

    @classmethod
    def extra_tag(cls):
        return "AdvisorDiffHeadsFixedAlpha__alpha_{}__lr_{}".format(
            cls.alpha(), cls.lr(),
        )

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=training_steps)
        return cls._training_pipeline(
            named_losses={
                "advisor_loss": Builder(
                    AdvisorWeightedStage,
                    kwargs={
                        "rl_loss": ppo_info["loss"],
                        "alpha_scheduler": LinearAlphaScheduler(
                            cls.anneal_alpha_start(),
                            cls.anneal_alpha_stop(),
                            training_steps,
                        ),
                    },
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["advisor_loss"],
                    max_stage_steps=training_steps,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                ),
            ],
            num_mini_batch=ppo_info["num_mini_batch"],
            update_repeats=ppo_info["update_repeats"],
        )
