from extensions.rl_lighthouse.lighthouse_experiments.advisor_ppo import (
    LightHouseAdvisorPPOExperimentConfig,
)
from onpolicy_sync.losses.advisor import AdvisorWeightedStage
from utils.experiment_utils import Builder, PipelineStage


class LightHouseAdvisorA2CExperimentConfig(LightHouseAdvisorPPOExperimentConfig):
    """A2C and Imitation with adaptive reweighting."""

    @classmethod
    def tag(cls):
        return "LightHouseAdvisorA2C"

    @classmethod
    def training_pipeline(cls, **kwargs):
        alpha = 20
        training_steps = cls.TOTAL_TRAIN_STEPS
        a2c_info = cls.rl_loss_default("a2c", steps=training_steps)

        return cls._training_pipeline(
            named_losses={
                "advisor_loss": Builder(
                    AdvisorWeightedStage,
                    kwargs={"rl_loss": a2c_info["loss"], "fixed_alpha": alpha,},
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["advisor_loss"],
                    early_stopping_criterion=cls.get_early_stopping_criterion(),
                    max_stage_steps=training_steps,
                ),
            ],
            num_mini_batch=a2c_info["num_mini_batch"],
            update_repeats=a2c_info["update_repeats"],
        )
