import logging
from typing import Optional, List, Tuple

from gym_minigrid.minigrid import MiniGridEnv

from extensions.rl_babyai.babyai_offpolicy import BabyAIOffPolicyExpertCELoss
from extensions.rl_minigrid.minigrid_experiments.base import (
    MiniGridBaseExperimentConfig,
)
from utils.experiment_utils import PipelineStage, OffPolicyPipelineComponent

LOGGER = logging.getLogger("embodiedrl")


class PPOBabyAIBossLevelExperimentConfig(MiniGridBaseExperimentConfig):
    """PPO and off policy imitation."""

    DATASET: Optional[List[Tuple[str, bytes, List[int], MiniGridEnv.Actions]]] = None

    @classmethod
    def extra_tag(cls):
        return "PureOffPolicyBC__lr_{}".format(cls.lr())

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        offpolicy_demo_info = cls.offpolicy_demo_defaults(also_using_ppo=False)

        return cls._training_pipeline(
            named_losses={"offpolicy_expert_ce_loss": BabyAIOffPolicyExpertCELoss(),},
            pipeline_stages=[
                PipelineStage(
                    loss_names=[],
                    max_stage_steps=training_steps,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                    offpolicy_component=OffPolicyPipelineComponent(
                        data_iterator_builder=offpolicy_demo_info[
                            "data_iterator_builder"
                        ],
                        loss_names=["offpolicy_expert_ce_loss"],
                        updates=offpolicy_demo_info["offpolicy_updates"],
                    ),
                ),
            ],
            num_mini_batch=0,
            update_repeats=0,
        )
