from typing import Sequence, Union

import torch.nn as nn

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.experiment_utils import Builder, TrainingPipeline
from allenact_plugins.clip_plugin.clip_preprocessors import (
    ClipResNetPreprocessor,
    ClipViTPreprocessor,
    PromptClipViTPreprocessor
)
from projects.plugins.ithor_plugin.ithor_sensors import (
    GoalObjectTypeThorSensor,
    RGBSensorThor,
)
from projects.object_navigation.baseline_configs.clip.zeroshot_mixins import (
    ZeroshotClipResNetPreprocessGRUActorCriticMixin,
    ZeroshotClipViTPreprocessGRUActorCriticMixin,

)
from projects.object_navigation.baseline_configs.robothor.zeroshot_objectnav_robothor_base import (
    ZeroshotObjectNavRoboThorBaseConfig,
)
from projects.object_navigation.baseline_configs.robothor.clip.objectnav_robothor_rgb_clipresnet50gru_ddppo import (
    ObjectNavRoboThorClipRGBPPOExperimentConfig,
)
from projects.object_navigation.baseline_configs.ithor.objectnav_ithor_base import (
    ObjectNaviThorBaseConfig,
)
from projects.object_navigation.baseline_configs.navigation_base import ObjectNavPPOMixin


# class ZeroshotObjectNavRoboThorClipRGBPPOExperimentConfig(
#     ZeroshotObjectNavRoboThorBaseConfig,
#     ObjectNavRoboThorClipRGBPPOExperimentConfig
# ):
class ObjectNaviThorClipViTRGBPPOExperimentConfig(ObjectNaviThorBaseConfig):
    """A Zeroshot CLIP Object Navigation experiment configuration in RoboThor
    with RGB input."""

    CLIP_MODEL_TYPE = "ViT-B/32"
    NOISE_STD = 0.0

    SENSORS = [
        RGBSensorThor(
            height=ObjectNaviThorBaseConfig.SCREEN_SIZE,
            width=ObjectNaviThorBaseConfig.SCREEN_SIZE,
            use_resnet_normalization=True,
            mean=PromptClipViTPreprocessor.CLIP_RGB_MEANS,
            stdev=PromptClipViTPreprocessor.CLIP_RGB_STDS,
            uuid="rgb_lowres",
        ),
        GoalObjectTypeThorSensor(object_types=ObjectNaviThorBaseConfig.TARGET_TYPES,),
    ]

    # together
    PROMPT = ("projects/object_navigation/prompts/state_method/contrastive_state_latest.pth", 
                "projects/object_navigation/prompts/action_method/comparative_action_byol_latest.pth")
    '''
    # state only
    PROMPT = ("projects/object_navigation/prompts/state_method/contrastive_state_latest.pth", 
                False)
    
    PROMPT = ("projects/object_navigation/prompts/state_con/12mdps_fullshot/contrastive_state_latest.pth", 
                False)
    # action only
    PROMPT = (False, 
                "projects/object_navigation/prompts/action_method/comparative_action_byol_latest.pth")
    
    PROMPT = (False, 
                "projects/object_navigation/prompts/action_con/12mdps_fullshot/contrastive_action_latest.pth")
    # original state + action only
    PROMPT = (True, 
                "projects/object_navigation/prompts/action_con/12mdps_16shot/contrastive_action_byol_latest.pth")
    
    PROMPT = (True, 
                "projects/object_navigation/prompts/action_con/12mdps_fullshot/contrastive_action_latest.pth")
    '''

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.preprocessing_and_model = ZeroshotClipViTPreprocessGRUActorCriticMixin(
            sensors=self.SENSORS,
            clip_model_type=self.CLIP_MODEL_TYPE,
            screen_size=self.SCREEN_SIZE,
            goal_sensor_type=GoalObjectTypeThorSensor,
            pool=True,
            pooling_type='attn',
            target_types=self.TARGET_TYPES,
            prompt = self.PROMPT,
            noise_std = self.NOISE_STD,
        )

    def training_pipeline(self, **kwargs) -> TrainingPipeline:
        return ObjectNavPPOMixin.training_pipeline(
            auxiliary_uuids=[],
            multiple_beliefs=False,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
        )

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        return self.preprocessing_and_model.preprocessors()

    def create_model(self, **kwargs) -> nn.Module:
        return self.preprocessing_and_model.create_model(
            num_actions=self.ACTION_SPACE.n, **kwargs
        )

    def tag(cls):
        return "State-Con-Action-Com-PromptClipViTGRU-DDPPO-MDP1-Noise00"
