"""Baseline models for use in the object navigation task.

Object navigation is currently available as a Task in AI2-THOR and
Facebook's Habitat.
"""
from typing import Optional, List, Dict, cast, Tuple, Sequence

import gym
import torch
import torch.nn as nn
from gym.spaces import Dict as SpaceDict

from allenact.algorithms.onpolicy_sync.policy import ObservationType
from allenact.embodiedai.models import resnet as resnet
from allenact.embodiedai.models.basic_models import SimpleCNN
from allenact.embodiedai.models.visual_nav_models import (
    VisualNavActorCritic,
    FusionType,
)
from typing import (
    Optional,
    Tuple,
    Sequence,
    Union,
    Dict,
    Any,
)

import clip
from clip.model import CLIP
from torch.nn import Conv2d, Dropout
from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.misc_utils import prepare_locals_for_super
import numpy as np

class CatObservations(nn.Module):
    def __init__(self, ordered_uuids: Sequence[str], dim: int):
        super().__init__()
        assert len(ordered_uuids) != 0

        self.ordered_uuids = ordered_uuids
        self.dim = dim

    def forward(self, observations: ObservationType):
        if len(self.ordered_uuids) == 1:
            return observations[self.ordered_uuids[0]]
        return torch.cat(
            [observations[uuid] for uuid in self.ordered_uuids], dim=self.dim
        )


