import torch

from extensions.rl_poisoneddoors.poisoneddoors_experiments.advisor_fixed_alpha_different_heads import (
    PoisonedDoorsAdvisorFixedAlphaDiffHeadsExperimentConfig,
)
from onpolicy_sync.losses.advisor import (
    AdvisorImitationStage,
    AdvisorWeightedStage,
)
from utils.experiment_utils import Builder, PipelineStage, LinearDecay


class PoisonedDoorsDaggerThenAdvisorFixedAlphaDiffHeadsExperimentConfig(
    PoisonedDoorsAdvisorFixedAlphaDiffHeadsExperimentConfig
):
    """Training with behavior cloning (teacher forcing of 1) followed by
    adaptive reweighing."""

    GPU_ID = 1 if torch.cuda.is_available() else None
    USE_EXPERT = True
    SAME_WEIGHTS_FOR_ADVISOR_HEAD = False

    @classmethod
    def extra_tag(cls):
        return "BCTeacherForcingThenAdvisorDiffHeadsFixedAlpha__alpha_{}__lr_{}__tf_{}".format(
            cls.alpha(), cls.lr(), cls.tf_ratio(),
        )

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        steps_advisor_warmup_stage = int(training_steps * cls.tf_ratio())
        steps_advisor_weighted_stage = training_steps - steps_advisor_warmup_stage

        ppo_info = cls.rl_loss_default("ppo", steps=steps_advisor_weighted_stage)
        fixed_alpha = cls.alpha()
        return cls._training_pipeline(
            named_losses={
                "advisor_imitation_warmup": Builder(AdvisorImitationStage,),
                "advisor_loss": Builder(
                    AdvisorWeightedStage,
                    kwargs={"rl_loss": ppo_info["loss"], "fixed_alpha": fixed_alpha,},
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["advisor_imitation_warmup"],
                    max_stage_steps=steps_advisor_warmup_stage,
                    teacher_forcing=LinearDecay(
                        startp=1.0, endp=1.0, steps=steps_advisor_warmup_stage,
                    ),
                ),
                PipelineStage(
                    loss_names=["advisor_loss"],
                    max_stage_steps=steps_advisor_weighted_stage,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                ),
            ],
            num_mini_batch=ppo_info["num_mini_batch"],
            update_repeats=ppo_info["update_repeats"],
        )
