from typing import List, Optional, Any, cast, Dict, Tuple
import os
import pickle
import copy

import clip
import gym
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout
from clip.model import CLIP

from allenact.base_abstractions.preprocessor import Preprocessor
from allenact.utils.misc_utils import prepare_locals_for_super


class NaivePreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        goal_sensor_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):

        if clip_model_type == "ViT-B/32":
            output_shape = (3*224*224+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (3*224*224+1, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (3*224*224+1, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape
        if goal_sensor_uuid:
            input_uuids = [rgb_input_uuid, goal_sensor_uuid]
        else:
            input_uuids = [rgb_input_uuid]

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)
        self.task = None
        if "task" in kwargs.keys():
            self.task = kwargs["task"]
            with open('./allenact_plugins/clip_plugin/image_feature.pkl', 'rb') as f:
                self.goal_embs = pickle.load(f)
                for key in self.goal_embs.keys():
                    self.goal_embs[key] = self.goal_embs[key]
        
        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def vit(self):
        pass
        return

    def to(self, device: torch.device):
        self.device = device
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        # import torchvision
        # print(obs[self.input_uuids[0]].shape)
        # print(obs[self.input_uuids[1]].shape)
        # print(obs[self.input_uuids[1]])
        # torchvision.utils.save_image(obs[self.input_uuids[0]].permute(0, 3, 1, 2), "test.png")
        # torchvision.utils.save_image(torchvision.utils.make_grid(obs[self.input_uuids[0]].permute(0, 3, 1, 2), nrow=1, normalize=True), "grid_image_val_.png")
        if "expert_action" in list(obs.keys()):
            B = obs[self.input_uuids[0]].size(0)
            # print(obs[self.input_uuids[0]].shape)
            obs[self.input_uuids[0]] = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2).flatten().view(B, -1)  # bhwc -> bchw
            return obs[self.input_uuids[0]]

        B = obs[self.input_uuids[0]].size(0)
        obs[self.input_uuids[0]] = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2).flatten().view(B, -1)  # bhwc -> bchw
        if len(obs[self.input_uuids[0]].shape) != len(obs[self.input_uuids[1]].shape):
            obs[self.input_uuids[1]] = obs[self.input_uuids[1]].unsqueeze(dim=1)
            if self.task is not None:
                goal = []
                for i in range(B):
                    key = obs[self.input_uuids[1]][i].item()
                    goal.append(self.goal_embs[key].to("cuda"))
                obs[self.input_uuids[1]] = torch.cat(goal)
        x = torch.cat(list(obs.values()), dim=1)
        return x


class ClipResNetEmbedder(nn.Module):
    def __init__(self, resnet: CLIP, pool=True, pooling_type="avg"):
        super().__init__()
        self.model = resnet
        self.pool = pool
        self.pooling_type = pooling_type

        if not pool:
            self.model.visual.attnpool = nn.Identity()
        elif self.pooling_type == "attn":
            pass
        elif self.pooling_type == "avg":
            self.model.visual.attnpool = nn.Sequential(
                nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(start_dim=-3, end_dim=-1)
            )
        else:
            raise NotImplementedError("`pooling_type` must be 'avg' or 'attn'.")

        self.eval()

    def forward(self, x):
        with torch.no_grad():
            return self.model.visual(x)


class ClipResNetPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        pool: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        input_img_height_width: Tuple[int, int] = (224, 224),
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()
        assert pool == False or input_img_height_width == (224, 224)
        assert all(iis % 32 == 0 for iis in input_img_height_width)

        output_height_width = tuple(iis // 32 for iis in input_img_height_width)
        if clip_model_type == "RN50":
            output_shape = (2048,) + output_height_width
        elif clip_model_type == "RN50x16":
            output_shape = (3072,) + output_height_width
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'RN50' or 'RN50x16'"
            )

        if pool:
            output_shape = output_shape[:1]

        self.clip_model_type = clip_model_type

        self.pool = pool

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._resnet: Optional[ClipResNetEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def resnet(self) -> ClipResNetEmbedder:
        if self._resnet is None:
            self._resnet = ClipResNetEmbedder(
                clip.load(self.clip_model_type, device=self.device)[0], pool=self.pool
            ).to(self.device)
            for module in self._resnet.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0
            self._resnet.eval()
        return self._resnet

    def to(self, device: torch.device) -> "ClipResNetPreprocessor":
        self._resnet = self.resnet.to(device)
        self.device = device
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
        # If the input is depth, repeat it across all 3 channels
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        x = self.resnet(x).float()
        return x


class ClipViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False):
        super().__init__()
        self.model = model
        '''
        self.model.visual.transformer.resblocks = nn.Sequential(
            *list(self.model.visual.transformer.resblocks)[:-1]
        )
        '''
        self.class_emb_only = class_emb_only

        self.eval()

    def forward(self, x):
        m = self.model.visual
        with torch.no_grad():
            x = m.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [
                    m.class_embedding.to(x.dtype)
                    + torch.zeros(
                        x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                    ),
                    x,
                ],
                dim=1,
            )  # shape = [*, grid ** 2 + 1, width]
            x = x + m.positional_embedding.to(x.dtype)
            x = m.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = m.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD
            
            x = m.ln_post(x[:, 0, :])
            if m.proj is not None:
                x = x @ m.proj
                return x
            if self.class_emb_only:
                return x[:,0,:]
            else:
                return x
        #if self.class_emb_only:
        #    return x[:, 0, :]
        #else:
        #    return x


class ClipViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        goal_sensor_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

       # if clip_model_type == "ViT-B/32":
       #     output_shape = (7 * 7 + 1, 768)
       # elif clip_model_type == "ViT-B/16":
       #     output_shape = (14 * 14 + 1, 768)
       # elif clip_model_type == "ViT-L/14":
       #     output_shape = (16 * 16 + 1, 1024)
       # else:
       #     raise NotImplementedError(
       #         f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
       #     )
        if clip_model_type == "ViT-B/32":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        self.noise_std = kwargs["noise_std"]
        self.ckpt = kwargs["ckpt"]

        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._vit: Optional[ClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape

        if goal_sensor_uuid:
            input_uuids = [rgb_input_uuid, goal_sensor_uuid]
        else:
            input_uuids = [rgb_input_uuid]
        # assert (
        #     len(input_uuids) == 1
        # ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def vit(self) -> ClipViTEmbedder:
        if self._vit is None:
            self._vit = ClipViTEmbedder(
                model=clip.load(self.clip_model_type, device=self.device)[0],
                class_emb_only=self.class_emb_only,
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0
            self._vit.eval()
        return self._vit

    def to(self, device: torch.device) -> "ClipViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        if isinstance(self.ckpt, tuple):
            source_dict = torch.load(self.ckpt[0])
            model_dict = self.vit.state_dict()
            # print(model_dict.keys())
            source_dict = source_dict["state_dict"]
            temp_dict = {}
            for k, v in source_dict.items():
                if "encoder_q" in k:
                    k = k[17:]
                    temp_dict[k] = v
            pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
            self.vit.load_state_dict(pretrained_dict)

            self.vit2 = copy.deepcopy(self.vit)
            source_dict = torch.load(self.ckpt[1])
            model_dict = self.vit2.state_dict()
            temp_dict = {}
            for k, v in source_dict.items():
                if "net" in k:
                    k = k[4:]
                    temp_dict[k] = v
            pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
            self.vit2.load_state_dict(pretrained_dict)

        elif self.ckpt:
            source_dict = torch.load(self.ckpt)
            if "state_dict" in source_dict.keys():
                model_dict = self.vit.state_dict()
                # print(model_dict.keys())
                source_dict = source_dict["state_dict"]
                temp_dict = {}
                for k, v in source_dict.items():
                    if "encoder_q" in k:
                        k = k[17:]
                        temp_dict[k] = v
                pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
            else:
                model_dict = self.vit.state_dict()
                temp_dict = {}
                for k, v in source_dict.items():
                    if "net" in k:
                        k = k[4:]
                        temp_dict[k] = v
                pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
            self.vit.load_state_dict(pretrained_dict)
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            std = self.noise_std
            in_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
            # If the input is depth, repeat it across all 3 channels
            if in_x.shape[1] == 1:
                in_x = in_x.repeat(1, 3, 1, 1)
            x = self.vit(in_x).float()
            if isinstance(self.ckpt, tuple):
                x += self.vit2(in_x).float()
            # x = x / x.norm(dim=-1, keepdim=True)
            if std:
                noise = torch.clip(torch.normal(0, std, size=x.size()), -1.5 * std, 1.5 * std).to(self.device)
                x += noise
            obs[self.input_uuids[0]] = x
            if len(obs[self.input_uuids[0]].shape) != len(obs[self.input_uuids[1]].shape):
                obs[self.input_uuids[1]] = obs[self.input_uuids[1]].unsqueeze(dim=1)
                if self.task is not None:
                    goal = []
                    for i in range(B):
                        key = obs[self.input_uuids[1]][i].item()
                        goal.append(self.goal_embs[key].to("cuda"))
                    obs[self.input_uuids[1]] = torch.cat(goal)
            x = torch.cat(list(obs.values()), dim=1)
            return x


class PromptClipViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False, clip_model_type: str = "ViT-B/32"):
        super().__init__()
        self.model = model
        #self.model.visual.transformer.resblocks = nn.Sequential(
        #    *list(self.model.visual.transformer.resblocks)[:-1]
        #)
        self.class_emb_only = class_emb_only

        # prompt config
        if clip_model_type == "ViT-B/32":
            patch_size = (32, 32)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 16
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # action
            self.action_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.action_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # dynamics
            self.dynamics_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.dynamics_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            
        elif clip_model_type == "ViT-B/16":
            patch_size = (16, 16)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', !not! 'ViT-B/14'"
            )

        self.eval()

    def forward(self, x, mode=None):
        m = self.model.visual
        with torch.no_grad():
            x = m.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [
                    m.class_embedding.to(x.dtype)
                    + torch.zeros(
                        x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                    ),
                    x,
                ],
                dim=1,
            )  # shape = [*, grid ** 2 + 1, width]
            x = x + m.positional_embedding.to(x.dtype)
            
            if mode == "action":
                prompt_proj = self.action_prompt_proj
                prompt_embeddings = self.action_prompt_embeddings
            elif mode == "dynamics":
                prompt_proj = self.dynamics_prompt_proj
                prompt_embeddings = self.dynamics_prompt_embeddings
            else:
                prompt_proj = self.prompt_proj
                prompt_embeddings = self.prompt_embeddings

            # prompt embeddings
            B = x.size(0)
            x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)


            x = m.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = m.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = m.ln_post(x[:, 0, :])

            if m.proj is not None:
                x = x @ m.proj
                return x

            if self.class_emb_only:
                return x[:, 0, :]
            else:
                return x


class PromptClipViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

        if clip_model_type == "ViT-B/32":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        self.prompt = kwargs["prompt"]
        self.noise_std = kwargs["noise_std"]
        
        if len(self.prompt)==1 or (not self.prompt[1]) or (not self.prompt[0]):
            pass
        else:
            output_shape = (512*2+1, )

        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._vit: Optional[ClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    def load_weight(self, path, prompt_key):
        if path and isinstance(path, str):
            pretrained_dict = torch.load(path)
            model_dict = self._vit.state_dict()
            
            source_prompt_dict = {}
            for source_key, source_val in pretrained_dict.items():
                for key in prompt_key:
                    if (key in source_key):
                        if ("action" not in key) and (("action" in source_key) or ("dynamics" in source_key)):
                            continue
                        source_prompt_dict[source_key] = source_val
            
            target_prompt_dict = {}
            for target_key, target_val in model_dict.items():
                for key in prompt_key:
                    if (key in target_key):
                        if ("action" not in key) and (("action" in target_key) or ("dynamics" in target_key)):
                            continue
                        target_prompt_dict[target_key] = target_val
            
            assert len(source_prompt_dict.keys())
            print(source_prompt_dict.keys())
            print(target_prompt_dict.keys())
            
            # key value switching
            prompt_dict = {}
            for s_k, t_k in zip(source_prompt_dict.keys(), target_prompt_dict.keys()):
                prompt_dict[t_k] = source_prompt_dict[s_k]
            print(prompt_dict.keys(), "loaded")
            
            model_dict.update(prompt_dict)
            assert (model_dict[list(target_prompt_dict.keys())[0]] == prompt_dict[list(target_prompt_dict.keys())[0]]).all()
            self._vit.load_state_dict(model_dict)
        else:
            return


    @property
    def vit(self) -> PromptClipViTEmbedder:
        if self._vit is None:
            clip_model = clip.load(self.clip_model_type, device=self.device)[0]
            self._vit = PromptClipViTEmbedder(
                model=clip_model,
                class_emb_only=self.class_emb_only, clip_model_type=self.clip_model_type
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0

            # load learned prompt
            targets_keys = [
                ["prompt_proj.weight", "prompt_proj.bias", "prompt_embeddings"],
                ["action_prompt_proj.weight", "action_prompt_proj.bias", "action_prompt_embeddings"],
            ]
            for i in range(len(self.prompt)):
                self.load_weight(self.prompt[i], targets_keys[i])
            self._vit.eval()

        return self._vit

    def to(self, device: torch.device) -> "PromptClipViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self
    
    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            std = self.noise_std
            if self.prompt:
                full_state = []
                if self.prompt[0] and isinstance(self.prompt[0], str):
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                elif self.prompt[0]:
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit.model.encode_image(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                
                if self.prompt[1] and isinstance(self.prompt[1], str):
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit(a_x, mode="action").float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)
                elif self.prompt[1]:
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit.model.encode_image(a_x).float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)

                if "goal_object_type_ind" in obs.keys():
                    goal = obs["goal_object_type_ind"]
                    goal = goal.unsqueeze(1).to(self.device)
                    full_state.append(goal)
                
                x = torch.cat(full_state, dim=1)
            else:
                NotImplementedError
        return x


class PromptClipATTMViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False, clip_model_type: str = "ViT-B/32"):
        super().__init__()
        self.model = model
        #self.model.visual.transformer.resblocks = nn.Sequential(
        #    *list(self.model.visual.transformer.resblocks)[:-1]
        #)
        self.class_emb_only = class_emb_only

        # prompt config
        if clip_model_type == "ViT-B/32":
            patch_size = (32, 32)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # action
            self.action_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.action_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # dynamics
            self.dynamics_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.dynamics_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            
        elif clip_model_type == "ViT-B/16":
            patch_size = (16, 16)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', !not! 'ViT-B/14'"
            )

        self.eval()

    def forward(self, x, mode=None):
        m = self.model.visual
        with torch.no_grad():
            x = m.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [
                    m.class_embedding.to(x.dtype)
                    + torch.zeros(
                        x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                    ),
                    x,
                ],
                dim=1,
            )  # shape = [*, grid ** 2 + 1, width]
            x = x + m.positional_embedding.to(x.dtype)
            
            if mode == "action":
                prompt_proj = self.action_prompt_proj
                prompt_embeddings = self.action_prompt_embeddings
            elif mode == "dynamics":
                prompt_proj = self.dynamics_prompt_proj
                prompt_embeddings = self.dynamics_prompt_embeddings
            else:
                prompt_proj = self.prompt_proj
                prompt_embeddings = self.prompt_embeddings

            # prompt embeddings
            B = x.size(0)
            x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)


            x = m.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = m.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = m.ln_post(x[:, 0, :])

            if m.proj is not None:
                x = x @ m.proj
                return x

            if self.class_emb_only:
                return x[:, 0, :]
            else:
                return x


class PromptClipATTMViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

        if clip_model_type == "ViT-B/32":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        self.prompt = kwargs["prompt"]
        self.noise_std = kwargs["noise_std"]
        self.weighted_sum = kwargs["weighted_sum"]
        self.weighted_cat = kwargs["weighted_cat"]
        
        if len(self.prompt)==1 or (not self.prompt[1]) or (not self.prompt[0]):
            pass
        else:
            output_shape = (512*2+1, )
        
        if self.weighted_sum:
            output_shape = (512*1+1, )

        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._vit: Optional[ClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    def load_weight(self, path, prompt_key):
        if path and isinstance(path, str):
            pretrained_dict = torch.load(path)
            model_dict = self._vit.state_dict()
            
            source_prompt_dict = {}
            for source_key, source_val in pretrained_dict.items():
                for key in prompt_key:
                    if (key in source_key):
                        if ("action" not in key) and (("action" in source_key) or ("dynamics" in source_key)):
                            continue
                        source_prompt_dict[source_key] = source_val
            
            target_prompt_dict = {}
            for target_key, target_val in model_dict.items():
                for key in prompt_key:
                    if (key in target_key):
                        if ("action" not in key) and (("action" in target_key) or ("dynamics" in target_key)):
                            continue
                        target_prompt_dict[target_key] = target_val
            
            assert len(source_prompt_dict.keys())
            print(source_prompt_dict.keys())
            print(target_prompt_dict.keys())
            
            # key value switching
            prompt_dict = {}
            for s_k, t_k in zip(source_prompt_dict.keys(), target_prompt_dict.keys()):
                prompt_dict[t_k] = source_prompt_dict[s_k]
            print(prompt_dict.keys(), "loaded")
            
            model_dict.update(prompt_dict)
            assert (model_dict[list(target_prompt_dict.keys())[0]] == prompt_dict[list(target_prompt_dict.keys())[0]]).all()
            self._vit.load_state_dict(model_dict)
        else:
            return


    @property
    def vit(self) -> PromptClipATTMViTEmbedder:
        if self._vit is None:
            clip_model = clip.load(self.clip_model_type, device=self.device)[0]
            self._vit = PromptClipATTMViTEmbedder(
                model=clip_model,
                class_emb_only=self.class_emb_only, clip_model_type=self.clip_model_type
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0

            # load learned prompt
            targets_keys = [
                ["prompt_proj.weight", "prompt_proj.bias", "prompt_embeddings"],
                ["action_prompt_proj.weight", "action_prompt_proj.bias", "action_prompt_embeddings"],
            ]
            for i in range(len(self.prompt)):
                self.load_weight(self.prompt[i], targets_keys[i])
            self._vit.eval()

        return self._vit

    def to(self, device: torch.device) -> "PromptClipATTMViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self
    
    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            std = self.noise_std
            if self.prompt:
                full_state = []
                ori_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                clip_x = self.vit.model.encode_image(ori_x).float()
                clip_x = clip_x / clip_x.norm(dim=-1, keepdim=True)
                if std:
                    #noise = torch.clip(torch.normal(0, std, size=clip_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                    noise = torch.normal(0, std, size=clip_x.size()).to(self.device)
                    clip_x += noise
                    clip_x = clip_x / clip_x.norm(dim=-1, keepdim=True)
                full_state.append(clip_x)

                if self.prompt[0] and isinstance(self.prompt[0], str):
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        noise = torch.normal(0, std, size=s_x.size()).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                elif self.prompt[0]:
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit.model.encode_image(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        noise = torch.normal(0, std, size=s_x.size()).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                
                if self.prompt[1] and isinstance(self.prompt[1], str):
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit(a_x, mode="action").float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        noise = torch.normal(0, std, size=a_x.size()).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)
                elif self.prompt[1]:
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit.model.encode_image(a_x).float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        noise = torch.normal(0, std, size=a_x.size()).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)

                if "goal_object_type_ind" in obs.keys():
                    goal = obs["goal_object_type_ind"]
                    goal = goal.unsqueeze(1).to(self.device)
                    full_state.append(goal)
                
                x = torch.cat(full_state, dim=1)
            else:
                NotImplementedError
        return x


class SNPromptClipViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False, clip_model_type: str = "ViT-B/32"):
        super().__init__()
        self.model = model
        #self.model.visual.transformer.resblocks = nn.Sequential(
        #    *list(self.model.visual.transformer.resblocks)[:-1]
        #)
        self.class_emb_only = class_emb_only

        # prompt config
        if clip_model_type == "ViT-B/32":
            patch_size = (32, 32)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # action
            self.action_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.action_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            
        elif clip_model_type == "ViT-B/16":
            patch_size = (16, 16)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', !not! 'ViT-B/14'"
            )

        self.eval()

    def forward(self, x, mode=None):
        m = self.model.visual
        with torch.no_grad():
            x = m.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [
                    m.class_embedding.to(x.dtype)
                    + torch.zeros(
                        x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                    ),
                    x,
                ],
                dim=1,
            )  # shape = [*, grid ** 2 + 1, width]
            x = x + m.positional_embedding.to(x.dtype)
            
            if mode == "action":
                prompt_proj = self.action_prompt_proj
                prompt_embeddings = self.action_prompt_embeddings
            else:
                prompt_proj = self.prompt_proj
                prompt_embeddings = self.prompt_embeddings

            # prompt embeddings
            B = x.size(0)
            x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)


            x = m.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = m.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = m.ln_post(x[:, 0, :])

            if m.proj is not None:
                x = x @ m.proj
                return x

            if self.class_emb_only:
                return x[:, 0, :]
            else:
                return x


class SNPromptClipViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

        if clip_model_type == "ViT-B/32":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        self.prompt = kwargs["prompt"]
        self.noise_std = kwargs["noise_std"]
        self.state_text_features = None
        self.action_text_features = None
        self.cnt = 0
        
        if len(self.prompt)==1 or (not self.prompt[1][1]) or (not self.prompt[0][1]):
            pass
        else:
            output_shape = (512*2+1, )

        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._vit: Optional[ClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        super().__init__(**prepare_locals_for_super(locals()))

    def load_weight(self, path, prompt_key):
        if path and isinstance(path, str) and prompt_key is not None:
            pretrained_dict = torch.load(path)
            model_dict = self._vit.state_dict()
            
            source_prompt_dict = {}
            for source_key, source_val in pretrained_dict.items():
                for key in prompt_key:
                    if (key in source_key):
                        if ("action" not in key) and (("action" in source_key) or ("dynamics" in source_key)):
                            continue
                        source_prompt_dict[source_key] = source_val
            
            target_prompt_dict = {}
            for target_key, target_val in model_dict.items():
                for key in prompt_key:
                    if (key in target_key):
                        if ("action" not in key) and (("action" in target_key) or ("dynamics" in target_key)):
                            continue
                        target_prompt_dict[target_key] = target_val
            
            assert len(source_prompt_dict.keys())
            print(source_prompt_dict.keys())
            print(target_prompt_dict.keys())
            
            # key value switching
            prompt_dict = {}
            for s_k, t_k in zip(source_prompt_dict.keys(), target_prompt_dict.keys()):
                prompt_dict[t_k] = source_prompt_dict[s_k]
            
            model_dict.update(prompt_dict)
            assert (model_dict[list(target_prompt_dict.keys())[0]] == prompt_dict[list(target_prompt_dict.keys())[0]]).all()
            self._vit.load_state_dict(model_dict)
            print(prompt_dict.keys(), "loaded")

        else:
            if path and isinstance(path, str) and "state" in path:
                self.state_text_features = torch.load(path)
                self.state_text_features = self.state_text_features / self.state_text_features.norm(dim=-1, keepdim=True)
                print("state text features loaded")
            elif path and isinstance(path, str) and "action" in path:
                self.action_text_features = torch.load(path)
                self.action_text_features = self.action_text_features / self.action_text_features.norm(dim=-1, keepdim=True)
                print("action text features loaded")
            else:
                return


    @property
    def vit(self) -> SNPromptClipViTEmbedder:
        if self._vit is None:
            clip_model = clip.load(self.clip_model_type, device=self.device)[0]
            self._vit = SNPromptClipViTEmbedder(
                model=clip_model,
                class_emb_only=self.class_emb_only, clip_model_type=self.clip_model_type
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0

            # load learned prompt
            targets_keys = [
                ["prompt_proj.weight", "prompt_proj.bias", "prompt_embeddings"],
                ["action_prompt_proj.weight", "action_prompt_proj.bias", "action_prompt_embeddings"],
            ]
            for i in range(len(self.prompt)):
                self.load_weight(self.prompt[i][0], targets_keys[i])
                self.load_weight(self.prompt[i][1], None)
            self._vit.eval()

        return self._vit

    def to(self, device: torch.device) -> "SNPromptClipViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self
    
    def make_logits(self, x, text_features):
        logits = 100. * x @ text_features.t()#.to(self.device)
        logits = torch.sigmoid(logits) > 0.5
        return logits
    
    def resampling(self, ori_x, x, pre, post, std, t_f):
        while torch.mean((pre == post).float())<0.90:
            x = ori_x + torch.clip(torch.normal(0, std, size=ori_x.size()), -1.5 * std, 1.5 * std).to(self.device)
            x = x / x.norm(dim=-1, keepdim=True)

            post = self.make_logits(x, t_f)
            self.cnt+=1
            if self.cnt==500:
                self.cnt = 0
                #x = ori_x
                print("giveup")
                break
        #print(self.cnt)
        self.cnt = 0
        return x 
    
    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            std = self.noise_std
            if self.prompt:
                full_state = []
                if self.prompt[0][1] and isinstance(self.prompt[0][1], str):
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        if self.state_text_features is not None:
                            ori_x = s_x.detach()
                            pre = self.make_logits(ori_x, self.state_text_features)
                            s_x = s_x + noise
                            s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                            post = self.make_logits(s_x, self.state_text_features)
                            s_x = self.resampling(ori_x, s_x, pre, post, std, self.state_text_features)
                    full_state.append(s_x)

                elif self.prompt[0][1]:
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit.model.encode_image(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                
                if self.prompt[1][1] and isinstance(self.prompt[1][1], str):
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit(a_x, mode="action").float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        if self.action_text_features is not None:
                            ori_x = a_x.detach()
                            pre = self.make_logits(ori_x, self.action_text_features)
                            a_x = a_x + noise
                            a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                            post = self.make_logits(a_x, self.action_text_features)
                            a_x = self.resampling(ori_x, a_x, pre, post, std, self.action_text_features)
                    full_state.append(a_x)

                elif self.prompt[1][1]:
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit.model.encode_image(a_x).float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)

                if "goal_object_type_ind" in obs.keys():
                    goal = obs["goal_object_type_ind"]
                    goal = goal.unsqueeze(1).to(self.device)
                    full_state.append(goal)
                
                x = torch.cat(full_state, dim=1)
            else:
                NotImplementedError
        return x


class SNPromptClipATTMViTEmbedder(nn.Module):
    def __init__(self, model: CLIP, class_emb_only: bool = False, clip_model_type: str = "ViT-B/32"):
        super().__init__()
        self.model = model
        #self.model.visual.transformer.resblocks = nn.Sequential(
        #    *list(self.model.visual.transformer.resblocks)[:-1]
        #)
        self.class_emb_only = class_emb_only

        # prompt config
        if clip_model_type == "ViT-B/32":
            patch_size = (32, 32)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            # action
            self.action_prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.action_prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
            
        elif clip_model_type == "ViT-B/16":
            patch_size = (16, 16)
            _, prompt_dim = self.model.visual.positional_embedding.shape
            num_tokens = 8
            hidden_size = 768
            self.prompt_dropout = Dropout(0.1)
            self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
            self.prompt_embeddings = nn.Parameter(torch.zeros(
                1, num_tokens, prompt_dim))
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', !not! 'ViT-B/14'"
            )

        self.eval()

    def forward(self, x, mode=None):
        m = self.model.visual
        with torch.no_grad():
            x = m.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [
                    m.class_embedding.to(x.dtype)
                    + torch.zeros(
                        x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
                    ),
                    x,
                ],
                dim=1,
            )  # shape = [*, grid ** 2 + 1, width]
            x = x + m.positional_embedding.to(x.dtype)
            
            if mode == "action":
                prompt_proj = self.action_prompt_proj
                prompt_embeddings = self.action_prompt_embeddings
            else:
                prompt_proj = self.prompt_proj
                prompt_embeddings = self.prompt_embeddings

            # prompt embeddings
            B = x.size(0)
            x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)


            x = m.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = m.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = m.ln_post(x[:, 0, :])

            if m.proj is not None:
                x = x @ m.proj
                return x

            if self.class_emb_only:
                return x[:, 0, :]
            else:
                return x


