from typing import cast

import gym
import torch.nn as nn

from extensions.rl_babyai.babyai_offpolicy import BabyAIOffPolicyAdvisorLoss
from extensions.rl_lighthouse.lighthouse_models import LinearAdvisorActorCritic
from extensions.rl_minigrid.minigrid_experiments.base import (
    MiniGridBaseExperimentConfig,
)
from extensions.rl_minigrid.minigrid_models import MiniGridSimpleConvRNN
from extensions.rl_minigrid.minigrid_sensors import EgocentricMiniGridSensor
from onpolicy_sync.losses.advisor import LinearAlphaScheduler
from rl_base.sensor import SensorSuite
from utils.experiment_utils import PipelineStage, OffPolicyPipelineComponent, Builder


class PPOWithOffPolicyAdvisorFixedAlphaDifferentHeadsLevelExperimentConfig(
    MiniGridBaseExperimentConfig
):
    """PPO and Imitation with adaptive reweighting."""

    SAME_WEIGHTS_FOR_ADVISOR_HEAD = False
    FIXED_ALPHA = True

    @classmethod
    def extra_tag(cls):
        return "AdvisorDiffHeadsFixedAlphaOffPolicy__alpha_{}__lr_{}".format(
            cls.alpha(), cls.lr(),
        )

    @classmethod
    def training_pipeline(cls, **kwargs):
        training_steps = cls.TOTAL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=training_steps)
        offpolicy_demo_info = cls.offpolicy_demo_defaults(also_using_ppo=True)

        fixed_alpha = None if not cls.FIXED_ALPHA else cls.alpha()
        alpha_scheduler = (
            None
            if cls.FIXED_ALPHA
            else LinearAlphaScheduler(
                cls.anneal_alpha_start(), cls.anneal_alpha_stop(), training_steps,
            )
        )

        assert (
            cls.FIXED_ALPHA and (fixed_alpha is not None and alpha_scheduler is None)
        ) or (
            (not cls.FIXED_ALPHA)
            and (fixed_alpha is None and alpha_scheduler is not None)
        )

        return cls._training_pipeline(
            named_losses={
                "ppo_loss": ppo_info["loss"],
                "offpolicy_advisor_loss": BabyAIOffPolicyAdvisorLoss(
                    fixed_alpha=fixed_alpha, alpha_scheduler=alpha_scheduler,
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["ppo_loss"],
                    max_stage_steps=training_steps,
                    early_stopping_criterion=cls.task_info().get(
                        "early_stopping_criterion"
                    ),
                    offpolicy_component=OffPolicyPipelineComponent(
                        data_iterator_builder=offpolicy_demo_info[
                            "data_iterator_builder"
                        ],
                        loss_names=["offpolicy_advisor_loss"],
                        updates=offpolicy_demo_info["offpolicy_updates"],
                    ),
                ),
            ],
            num_mini_batch=offpolicy_demo_info["ppo_num_mini_batch"],
            update_repeats=offpolicy_demo_info["ppo_update_repeats"],
        )

    @classmethod
    def create_model(cls, **kwargs) -> nn.Module:
        sensors = cls.get_sensors()
        return MiniGridSimpleConvRNN(
            action_space=gym.spaces.Discrete(
                len(cls.task_info()["task_class"].class_action_names())
            ),
            num_objects=cast(EgocentricMiniGridSensor, sensors[0]).num_objects,
            num_colors=cast(EgocentricMiniGridSensor, sensors[0]).num_colors,
            num_states=cast(EgocentricMiniGridSensor, sensors[0]).num_states,
            observation_space=SensorSuite(sensors).observation_spaces,
            hidden_size=128,
            rnn_type=cls.RNN_TYPE,
            head_type=Builder(  # type: ignore
                LinearAdvisorActorCritic,
                kwargs={"ensure_same_weights": cls.SAME_WEIGHTS_FOR_ADVISOR_HEAD},
            ),
        )
