import torch.nn as nn
from typing import Sequence, Union

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.experiment_utils import Builder, TrainingPipeline
from allenact_plugins.clip_plugin.clip_preprocessors import (
    ClipViTPreprocessor,
    NaivePreprocessor
)
from projects.plugins.ithor_plugin.ithor_sensors import RGBSensorThor
from projects.plugins.robothor_plugin.robothor_sensors import GPSCompassSensorRoboThor
from projects.point_navigation.baseline_configs.clip.zeroshot_mixins import (
    ClipViTPreprocessGRUActorCriticMixin,
    CONPEPointNavActorCriticMixin
)
from projects.point_navigation.baseline_configs.ithor.pointnav_ithor_base import (
    PointNaviThorBaseConfig
)
from projects.point_navigation.baseline_configs.mixins import PointNavPPOMixin


class ObjectNaviThorClipViTRGBPPOExperimentConfig(PointNaviThorBaseConfig):
    """A CLIP Object Navigation experiment configuration in RoboThor
    with RGB input."""

    CLIP_MODEL_TYPE = "ViT-B/32"
    # NOISE_STD = 0.02
    # SEMANTIC_NOISE = (0.0, True, "../logs/PROMPTS/10promptcls/checkpoints/text_features_best.pth", 0.996)
    NOISE_STD = 0.0
    SEMANTIC_NOISE = (0.0, True, "", 0.0)

    SENSORS = [
        RGBSensorThor(
            height=PointNaviThorBaseConfig.SCREEN_SIZE,
            width=PointNaviThorBaseConfig.SCREEN_SIZE,
            use_resnet_normalization=True,
            mean=ClipViTPreprocessor.CLIP_RGB_MEANS,
            stdev=ClipViTPreprocessor.CLIP_RGB_STDS,
            uuid="rgb_lowres",
        ),
        GPSCompassSensorRoboThor(),
    ]
    
    PROMPT = (
        "../logs/PROMPTS/BRIGHTNESS/checkpoints/contrastive__latest.pth",
        "../logs/PROMPTS/CONTRAST/checkpoints/contrastive__latest.pth",
        "../logs/PROMPTS/SATURATION/checkpoints/contrastive__latest.pth",
        "../logs/PROMPTS/HUE/checkpoints/contrastive__latest.pth",
        
        "../logs/PROMPTS/FOV_39-59/checkpoints/comparative_action_byol_latest.pth",
        "../logs/PROMPTS/FOV_69-89/checkpoints/comparative_action_byol_latest.pth",
        "../logs/PROMPTS/FOV_99-139/checkpoints/comparative_action_byol_latest.pth",

        "../logs/PROMPTS/LOOK/checkpoints/comparative_action_byol_latest.pth",
        "../logs/PROMPTS/ROTATE/checkpoints/comparative_action_byol_latest.pth",
        "../logs/PROMPTS/STEPSIZE/checkpoints/comparative_action_byol_latest.pth",
    )
    MULTI_P_MODE = [
        ("COMPOSE", "UNIFORM", "AVG"), 
        ("COMPOSE", "UNIFORM", "CAT"), 
        ("COMPOSE", "WEIGHTED", "AVG"), 
        ("COMPOSE", "WEIGHTED", "CAT"), 
        ("ENSEMBLE", "UNIFORM", "AVG"), 
        ("ENSEMBLE", "UNIFORM", "CAT"), 
        ("ENSEMBLE", "WEIGHTED", "AVG"), 
        ("ENSEMBLE", "WEIGHTED", "CAT"), 
        ("ATTEMPT","WEIGHTED", "AVG"),
        ("SESoM","WEIGHTED", "AVG"),
        ]
    META_MODE = True
    SOURCE_MODEL = (
        # "/path/to/MMRL/allenact/storage/MAIN-EXP/ConPE/checkpoints/PromptATTNCLIPViTGRU-DDPPO-MDPs/2023-04-29_01-03-12/exp_PromptATTNCLIPViTGRU-DDPPO-MDPs__stage_00__steps_000003000000.pt",
        None,
        "/path/to/MMRL/allenact/storage/MAIN-EXP-POINTNAV/EMBCLIP/checkpoints/ViTGRU-DDPPO-MDPs/2023-05-05_02-09-16/exp_ViTGRU-DDPPO-MDPs__stage_00__steps_000001505280.pt"
        )
    # SOURCE_MODEL = None
    

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        #### Domain defined by domain factors ####
        self.STEP_SIZE =             [0.1, 0.15, 0.25, 0.3]
        self.ROTATION_DEGREES =      [90.0, 60.0, 30.0, 10.0]
        self.VISIBILITY_DISTANCE =   [1.5 , 1.5, 1.5, 1.5]
        self.LIGHTING_VALUE =        [(0.6, 0.2, 1.5, -0.4), (1.1, 1.0, 0.5, -0.1), None, (2.0, 3.5, 2, 0.4)]
        self.HORIZONTAL_FIELD_OF_VIEW = [59, 69, 79, 89]
        self.LOOK_DEGREES = [40, 10, 30, 20]
        # self.STEP_SIZE =                [0.25, 0.25, 0.25, 0.25, 0.25]
        # self.ROTATION_DEGREES =         [30.0, 30.0, 30.0, 30.0, 30.0]
        # self.VISIBILITY_DISTANCE =      [1.0 , 1.0, 1.0, 1.0, 1.0]
        # self.LIGHTING_VALUE =           [(0.2, None, None, None), (None, 3.4, None, None), (None, None, 0.5, None), (None, None, None, 0.4), None]
        # self.HORIZONTAL_FIELD_OF_VIEW = [79, 79, 79, 79, 139]
        # self.LOOK_DEGREES =             [30, 30, 30, 30, 30]
        # self.STEP_SIZE =             [0.25, 0.3, 0.05, 0.35]
        # self.ROTATION_DEGREES =      [30, 10, 90, 5]
        # self.VISIBILITY_DISTANCE =   [1.0, 1.0 , 1.0, 1.0]
        # self.LIGHTING_VALUE =        [None, (1.7, 1.4, 1.5, 0.3), (0.4, 1.2, 1.7, 0.0), (0.6, 1.3, 1.8, 0.2)]
        # self.HORIZONTAL_FIELD_OF_VIEW = [79, 89, 59, 129]
        # self.LOOK_DEGREES = [30, 15, 5, 10]
        # candidate 1
        # self.STEP_SIZE =             [0.25]
        # self.ROTATION_DEGREES =      [30.0]
        # self.VISIBILITY_DISTANCE =   [1.0]
        # self.LIGHTING_VALUE =        [None]
        # self.HORIZONTAL_FIELD_OF_VIEW = [79]
        # self.LOOK_DEGREES = [30.0]
        ##########################################
        self.DATA_GEN = False

        self.preprocessing_and_model = CONPEPointNavActorCriticMixin(
            sensors=self.SENSORS,
            clip_model_type=self.CLIP_MODEL_TYPE,
            screen_size=self.SCREEN_SIZE,
            goal_sensor_type=GPSCompassSensorRoboThor,
            pool=False,
            pooling_type='',
            target_types=self.TARGET_TYPES,
            prompt = self.PROMPT,
            multi_p_mode = self.MULTI_P_MODE[6],
            meta_mode = self.META_MODE,
            noise_std = self.NOISE_STD,
            sm_noise = self.SEMANTIC_NOISE,
            source_model = self.SOURCE_MODEL
        )

    def training_pipeline(self, **kwargs) -> TrainingPipeline:
        return PointNavPPOMixin.training_pipeline(
            auxiliary_uuids=[],
            multiple_beliefs=False,
            normalize_advantage=True,
            advance_scene_rollout_period=self.ADVANCE_SCENE_ROLLOUT_PERIOD,
        )

    def preprocessors(self) -> Sequence[Union[Preprocessor, Builder[Preprocessor]]]:
        return self.preprocessing_and_model.preprocessors()

    def create_model(self, **kwargs) -> nn.Module:
        return self.preprocessing_and_model.create_model(
            num_actions=self.ACTION_SPACE.n, **kwargs
        )

    def tag(cls):
        return "PromptATTNCLIPViTGRU-DDPPO-MDPs"
# file name: pointnav_ithor_rgb_conpe_gru_ddppo_mdps