class SNPromptClipATTMViTPreprocessor(Preprocessor):
    """Preprocess RGB or depth image using a ResNet model with CLIP model
    weights."""

    CLIP_RGB_MEANS = (0.485, 0.456, 0.406)
    CLIP_RGB_STDS = (0.229, 0.224, 0.225)

    def __init__(
        self,
        rgb_input_uuid: str,
        clip_model_type: str,
        class_emb_only: bool,
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        assert clip_model_type in clip.available_models()

        if clip_model_type == "ViT-B/32":
            output_shape = (512+1, )
        elif clip_model_type == "ViT-B/16":
            output_shape = (512, )
        elif clip_model_type == "ViT-L/14":
            output_shape = (768, )
        else:
            raise NotImplementedError(
                f"Currently `clip_model_type` must be one of 'ViT-B/32', 'ViT-B/16', or 'ViT-B/14'"
            )

        # optional
        self.prompt = kwargs["prompt"]
        self.noise_std = kwargs["noise_std"]
        self.weighted_sum = kwargs["weighted_sum"]
        self.weighted_cat = kwargs["weighted_cat"]
        
        self.state_text_features = None
        self.action_text_features = None
        self.cnt = 0
        
        if len(self.prompt)==1 or (not self.prompt[1]) or (not self.prompt[0]):
            pass
        else:
            output_shape = (512*2+1, )
        
        if self.weighted_sum:
            output_shape = (512*1+1, )

        if class_emb_only:
            output_shape = output_shape[1:]

        self.clip_model_type = clip_model_type

        self.class_emb_only = class_emb_only

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )
        self._vit: Optional[ClipViTEmbedder] = None

        low = -np.inf
        high = np.inf
        shape = output_shape
        self.output_shape = shape

        input_uuids = [rgb_input_uuid]
        assert (
            len(input_uuids) == 1
        ), "resnet preprocessor can only consume one observation type"

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        self.class_text_features = torch.load("projects/object_navigation/prompts/16shot_CoOp/classification_state_text_features_best.pth")

        super().__init__(**prepare_locals_for_super(locals()))

    def load_weight(self, path, prompt_key):
        if path and isinstance(path, str) and prompt_key is not None:
            pretrained_dict = torch.load(path)
            model_dict = self._vit.state_dict()
            
            source_prompt_dict = {}
            for source_key, source_val in pretrained_dict.items():
                for key in prompt_key:
                    if (key in source_key):
                        if ("action" not in key) and (("action" in source_key) or ("dynamics" in source_key)):
                            continue
                        source_prompt_dict[source_key] = source_val
            
            target_prompt_dict = {}
            for target_key, target_val in model_dict.items():
                for key in prompt_key:
                    if (key in target_key):
                        if ("action" not in key) and (("action" in target_key) or ("dynamics" in target_key)):
                            continue
                        target_prompt_dict[target_key] = target_val
            
            assert len(source_prompt_dict.keys())
            print(source_prompt_dict.keys())
            print(target_prompt_dict.keys())
            
            # key value switching
            prompt_dict = {}
            for s_k, t_k in zip(source_prompt_dict.keys(), target_prompt_dict.keys()):
                prompt_dict[t_k] = source_prompt_dict[s_k]
            
            model_dict.update(prompt_dict)
            assert (model_dict[list(target_prompt_dict.keys())[0]] == prompt_dict[list(target_prompt_dict.keys())[0]]).all()
            self._vit.load_state_dict(model_dict)
            print(prompt_dict.keys(), "loaded")

        else:
            if path and isinstance(path, str) and "state" in path:
                self.state_text_features = torch.load(path)
                self.state_text_features = self.state_text_features / self.state_text_features.norm(dim=-1, keepdim=True)
                print("state text features loaded")
            elif path and isinstance(path, str) and "action" in path:
                self.action_text_features = torch.load(path)
                self.action_text_features = self.action_text_features / self.action_text_features.norm(dim=-1, keepdim=True)
                print("action text features loaded")
            else:
                return


    @property
    def vit(self) -> SNPromptClipATTMViTEmbedder:
        if self._vit is None:
            clip_model = clip.load(self.clip_model_type, device=self.device)[0]
            self._vit = SNPromptClipATTMViTEmbedder(
                model=clip_model,
                class_emb_only=self.class_emb_only, clip_model_type=self.clip_model_type
            ).to(self.device)
            for module in self._vit.modules():
                if "BatchNorm" in type(module).__name__:
                    module.momentum = 0.0

            # load learned prompt
            targets_keys = [
                ["prompt_proj.weight", "prompt_proj.bias", "prompt_embeddings"],
                ["action_prompt_proj.weight", "action_prompt_proj.bias", "action_prompt_embeddings"],
            ]
            for i in range(len(self.prompt)):
                self.load_weight(self.prompt[i][0], targets_keys[i])
                self.load_weight(self.prompt[i][1], None)
            self._vit.eval()

        return self._vit

    def to(self, device: torch.device) -> "SNPromptClipATTMViTPreprocessor":
        self._vit = self.vit.to(device)
        self.device = device
        return self
    
    def make_logits(self, x, text_features):
        logits = 100. * x @ text_features.t()#.to(self.device)
        logits = torch.sigmoid(logits) > 0.5
        return logits
    
    def resampling(self, ori_x, x, pre, post, std, t_f):
        while torch.mean((pre == post).float())<0.90:
            x = ori_x + torch.clip(torch.normal(0, std, size=ori_x.size()), -1.5 * std, 1.5 * std).to(self.device)
            x = x / x.norm(dim=-1, keepdim=True)

            post = self.make_logits(x, t_f)
            self.cnt+=1
            if self.cnt==100:
                self.cnt = 0
                #x = ori_x
                print("giveup")
                break
        #if self.cnt > 100:
        #    print(self.cnt)
        self.cnt = 0
        return x 
    
    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        with torch.no_grad():
            std = self.noise_std
            if self.prompt:
                full_state = []
                ori_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                clip_x = self.vit.model.encode_image(ori_x).float()
                clip_x = clip_x / clip_x.norm(dim=-1, keepdim=True)
                if std:
                    noise = torch.clip(torch.normal(0, std, size=clip_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                    if self.class_text_features is not None:
                        ori_x = clip_x.detach()
                        pre = self.make_logits(ori_x, self.class_text_features)
                        clip_x = clip_x + noise
                        clip_x = clip_x / clip_x.norm(dim=-1, keepdim=True)
                        post = self.make_logits(clip_x, self.class_text_features)
                        clip_x = self.resampling(ori_x, clip_x, pre, post, std, self.class_text_features)
                full_state.append(clip_x)

                if self.prompt[0][1] and isinstance(self.prompt[0][1], str):
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        if self.state_text_features is not None:
                            ori_x = s_x.detach()
                            pre = self.make_logits(ori_x, self.state_text_features)
                            s_x = s_x + noise
                            s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                            post = self.make_logits(s_x, self.state_text_features)
                            s_x = self.resampling(ori_x, s_x, pre, post, std, self.state_text_features)
                    full_state.append(s_x)

                elif self.prompt[0][1]:
                    s_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    s_x = self.vit.model.encode_image(s_x).float()
                    s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=s_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        s_x += noise
                        s_x = s_x / s_x.norm(dim=-1, keepdim=True)
                    full_state.append(s_x)
                
                if self.prompt[1][1] and isinstance(self.prompt[1][1], str):
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit(a_x, mode="action").float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        if self.action_text_features is not None:
                            ori_x = a_x.detach()
                            pre = self.make_logits(ori_x, self.action_text_features)
                            a_x = a_x + noise
                            a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                            post = self.make_logits(a_x, self.action_text_features)
                            a_x = self.resampling(ori_x, a_x, pre, post, std, self.action_text_features)
                    full_state.append(a_x)

                elif self.prompt[1][1]:
                    a_x = obs[self.input_uuids[0]].to(self.device).permute(0, 3, 1, 2)  # bhwc -> bchw
                    a_x = self.vit.model.encode_image(a_x).float()
                    a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    if std:
                        #noise = torch.clip(torch.normal(0, std, size=a_x.size()), -1.5 * std, 1.5 * std).to(self.device)
                        a_x += noise
                        a_x = a_x / a_x.norm(dim=-1, keepdim=True)
                    full_state.append(a_x)

                if "goal_object_type_ind" in obs.keys():
                    goal = obs["goal_object_type_ind"]
                    goal = goal.unsqueeze(1).to(self.device)
                    full_state.append(goal)
                
                x = torch.cat(full_state, dim=1)
            else:
                NotImplementedError
        return x



class ClipTextPreprocessor(Preprocessor):

    def __init__(
        self,
        goal_sensor_uuid: str,
        object_types: List[str],
        device: Optional[torch.device] = None,
        device_ids: Optional[List[torch.device]] = None,
        **kwargs: Any,
    ):
        output_shape = (1024,)

        self.object_types = object_types

        self.device = torch.device("cpu") if device is None else device
        self.device_ids = device_ids or cast(
            List[torch.device], list(range(torch.cuda.device_count()))
        )

        low = -np.inf
        high = np.inf
        shape = output_shape

        observation_space = gym.spaces.Box(low=low, high=high, shape=shape)

        input_uuids = [goal_sensor_uuid]        

        super().__init__(**prepare_locals_for_super(locals()))

    @property
    def text_encoder(self):
        if self._clip_model is None:
            self._clip_model = clip.load('RN50', device=self.device)[0]
            self._clip_model.eval()
        return self._clip_model.encode_text

    def to(self, device: torch.device):
        self.device = device
        self._clip_model = None
        return self

    def process(self, obs: Dict[str, Any], *args: Any, **kwargs: Any) -> Any:
        object_inds = obs[self.input_uuids[0]]
        object_types = [self.object_types[ind] for ind in object_inds]
        x = clip.tokenize([f"navigate to the {obj}" for obj in object_types]).to(self.device)
        with torch.no_grad():
            return self.text_encoder(x).float()
