from extensions.rl_poisoneddoors.poisoneddoors_experiments.base import (
    PoisonedDoorsBaseExperimentConfig,
)

from extensions.rl_poisoneddoors.poisoneddoors_offpolicy import (
    PoisonedDoorsOffPolicyAdvisorLoss,
)
from onpolicy_sync.losses.advisor import LinearAlphaScheduler
from utils.experiment_utils import PipelineStage, OffPolicyPipelineComponent


class PPOWithOffPolicyAdvisorFixedAlphaDifferentHeadsLevelExperimentConfig(
    PoisonedDoorsBaseExperimentConfig
):
    """PPO and Imitation with adaptive reweighting."""

    SAME_WEIGHTS_FOR_ADVISOR_HEAD = False
    FIXED_ALPHA = True
    INCLUDE_AUXILIARY_HEAD = True

    @classmethod
    def extra_tag(cls):
        return "PPODiffHeadsFixedAlphaAdvisorOffPolicy__alpha_{}__lr_{}".format(
            cls.alpha(), cls.lr(),
        )

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=training_steps)
        offpolicy_demo_info = cls.offpolicy_demo_defaults(also_using_ppo=True)

        fixed_alpha = None if not cls.FIXED_ALPHA else cls.alpha()
        alpha_scheduler = (
            None
            if cls.FIXED_ALPHA
            else LinearAlphaScheduler(
                cls.anneal_alpha_start(), cls.anneal_alpha_stop(), training_steps,
            )
        )

        assert (
            cls.FIXED_ALPHA and (fixed_alpha is not None and alpha_scheduler is None)
        ) or (
            (not cls.FIXED_ALPHA)
            and (fixed_alpha is None and alpha_scheduler is not None)
        )

        return cls._training_pipeline(
            named_losses={
                "ppo_loss": ppo_info["loss"],
                "offpolicy_advisor_loss": PoisonedDoorsOffPolicyAdvisorLoss(
                    fixed_alpha=fixed_alpha, alpha_scheduler=alpha_scheduler,
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=training_steps,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                    offpolicy_component=OffPolicyPipelineComponent(
                        data_iterator_builder=offpolicy_demo_info[
                            "data_iterator_builder"
                        ],
                        loss_names=["offpolicy_advisor_loss"],
                        updates=offpolicy_demo_info["offpolicy_updates"],
                    ),
                ),
            ],
            num_mini_batch=offpolicy_demo_info["ppo_num_mini_batch"],
            update_repeats=offpolicy_demo_info["ppo_update_repeats"],
        )