class ObjectNavActorCritic(VisualNavActorCritic):
    """Baseline recurrent actor critic model for object-navigation.

    # Attributes
    action_space : The space of actions available to the agent. Currently only discrete
        actions are allowed (so this space will always be of type `gym.spaces.Discrete`).
    observation_space : The observation space expected by the agent. This observation space
        should include (optionally) 'rgb' images and 'depth' images and is required to
        have a component corresponding to the goal `goal_sensor_uuid`.
    goal_sensor_uuid : The uuid of the sensor of the goal object. See `GoalObjectTypeThorSensor`
        as an example of such a sensor.
    hidden_size : The hidden size of the GRU RNN.
    object_type_embedding_dim: The dimensionality of the embedding corresponding to the goal
        object type.
    """

    def __init__(
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str,
        # RNN
        hidden_size=512,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        # Aux loss
        multiple_beliefs=False,
        beliefs_fusion: Optional[FusionType] = None,
        auxiliary_uuids: Optional[Sequence[str]] = None,
        # below are custom params
        rgb_uuid: Optional[str] = None,
        depth_uuid: Optional[str] = None,
        object_type_embedding_dim=8,
        trainable_masked_hidden_state: bool = False,
        # perception backbone params,
        backbone="gnresnet18",
        resnet_baseplanes=32,
    ):
        """Initializer.

        See class documentation for parameter definitions.
        """
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size,
            multiple_beliefs=multiple_beliefs,
            beliefs_fusion=beliefs_fusion,
            auxiliary_uuids=auxiliary_uuids,
        )

        self.rgb_uuid = rgb_uuid
        self.depth_uuid = depth_uuid

        self.goal_sensor_uuid = goal_sensor_uuid
        self._n_object_types = self.observation_space.spaces[self.goal_sensor_uuid].n
        self.object_type_embedding_size = object_type_embedding_dim
  
        self.backbone = backbone
        if backbone == "simple_cnn":
            self.visual_encoder = SimpleCNN(
                observation_space=observation_space,
                output_size=hidden_size,
                rgb_uuid=rgb_uuid,
                depth_uuid=depth_uuid,
            )
            self.visual_encoder_output_size = hidden_size
            assert self.is_blind == self.visual_encoder.is_blind
        elif backbone == "gnresnet18":  # resnet family
            self.visual_encoder = resnet.GroupNormResNetEncoder(
                observation_space=observation_space,
                output_size=hidden_size,
                rgb_uuid=rgb_uuid,
                depth_uuid=depth_uuid,
                baseplanes=resnet_baseplanes,
                ngroups=resnet_baseplanes // 2,
                make_backbone=getattr(resnet, backbone),
            )
            self.visual_encoder_output_size = hidden_size
            assert self.is_blind == self.visual_encoder.is_blind
        elif backbone in ["identity", "projection"]:
            good_uuids = [
                uuid for uuid in [self.rgb_uuid, self.depth_uuid] if uuid is not None
            ]
            cat_model = CatObservations(ordered_uuids=good_uuids, dim=-1,)
            after_cat_size = sum(
                observation_space[uuid].shape[-1] for uuid in good_uuids
            )
            if backbone == "identity":
                self.visual_encoder = cat_model
                self.visual_encoder_output_size = after_cat_size
            else:
                self.visual_encoder = nn.Sequential(
                    cat_model, nn.Linear(after_cat_size, hidden_size), nn.ReLU(True)
                )
                self.visual_encoder_output_size = hidden_size

        else:
            raise NotImplementedError

        self.create_state_encoders(
            obs_embed_size=self.goal_visual_encoder_output_dims,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
            trainable_masked_hidden_state=trainable_masked_hidden_state,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=self.goal_visual_encoder_output_dims,
            action_embed_size=action_embed_size,
        )

        self.object_type_embedding = nn.Embedding(
            num_embeddings=self._n_object_types,
            embedding_dim=object_type_embedding_dim,
        )
  
        self.train()

    @property
    def is_blind(self) -> bool:
        """True if the model is blind (e.g. neither 'depth' or 'rgb' is an
        input observation type)."""
        return self.rgb_uuid is None and self.depth_uuid is None

    @property
    def goal_visual_encoder_output_dims(self):
        dims = self.object_type_embedding_size
        if self.is_blind:
            return dims
        return dims + self.visual_encoder_output_size

    def get_object_type_encoding(
        self, observations: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """Get the object type encoding from input batched observations."""
        # noinspection PyTypeChecker
        return self.object_type_embedding(  # type:ignore
            observations[self.goal_sensor_uuid].to(torch.int64)
        )

    def forward_encoder(self, observations: ObservationType) -> torch.Tensor:
        target_encoding = self.get_object_type_encoding(
            cast(Dict[str, torch.Tensor], observations)
        )
        obs_embeds = [target_encoding]
        ### no
        if not self.is_blind:
            perception_embed = self.visual_encoder(observations)
            obs_embeds = [perception_embed] + obs_embeds

        obs_embeds = torch.cat(obs_embeds, dim=-1)
        return obs_embeds


class ResnetTensorNavActorCritic(VisualNavActorCritic):
    def __init__(
        # base params
        self,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        goal_sensor_uuid: str,
        hidden_size=512,
        num_rnn_layers=1,
        rnn_type="GRU",
        add_prev_actions=False,
        add_prev_action_null_token=False,
        action_embed_size=6,
        multiple_beliefs=False,
        beliefs_fusion: Optional[FusionType] = None,
        auxiliary_uuids: Optional[List[str]] = None,
        # custom params
        rgb_resnet_preprocessor_uuid: Optional[str] = None,
        depth_resnet_preprocessor_uuid: Optional[str] = None,
        goal_dims: int = 32,
        resnet_compressor_hidden_out_dims: Tuple[int, int] = (128, 32),
        combiner_hidden_out_dims: Tuple[int, int] = (128, 32),
    ):
        super().__init__(
            action_space=action_space,
            observation_space=observation_space,
            hidden_size=hidden_size,
            multiple_beliefs=multiple_beliefs,
            beliefs_fusion=beliefs_fusion,
            auxiliary_uuids=auxiliary_uuids,
        )

        if (
            rgb_resnet_preprocessor_uuid is None
            or depth_resnet_preprocessor_uuid is None
        ):
            resnet_preprocessor_uuid = (
                rgb_resnet_preprocessor_uuid
                if rgb_resnet_preprocessor_uuid is not None
                else depth_resnet_preprocessor_uuid
            )
            self.goal_visual_encoder = ResnetTensorGoalEncoder(
                self.observation_space,
                goal_sensor_uuid,
                resnet_preprocessor_uuid,
                goal_dims,
                resnet_compressor_hidden_out_dims,
                combiner_hidden_out_dims,
            )
        else:
            self.goal_visual_encoder = ResnetDualTensorGoalEncoder(  # type:ignore
                self.observation_space,
                goal_sensor_uuid,
                rgb_resnet_preprocessor_uuid,
                depth_resnet_preprocessor_uuid,
                goal_dims,
                resnet_compressor_hidden_out_dims,
                combiner_hidden_out_dims,
            )

        self.create_state_encoders(
            obs_embed_size=self.goal_visual_encoder.output_dims,
            num_rnn_layers=num_rnn_layers,
            rnn_type=rnn_type,
            add_prev_actions=add_prev_actions,
            add_prev_action_null_token=add_prev_action_null_token,
            prev_action_embed_size=action_embed_size,
        )

        self.create_actorcritic_head()

        self.create_aux_models(
            obs_embed_size=self.goal_visual_encoder.output_dims,
            action_embed_size=action_embed_size,
        )

        self.train()

    @property
    def is_blind(self) -> bool:
        """True if the model is blind (e.g. neither 'depth' or 'rgb' is an
        input observation type)."""
        return self.goal_visual_encoder.is_blind

    def forward_encoder(self, observations: ObservationType) -> torch.FloatTensor:

        return self.goal_visual_encoder(observations)


class ResnetTensorGoalEncoder(nn.Module):
    def __init__(
        self,
        observation_spaces: SpaceDict,
        goal_sensor_uuid: str,
        resnet_preprocessor_uuid: str,
        goal_embed_dims: int = 32,
        resnet_compressor_hidden_out_dims: Tuple[int, int] = (128, 32),
        combiner_hidden_out_dims: Tuple[int, int] = (128, 32),
    ) -> None:
        super().__init__()
        self.goal_uuid = goal_sensor_uuid
        self.resnet_uuid = resnet_preprocessor_uuid
        self.goal_embed_dims = goal_embed_dims
        self.resnet_hid_out_dims = resnet_compressor_hidden_out_dims
        self.combine_hid_out_dims = combiner_hidden_out_dims

        self.goal_space = observation_spaces.spaces[self.goal_uuid]
        if isinstance(self.goal_space, gym.spaces.Discrete):
            self.embed_goal = nn.Embedding(
                num_embeddings=self.goal_space.n, embedding_dim=self.goal_embed_dims,
            )
        elif isinstance(self.goal_space, gym.spaces.Box):
            self.embed_goal = nn.Linear(self.goal_space.shape[-1], self.goal_embed_dims)
        else:
            raise NotImplementedError

        self.blind = self.resnet_uuid not in observation_spaces.spaces
        if not self.blind:
            self.resnet_tensor_shape = observation_spaces.spaces[self.resnet_uuid].shape
            self.resnet_compressor = nn.Sequential(
                nn.Conv2d(self.resnet_tensor_shape[0], self.resnet_hid_out_dims[0], 1),
                nn.ReLU(),
                nn.Conv2d(*self.resnet_hid_out_dims[0:2], 1),
                nn.ReLU(),
            )
            self.target_obs_combiner = nn.Sequential(
                nn.Conv2d(
                    self.resnet_hid_out_dims[1] + self.goal_embed_dims,
                    self.combine_hid_out_dims[0],
                    1,
                ),
                nn.ReLU(),
                nn.Conv2d(*self.combine_hid_out_dims[0:2], 1),
            )

    @property
    def is_blind(self):
        return self.blind

    @property
    def output_dims(self):
        if self.blind:
            return self.goal_embed_dims
        else:
            return (
                self.combine_hid_out_dims[-1]
                * self.resnet_tensor_shape[1]
                * self.resnet_tensor_shape[2]
            )

    def get_object_type_encoding(
        self, observations: Dict[str, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """Get the object type encoding from input batched observations."""
        return cast(
            torch.FloatTensor,
            self.embed_goal(observations[self.goal_uuid].to(torch.int64)),
        )

    def compress_resnet(self, observations):
        return self.resnet_compressor(observations[self.resnet_uuid])

    def distribute_target(self, observations):
        target_emb = self.embed_goal(observations[self.goal_uuid])
        return target_emb.view(-1, self.goal_embed_dims, 1, 1).expand(
            -1, -1, self.resnet_tensor_shape[-2], self.resnet_tensor_shape[-1]
        )

    def adapt_input(self, observations):
        resnet = observations[self.resnet_uuid]
        goal = observations[self.goal_uuid]

        use_agent = False
        nagent = 1

        if len(resnet.shape) == 6:
            use_agent = True
            nstep, nsampler, nagent = resnet.shape[:3]
        else:
            nstep, nsampler = resnet.shape[:2]

        observations[self.resnet_uuid] = resnet.view(-1, *resnet.shape[-3:])
        observations[self.goal_uuid] = goal.view(-1, goal.shape[-1])

        return observations, use_agent, nstep, nsampler, nagent

    @staticmethod
    def adapt_output(x, use_agent, nstep, nsampler, nagent):
        if use_agent:
            return x.view(nstep, nsampler, nagent, -1)
        return x.view(nstep, nsampler * nagent, -1)

    def forward(self, observations):
        observations, use_agent, nstep, nsampler, nagent = self.adapt_input(
            observations
        )


        if self.blind:
            return self.embed_goal(observations[self.goal_uuid])
        embs = [
            self.compress_resnet(observations),
            self.distribute_target(observations),
        ]
        x = self.target_obs_combiner(torch.cat(embs, dim=1,))
        x = x.reshape(x.size(0), -1)  # flatten

        return self.adapt_output(x, use_agent, nstep, nsampler, nagent)


class ResnetDualTensorGoalEncoder(nn.Module):
    def __init__(
        self,
        observation_spaces: SpaceDict,
        goal_sensor_uuid: str,
        rgb_resnet_preprocessor_uuid: str,
        depth_resnet_preprocessor_uuid: str,
        goal_embed_dims: int = 32,
        resnet_compressor_hidden_out_dims: Tuple[int, int] = (128, 32),
        combiner_hidden_out_dims: Tuple[int, int] = (128, 32),
    ) -> None:
        super().__init__()
        self.goal_uuid = goal_sensor_uuid
        self.rgb_resnet_uuid = rgb_resnet_preprocessor_uuid
        self.depth_resnet_uuid = depth_resnet_preprocessor_uuid
        self.goal_embed_dims = goal_embed_dims
        self.resnet_hid_out_dims = resnet_compressor_hidden_out_dims
        self.combine_hid_out_dims = combiner_hidden_out_dims

        self.goal_space = observation_spaces.spaces[self.goal_uuid]
        if isinstance(self.goal_space, gym.spaces.Discrete):
            self.embed_goal = nn.Embedding(
                num_embeddings=self.goal_space.n, embedding_dim=self.goal_embed_dims,
            )
        elif isinstance(self.goal_space, gym.spaces.Box):
            self.embed_goal = nn.Linear(self.goal_space.shape[-1], self.goal_embed_dims)
        else:
            raise NotImplementedError

        self.blind = (
            self.rgb_resnet_uuid not in observation_spaces.spaces
            or self.depth_resnet_uuid not in observation_spaces.spaces
        )
        if not self.blind:
            self.resnet_tensor_shape = observation_spaces.spaces[
                self.rgb_resnet_uuid
            ].shape
            self.rgb_resnet_compressor = nn.Sequential(
                nn.Conv2d(self.resnet_tensor_shape[0], self.resnet_hid_out_dims[0], 1),
                nn.ReLU(),
                nn.Conv2d(*self.resnet_hid_out_dims[0:2], 1),
                nn.ReLU(),
            )
            self.depth_resnet_compressor = nn.Sequential(
                nn.Conv2d(self.resnet_tensor_shape[0], self.resnet_hid_out_dims[0], 1),
                nn.ReLU(),
                nn.Conv2d(*self.resnet_hid_out_dims[0:2], 1),
                nn.ReLU(),
            )
            self.rgb_target_obs_combiner = nn.Sequential(
                nn.Conv2d(
                    self.resnet_hid_out_dims[1] + self.goal_embed_dims,
                    self.combine_hid_out_dims[0],
                    1,
                ),
                nn.ReLU(),
                nn.Conv2d(*self.combine_hid_out_dims[0:2], 1),
            )
            self.depth_target_obs_combiner = nn.Sequential(
                nn.Conv2d(
                    self.resnet_hid_out_dims[1] + self.goal_embed_dims,
                    self.combine_hid_out_dims[0],
                    1,
                ),
                nn.ReLU(),
                nn.Conv2d(*self.combine_hid_out_dims[0:2], 1),
            )

    @property
    def is_blind(self):
        return self.blind

    @property
    def output_dims(self):
        if self.blind:
            return self.goal_embed_dims
        else:
            return (
                2
                * self.combine_hid_out_dims[-1]
                * self.resnet_tensor_shape[1]
                * self.resnet_tensor_shape[2]
            )

    def get_object_type_encoding(
        self, observations: Dict[str, torch.FloatTensor]
    ) -> torch.FloatTensor:
        """Get the object type encoding from input batched observations."""
        return cast(
            torch.FloatTensor,
            self.embed_goal(observations[self.goal_uuid].to(torch.int64)),
        )

    def compress_rgb_resnet(self, observations):
        return self.rgb_resnet_compressor(observations[self.rgb_resnet_uuid])

    def compress_depth_resnet(self, observations):
        return self.depth_resnet_compressor(observations[self.depth_resnet_uuid])

    def distribute_target(self, observations):
        target_emb = self.embed_goal(observations[self.goal_uuid])
        return target_emb.view(-1, self.goal_embed_dims, 1, 1).expand(
            -1, -1, self.resnet_tensor_shape[-2], self.resnet_tensor_shape[-1]
        )

    def adapt_input(self, observations):
        rgb = observations[self.rgb_resnet_uuid]
        depth = observations[self.depth_resnet_uuid]

        use_agent = False
        nagent = 1

        if len(rgb.shape) == 6:
            use_agent = True
            nstep, nsampler, nagent = rgb.shape[:3]
        else:
            nstep, nsampler = rgb.shape[:2]

        observations[self.rgb_resnet_uuid] = rgb.view(-1, *rgb.shape[-3:])
        observations[self.depth_resnet_uuid] = depth.view(-1, *depth.shape[-3:])
        observations[self.goal_uuid] = observations[self.goal_uuid].view(-1, 1)

        return observations, use_agent, nstep, nsampler, nagent

    @staticmethod
    def adapt_output(x, use_agent, nstep, nsampler, nagent):
        if use_agent:
            return x.view(nstep, nsampler, nagent, -1)
        return x.view(nstep, nsampler * nagent, -1)

    def forward(self, observations):
        observations, use_agent, nstep, nsampler, nagent = self.adapt_input(
            observations
        )
        ## no   
        if self.blind:
            return self.embed_goal(observations[self.goal_uuid])
        rgb_embs = [
            self.compress_rgb_resnet(observations),
            self.distribute_target(observations),
        ]
        rgb_x = self.rgb_target_obs_combiner(torch.cat(rgb_embs, dim=1,))
        depth_embs = [
            self.compress_depth_resnet(observations),
            self.distribute_target(observations),
        ]
        depth_x = self.depth_target_obs_combiner(torch.cat(depth_embs, dim=1,))
        x = torch.cat([rgb_x, depth_x], dim=1)
        x = x.reshape(x.shape[0], -1)  # flatten

        return self.adapt_output(x, use_agent, nstep, nsampler, nagent)


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    for module in model.modules():
        if "BatchNorm" in type(module).__name__:
            module.momentum = 0.0
    model.eval()
    return model

class PromptClipViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, clip_model_type: str = "ViT-B/32", task_prompt: bool = False):
        super().__init__()
        self.model = freeze_model(model)
        #self.model.visual.transformer.resblocks = nn.Sequential(
        #    *list(self.model.visual.transformer.resblocks)[:-1]
        #)
        print("load clip model", clip_model_type)
        self.task_prompt = task_prompt

        # prompt config
        if clip_model_type == "ViT-B/32":
            patch_size = (32, 32)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))

        elif clip_model_type == "ViT-B/16":
            patch_size = (16, 16)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 5
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', !not! 'ViT-B/14'"
            )
        
        # task prompt
        if self.task_prompt == True:
            num_task_tokens = 1
            self.prompt_task_embeddings= nn.Parameter(torch.zeros(
            1, num_task_tokens, prompt_dim))
            # print(self.prompt_task_embeddings.dtype,self.prompt_embeddings.dtype)
            # self.prompt_embeddings = torch.cat([self.prompt_embeddings,self.task_embeddings], dim=1)
        self.eval()

    def forward(self, x):
        m = self.model.visual
        x = m.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat(
            [
                m.class_embedding.to(x.dtype)
                + torch.zeros(
                    x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                ),
                x,
            ],
            dim=1,
        )  # shape = [*, grid ** 2 + 1, width]
        x = x + m.positional_embedding.to(x.dtype)
        
        if self.task_prompt == True:
            prompt_embeddings = torch.cat((self.prompt_embeddings,self.prompt_task_embeddings), dim=1)

        else:
            prompt_embeddings = self.prompt_embeddings
 
        # prompt embeddings
        B = x.size(0)
        x = torch.cat((
            x[:, :1, :],
            self.prompt_dropout(self.prompt_proj(prompt_embeddings).expand(B, -1, -1)),
            x[:, 1:, :]
        ), dim=1)


        x = m.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = m.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = m.ln_post(x[:, 0, :])

        if m.proj is not None:
            x = x @ m.proj
            return x

        if self.class_emb_only:
            return x[:, 0, :]
        else:
            return x

class PromptClipViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.48145466, 0.4578275, 0.40821073)
    CLIP_RGB_STDS = (0.26862954, 0.26130258, 0.27577711)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        task_prompt = False,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

        if clip_model_type == "ViT-B/32":
            output_shape = (512, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )
        self.clip_model_type = clip_model_type

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self.task_prompt = task_prompt
        self._vit: Optional[PromptClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)
        
        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def vit(self) -> PromptClipViTEmbedder:
        if self._vit is None:
            self._vit = PromptClipViTEmbedder(
                model=clip.load(self.clip_model_type, device=self.device)[0].float(),
                clip_model_type=self.clip_model_type,
                task_prompt=self.task_prompt
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0
            # load learned prompt
            pretrained_dict = torch.load("projects/object_navigation/pretrained_model_ckpts/contrastive_500.pth", map_location=torch.device('cpu'))
            model_dict_key = ["prompt_proj.weight", "prompt_proj.bias", "prompt_embeddings"]
            model_dict = self._vit.state_dict()
            prompt_dict = {}
            for k, v in pretrained_dict.items():
                for key in model_dict_key:
                    if key in k:
                        prompt_dict[key] = v
            model_dict.update(prompt_dict)
            self._vit.load_state_dict(model_dict)
            print(f"loaded:{model_dict_key}")
            # self._vit.train()

            print("Turning off gradients in both the image and the text encoder")
            for name, param in self._vit.named_parameters():
                    param.requires_grad_(False)
            # Double check
            enabled = set()
            for name, param in self._vit.named_parameters():
                if param.requires_grad:
                    enabled.add(name)
            print(f"Parameters to be updated: {enabled}")
        return self._vit

    def to(self, device: torch.device) -> "PromptClipViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        # x = obs # bchw

        x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
        # x = self.vit(x).float()
        # x = x / x.norm(dim=-1, keepdim=True)
        
        # # no normalization when evaluating
        # # x += torch.normal(0, 0.03, size=x.size()).to(self.device)
        # # x = x / x.norm(dim=-1, keepdim=True)

        x = self.vit(x).float()
        # x = x / x.norm(dim=-1, keepdim=True)
        # std=0.05
        # x = x + torch.clip(torch.normal(0, std, size=x.size()), -1.5 * std, 1.5 * std).to(self.device)
        # x = x / x.norm(dim=-1, keepdim=True)
        return x