import gym
from torch import nn

from extensions.rl_lighthouse.lighthouse_experiments.base import (
    BaseLightHouseExperimentConfig,
)
from extensions.rl_lighthouse.lighthouse_models import LinearAdvisorActorCritic
from models.basic_models import RNNActorCritic
from onpolicy_sync.losses.advisor import AdvisorWeightedStage
from rl_base.sensor import SensorSuite
from utils.experiment_utils import Builder, PipelineStage


class LightHouseAdvisorPPOExperimentConfig(BaseLightHouseExperimentConfig):
    """PPO and Imitation with adaptive reweighting."""

    @classmethod
    def tag(cls):
        return "LightHouseAdvisorPPO"

    @classmethod
    def training_pipeline(cls, **kwargs):
        alpha = 20
        training_steps = cls.TOTAL_TRAIN_STEPS
        ppo_info = cls.rl_loss_default("ppo", steps=training_steps)

        return cls._training_pipeline(
            named_losses={
                "advisor_loss": Builder(
                    AdvisorWeightedStage,
                    kwargs={"rl_loss": ppo_info["loss"], "fixed_alpha": alpha,},
                ),
            },
            pipeline_stages=[
                PipelineStage(
                    loss_names=["advisor_loss"],
                    early_stopping_criterion=cls.get_early_stopping_criterion(),
                    max_stage_steps=training_steps,
                ),
            ],
            num_mini_batch=ppo_info["num_mini_batch"],
            update_repeats=ppo_info["update_repeats"],
        )

    @classmethod
    def create_model(cls, **kwargs) -> nn.Module:
        sensors = cls.get_sensors()
        if cls.RECURRENT_MODEL:
            return RNNActorCritic(
                input_key=sensors[0].uuid,
                action_space=gym.spaces.Discrete(2 * cls.WORLD_DIM),
                observation_space=SensorSuite(sensors).observation_spaces,
                head_type=Builder(  # type: ignore
                    LinearAdvisorActorCritic, kwargs={"ensure_same_weights": False}
                ),
            )
        else:
            return LinearAdvisorActorCritic(
                input_key=sensors[0].uuid,
                action_space=gym.spaces.Discrete(2 * cls.WORLD_DIM),
                observation_space=SensorSuite(sensors).observation_spaces,
                ensure_same_weights=False,
            )
