import torch.nn as nn
from typing import Sequence, Union

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.experiment_utils import Builder, TrainingPipeline
from allenact_plugins.clip_plugin.clip_preprocessors import (
    ClipViTPreprocessor,
    NaivePreprocessor
)
from projects.plugins.ithor_plugin.ithor_sensors import RGBSensorThor
from projects.plugins.robothor_plugin.robothor_sensors import GPSCompassSensorRoboThor
from projects.point_navigation.baseline_configs.clip.zeroshot_mixins import (
    ClipViTPreprocessGRUActorCriticMixin,
    CLIPViTGRUActorCriticMixin
)
from projects.point_navigation.baseline_configs.ithor.pointnav_ithor_base import (
    PointNaviThorAUTOTESTBaseConfig
)
from projects.point_navigation.baseline_configs.mixins import PointNavPPOMixin


class ObjectNaviThorClipViTRGBPPOExperimentConfig(PointNaviThorAUTOTESTBaseConfig):
    """A CLIP Object Navigation experiment configuration in RoboThor
    with RGB input."""

    CLIP_MODEL_TYPE = "ViT-B/32"
    NOISE_STD = 0.0

    SENSORS = [
        RGBSensorThor(
            height=PointNaviThorAUTOTESTBaseConfig.SCREEN_SIZE,
            width=PointNaviThorAUTOTESTBaseConfig.SCREEN_SIZE,
            use_resnet_normalization=True,
            mean=ClipViTPreprocessor.CLIP_RGB_MEANS,
            stdev=ClipViTPreprocessor.CLIP_RGB_STDS,
            uuid="rgb_lowres",
        ),
        GPSCompassSensorRoboThor(),
    ]
    
    PROMPT = False
    # CKPT = "/path/to/MMRL/logs/curl/ObjNav12mdps_16shot/checkpoint_0499.pth.tar"/
    CKPT = "/path/to/MMRL/logs/atc/checkpoints/comparative_action_byol_latest.pth"
    # CKPT = (
    #     "/path/to/MMRL/logs/curl/ObjNav12mdps_16shot/checkpoint_0499.pth.tar",
    #     "/path/to/MMRL/logs/acp/checkpoints/comparative_action_byol_latest.pth",
    # )
    MULTI_P_MODE = [None]
    META_MODE = False

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.DATA_GEN = False

        # self.preprocessing_and_model = ClipViTPreprocessGRUActorCriticMixin(
        #     sensors=self.SENSORS,
        #     clip_model_type=self.CLIP_MODEL_TYPE,
        #     screen_size=self.SCREEN_SIZE,
        #     goal_sensor_type=GoalObjectTypeThorSensor,
        #     pool=False,
        #     pooling_type='',
        #     target_types=self.TARGET_TYPES,
        #     prompt = self.PROMPT,
        #     noise_std = self.NOISE_STD,
        # )
        self.preprocessing_and_model = ClipViTPreprocessGRUActorCriticMixin(
            sensors=self.SENSORS,
            clip_model_type=self.CLIP_MODEL_TYPE,
            screen_size=self.SCREEN_SIZE,
            goal_sensor_type=GPSCompassSensorRoboThor,
            pool=False,
            pooling_type='',
            target_types=self.TARGET_TYPES,
            ckpt = self.CKPT,
            noise_std = self.NOISE_STD,
        )


    def training_pipeline(self, **kwargs) -> TrainingPipeline:
        return PointNavPPOMixin.training_pipeline(
            auxiliary_uuids=[],
            multiple_beliefs=False,
            normalize_advantage=True,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
        )

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        return self.preprocessing_and_model.preprocessors()

    def create_model(self, **kwargs) -> nn.Module:
        return self.preprocessing_and_model.create_model(
            num_actions=self.ACTION_SPACE.n, **kwargs
        )

    def tag(cls):
        return "pointnav_ithor_rgb_clip_vit32_comps_gru_ddppo_autotest"
# file name: pointnav_ithor_rgb_clip_vit32_comps_gru_ddppo_autotest