# m1

import os
import sys
from typing import Any, Dict, List, Union, Optional, Type
from zipfile import ZipFile
import gym
import clip
import numpy as np
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import bc.MVPT_ConPE as MVPT_ConPE
import bc.MVPT_AVG as MVPT_AVG
import bc.clipvit as clipvit
from bc.model import IthorDisentangledVAE


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False
    for module in model.modules():
        if "BatchNorm" in type(module).__name__:
            module.momentum = 0.0
    model.eval()
    return model


def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        if p.grad:
            p.grad.data = p.grad.data.float()

        
class ClipExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str):
        super().__init__(observation_space, features_dim=1)
        self.meta_mode = False
        self.clip_model = None
        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()
        del clip_model.transformer

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = clip_model
                total_concat_size += 512
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        self.noise_std = 0.0

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor.visual(obs)
                if self.noise_std:
                    noise = th.clip(th.normal(0, self.noise_std, size=tens.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(tens.device)
                    tens += noise
            encoded_tensor_list.append(tens)
        out = th.cat(encoded_tensor_list, dim=1)
        return out
            
        
class PromptAttentionExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str, prompt_env: str):
        super().__init__(observation_space, features_dim=1)

        self.noise_std = 0.00 # 0.002
        self.sm_noise = [0.7, True]
        # self.sm_noise = [0.5, True]
        self.meta_mode = False

        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in self.clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        self.clip_model.eval().float()
        del self.clip_model.transformer
        
        model = MVPT_ConPE.CONPEMultiVisualPromptTuningCLIP(self.clip_model, device,  
            clip_model_type="CLIPViT-B/32", n_vtk=8
        ).to(device)
        
        # load prompts #
        print("Loading Prompts !!")
        prompt_paths = []
        # representations
        domain_factors = ['BRIGHTNESS', "CONTRAST", "SATURATION", "HUE"] #, 'CONTRAST']
        print(f"Representation Factors: {domain_factors}")
        domains = len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/{env}/PROMPTS/{df}/checkpoints/contrastive__latest.pth')
        # dynamics
        domain_factors = ['GRAV_', 'XWIND'] #'CAMS'] #, 'GRAV', 'XWIND']
        print(f"Dyanamics Factors: {domain_factors}")
        domains += len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/{env}/PROMPTS/{df}/checkpoints/comparative_action_byol_latest.pth')
        
        model.prompt_init(prompt_paths, multi_p_mode=['ENSEMBLE', 'WEIGHTED', 'AVG']) #  [COMPOSE, UNIFORM/WEIGHTED, CAT/AVG]
        # Attention module
        self.source_prompt_attn_weight_list = nn.ModuleList([])
        for i in range(model.visual_backbone.prompt_num):
            attn = nn.Sequential(
                    nn.Linear(512, 128, bias=False),
                    nn.SiLU(),
                    nn.Linear(128, 512, bias=False),
                    nn.LayerNorm(512),
                )
            self.source_prompt_attn_weight_list.append(attn)
        
        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = model
                total_concat_size += 512 #(512 * domains)
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size
        

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                batch = obs.shape[0]
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                tens = tens.reshape(batch, self.extractors[key].visual_backbone.prompt_num, tens.size(-1)) # (batch, prompt_num, clip_dim)
                # Augmentation
                if self.noise_std:
                    noise = th.clip(th.normal(0, self.noise_std, size=tens.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(tens.device)
                    tens += noise
                tens = tens / tens.norm(dim=-1, keepdim=True)
                
                # Attention
                clip_tens = self.clip_model.visual(obs)
                query = clip_tens.unsqueeze(1) # th.Size([B, 1, clip_embedding_dim])
                # P_emb -> P_emb_hat
                P_emb_list = []
                P_emb_hat_list = []
                for i, attn in enumerate(self.source_prompt_attn_weight_list):
                    P_emb = tens[:,i,:] # th.Size([B, clip_embedding_dim])
                    P_emb_list.append(P_emb.unsqueeze(1))
                    P_emb_hat = attn(P_emb) # th.Size([B, clip_embedding_dim])
                    P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
                key = th.cat(P_emb_hat_list, dim=1) # th.Size([B, prompt_num, clip_embedding_dim])
                value = th.cat(P_emb_list, dim=1) # th.Size([B, prompt_num, clip_embedding_dim])
                # cosim guidance
                Q_norm = th.norm(query, dim=2, keepdim=True)
                V_norm = th.norm(value, dim=2, keepdim=True)
                dot_prod = th.bmm(query, value.permute(0, 2, 1))
                cos_sim = dot_prod / th.bmm(Q_norm, V_norm.permute(0, 2, 1)) # th.Size([B, 1, prompt_num])
                # self.cos_sim_list.append(cos_sim)
                # print(th.mean(th.stack(self.cos_sim_list).squeeze(), 0))
                score = th.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # th.Size([B, 1, prompt_num])
                # gubel noise sampling
                if not self.sm_noise[1]:
                    gumbel_noise = -th.log(-th.log(th.rand_like(cos_sim)))
                else:
                    gumbel_noise = 0
                if self.sm_noise[0]:
                    score = score + gumbel_noise / self.sm_noise[0]
                else:
                    score = score
                attn = th.softmax(cos_sim*score, -1) # th.Size([B, 1, prompt_num])
                # self.attn_list.append(attn)
                # print(th.mean(th.stack(self.attn_list).squeeze(), 0))
                context = th.bmm(attn, value) # th.Size([B, 1, prompt_num])
                tens = clip_tens.unsqueeze(1) + context
                tens = tens.view(batch, -1)
                
            encoded_tensor_list.append(tens)
        
        out = th.cat(encoded_tensor_list, dim=1)
        return out


class ConPEExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str, prompt_env: str):
        super().__init__(observation_space, features_dim=1)
        #####OPTIONS###########
        self.noise_std = 0.00 # 0.002
        self.sm_noise = [0.07, True]
        self.meta_mode = True

        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in self.clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        self.clip_model.eval().float()
        del self.clip_model.transformer
        
        model = MVPT_ConPE.CONPEMultiVisualPromptTuningCLIP(self.clip_model, device,  
            clip_model_type="CLIPViT-B/32", n_vtk=8
        ).to(device)
        
        # load prompts #
        print("Loading Prompts !!")
        prompt_paths = []
        # representations
        domain_factors = ['BRIGHTNESS', "CONTRAST", "SATURATION", "HUE"] #, 'CONTRAST']
        print(f"Representation Factors: {domain_factors}")
        domains = len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/contrastive__latest.pth')
        # dynamics
        domain_factors = ['GRAV_', 'XWIND'] #'CAMS'] #, 'GRAV', 'XWIND']
        print(f"Dyanamics Factors: {domain_factors}")
        domains += len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/comparative_action_byol_latest.pth')
        
        model.prompt_init(prompt_paths, multi_p_mode=['ENSEMBLE', 'WEIGHTED', 'AVG'], meta_mode=self.meta_mode) #  [COMPOSE, UNIFORM/WEIGHTED, CAT/AVG]
        # Attention module
        self.source_prompt_attn_weight_list = nn.ModuleList([])
        for i in range(model.visual_backbone.prompt_num):
            attn = nn.Sequential(
                    nn.Linear(512, 128, bias=False),
                    nn.SiLU(),
                    nn.Linear(128, 512, bias=False),
                    nn.LayerNorm(512),
                )
            self.source_prompt_attn_weight_list.append(attn)
        
        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = model
                total_concat_size += 512 #(512 * domains)
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.extractors.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)

        for name, param in self.clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")
        

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                batch = obs.shape[0]
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                tens = tens.reshape(batch, self.extractors[key].visual_backbone.prompt_num, tens.size(-1)) # (batch, prompt_num, clip_dim)
                # Augmentation
                if self.noise_std:
                    noise = th.clip(th.normal(0, self.noise_std, size=tens.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(tens.device)
                    tens += noise
                tens = tens / tens.norm(dim=-1, keepdim=True)
                
                # Attention
                if self.meta_mode:
                    clip_tens = extractor(obs, self.meta_mode)
                else:
                    clip_tens = self.clip_model.visual(obs)

                query = clip_tens.unsqueeze(1) # th.Size([B, 1, clip_embedding_dim])
                # P_emb -> P_emb_hat
                P_emb_list = []
                P_emb_hat_list = []
                for i, attn in enumerate(self.source_prompt_attn_weight_list):
                    P_emb = tens[:,i,:] # th.Size([B, clip_embedding_dim])
                    P_emb_list.append(P_emb.unsqueeze(1))
                    P_emb_hat = attn(P_emb) # th.Size([B, clip_embedding_dim])
                    P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
                key = th.cat(P_emb_hat_list, dim=1) # th.Size([B, prompt_num, clip_embedding_dim])
                value = th.cat(P_emb_list, dim=1) # th.Size([B, prompt_num, clip_embedding_dim])
                # cosim guidance
                Q_norm = th.norm(query, dim=2, keepdim=True)
                V_norm = th.norm(value, dim=2, keepdim=True)
                dot_prod = th.bmm(query, value.permute(0, 2, 1))
                cos_sim = dot_prod / th.bmm(Q_norm, V_norm.permute(0, 2, 1)) # th.Size([B, 1, prompt_num])
                # self.cos_sim_list.append(cos_sim)
                # print(th.mean(th.stack(self.cos_sim_list).squeeze(), 0))
                score = th.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # th.Size([B, 1, prompt_num])
                # gubel noise sampling
                if self.sm_noise[0] and not self.sm_noise[1]:
                    gumbel_noise = -th.log(-th.log(th.rand_like(cos_sim)))
                    score = score + gumbel_noise / self.sm_noise[0]
                attn = th.softmax(cos_sim*score, -1) # th.Size([B, 1, prompt_num])
                # self.attn_list.append(attn)
                # print(th.mean(th.stack(self.attn_list).squeeze(), 0))
                context = th.bmm(attn, value) # th.Size([B, 1, prompt_num])
                tens = clip_tens.unsqueeze(1) + context
                tens = tens.view(batch, -1)
                
            encoded_tensor_list.append(tens)
        
        out = th.cat(encoded_tensor_list, dim=1)
        return out


class ATTEMPTExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str, prompt_env: str):
        super().__init__(observation_space, features_dim=1)

        #####OPTIONS###########
        self.noise_std = 0.00 # 0.002
        self.meta_mode = True

        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        convert_models_to_fp32(clip_model)
        self.clip_model = freeze_model(clip_model)
        del self.clip_model.transformer
        
        model = MVPT_AVG.AVGMultiVisualPromptTuningCLIP(clip_model, device,  
            clip_model_type="CLIPViT-B/32", n_vtk=8
        ).to(device)
        
        # load prompts #
        print("Loading Prompts !!")
        prompt_paths = []
        # representations
        domain_factors = ['BRIGHTNESS', "CONTRAST", "SATURATION", "HUE"] #, 'CONTRAST']
        print(f"Representation Factors: {domain_factors}")
        domains = len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/contrastive__latest.pth')
        # dynamics
        domain_factors = ['GRAV_', 'XWIND'] #'CAMS'] #, 'GRAV', 'XWIND']
        print(f"Dyanamics Factors: {domain_factors}")
        domains += len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/comparative_action_byol_latest.pth')
        
        model.prompt_init(prompt_paths, multi_p_mode=['ATTEMPT', 'WEIGHTED', 'AVG'], meta_mode=self.meta_mode) #  [COMPOSE, UNIFORM/WEIGHTED, CAT/AVG]

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = model
                total_concat_size += 512 #(512 * domains)
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size


    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                batch = obs.shape[0]
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                tens = tens / tens.norm(dim=-1, keepdim=True)
                
            encoded_tensor_list.append(tens)
        
        out = th.cat(encoded_tensor_list, dim=1)
        return out


class SESoMExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str, prompt_env: str):
        super().__init__(observation_space, features_dim=1)

        #####OPTIONS###########
        self.noise_std = 0.00 # 0.002
        self.meta_mode = True

        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        convert_models_to_fp32(clip_model)
        self.clip_model = freeze_model(clip_model)
        del self.clip_model.transformer
        
        model = MVPT_AVG.AVGMultiVisualPromptTuningCLIP(clip_model, device,  
            clip_model_type="CLIPViT-B/32", n_vtk=8
        ).to(device)
        
        # load prompts #
        print("Loading Prompts !!")
        prompt_paths = []
        # representations
        domain_factors = ['BRIGHTNESS', "CONTRAST", "SATURATION", "HUE"] #, 'CONTRAST']
        print(f"Representation Factors: {domain_factors}")
        domains = len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/contrastive__latest.pth')
        # dynamics
        domain_factors = ['GRAV_', 'XWIND'] #'CAMS'] #, 'GRAV', 'XWIND']
        print(f"Dyanamics Factors: {domain_factors}")
        domains += len(domain_factors)
        for df in domain_factors:
            prompt_paths.append(f'/path/to/MMRL/logs/Robomani/reach/PROMPTS/{df}/checkpoints/comparative_action_byol_latest.pth')
        
        model.prompt_init(prompt_paths, multi_p_mode=['SESoM', 'WEIGHTED', 'AVG']) #  [COMPOSE, UNIFORM/WEIGHTED, CAT/AVG]
        # ATTENTION MODULE
        
        # BASELINE REFERENCE: SESoM
        self.attn_W_down = nn.Linear(512, 128, bias=False)
        self.attn_W_up = nn.Linear(128, 512, bias=False)
        self.attn_non_linear = nn.SiLU()
        self.attn_layer_norm = nn.LayerNorm(512)

        self.source_prompt_attn_weight_list = nn.ModuleList([])
        for i in range(model.visual_backbone.prompt_num):
            attn = nn.Sequential(
                    nn.Linear(512, 128, bias=False),
                    nn.SiLU(),
                    nn.Linear(128, 512, bias=False),
                    nn.LayerNorm(512),
                )
            self.source_prompt_attn_weight_list.append(attn)

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = model
                total_concat_size += 512 #(512 * domains)
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                batch = obs.shape[0]
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                tens = tens / tens.norm(dim=-1, keepdim=True)
                tens = tens.reshape(batch, self.extractors[key].visual_backbone.prompt_num, tens.size(-1)) # (batch, prompt_num, clip_dim)
                
                clip_tens = self.clip_model.visual(obs)
                
                # BASELINE REFERENCE: SESoM
                H = self.attn_W_down(clip_tens)
                H = self.attn_non_linear(H)
                H = self.attn_W_up(H)
                H = self.attn_layer_norm(H) # torch.Size([B, clip_embedding_dim])
                P_emb_list = []
                P_emb_hat_list = []
                for i, attn in enumerate(self.source_prompt_attn_weight_list):
                    P_emb = tens[:,i,:] # torch.Size([B, hidden_size])
                    P_emb_list.append(P_emb.unsqueeze(1))
                    P_emb_hat = attn(P_emb) # torch.Size([1, hidden_size])
                    P_emb_hat_list.append(P_emb_hat.unsqueeze(1))
                key = th.cat(P_emb_hat_list, dim=1)
                value = th.cat(P_emb_list, dim=1)
                query = H.unsqueeze(1) # torch.Size([B, 1, clip_embedding_dim])
                score = th.bmm(query, key.transpose(1, 2)) / np.sqrt(query.size(-1)) # (batch, 1, s_len)
                attn = th.softmax(score, -1) # (batch, 1, s_len)
                # weighted sum
                x = th.bmm(attn, value) # (batch, 1, dim)
                tens = x.view(batch, -1)

            encoded_tensor_list.append(tens)
        
        out = th.cat(encoded_tensor_list, dim=1)
        return out


class CURLExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str):
        super().__init__(observation_space, features_dim=1)
        self.meta_mode = False
        self.clip_model = None
        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()
        del clip_model.transformer

        self.embedder = clipvit.ClipViTEmbedder(clip_model)
        model_dict = self.embedder.state_dict()
        source_dict = th.load("/path/to/MMRL/logs/curl_metaworld/ObjNav12mdps_16shot/checkpoint_0499.pth.tar")["state_dict"]
        temp_dict = {}
        # print(list(model_dict.keys())[0])
        # print(list(source_dict.keys())[11])
        # print(list(source_dict.keys())[11][17:])
        # exit()
        for k, v in source_dict.items():
            if "encoder_q" in k:
                k = k[17:]
                temp_dict[k] = v
        pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}  
        # print(pretrained_dict.keys())  
        self.embedder.load_state_dict(pretrained_dict)

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = self.embedder
                total_concat_size += 512
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        self.noise_std = 0.0

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                if self.noise_std:
                    noise = th.clip(th.normal(0, self.noise_std, size=tens.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(tens.device)
                    tens += noise
            encoded_tensor_list.append(tens)
        out = th.cat(encoded_tensor_list, dim=1)
        return out

class ATCExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str):
        super().__init__(observation_space, features_dim=1)
        self.meta_mode = False
        self.clip_model = None
        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()
        del clip_model.transformer

        self.embedder = clipvit.ClipViTEmbedder(clip_model)
        model_dict = self.embedder.state_dict()
        source_dict = th.load("/path/to/MMRL/logs/atc_metaworld_/checkpoints/comparative_action_byol_latest.pth")
        temp_dict = {}
        for k, v in source_dict.items():
            if "net" in k:
                k = k[4:]
                temp_dict[k] = v
        pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
        self.embedder.load_state_dict(pretrained_dict)

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = self.embedder
                total_concat_size += 512
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        self.noise_std = 0.0

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                if self.noise_std:
                    noise = th.clip(th.normal(0, self.noise_std, size=tens.size()), -1.5 * self.noise_std, 1.5 * self.noise_std).to(tens.device)
                    tens += noise
            encoded_tensor_list.append(tens)
        out = th.cat(encoded_tensor_list, dim=1)
        return out

class ACPExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str):
        super().__init__(observation_space, features_dim=1)
        self.meta_mode = False
        self.clip_model = None
        extractors = {}

        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')
        clip_model, _ = clip.load("ViT-B/32", device=device)
        for module in clip_model.modules():
            if "BatchNorm" in type(module).__name__:
                module.momentum = 0.0
        clip_model.eval().float()
        del clip_model.transformer

        self.embedder1 = clipvit.ClipViTEmbedder(clip_model)
        model_dict = self.embedder1.state_dict()
        source_dict = th.load("/path/to/MMRL/logs/curl_metaworld/ObjNav12mdps_16shot/checkpoint_0499.pth.tar")["state_dict"]
        temp_dict = {}
        for k, v in source_dict.items():
            if "encoder_q" in k:
                k = k[17:]
                temp_dict[k] = v
        pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}  
        self.embedder1.load_state_dict(pretrained_dict)

        self.embedder2 = clipvit.ClipViTEmbedder(clip_model)
        model_dict = self.embedder2.state_dict()
        source_dict = th.load("/path/to/MMRL/logs/metaworld_acp_/checkpoints/comparative_action_byol_latest.pth")
        temp_dict = {}
        for k, v in source_dict.items():
            if "net" in k:
                k = k[4:]
                temp_dict[k] = v
        pretrained_dict = {k: v for k, v in temp_dict.items() if k in model_dict}
        self.embedder2.load_state_dict(pretrained_dict)

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = self.embedder1
                total_concat_size += 512
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        self.noise_std = 0.0

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor(obs)
                tens += self.embedder2(obs)
                tens = tens / 2
            encoded_tensor_list.append(tens)
        out = th.cat(encoded_tensor_list, dim=1)
        return out


class LUSRExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict, env: str):
        super().__init__(observation_space, features_dim=1)
        self.meta_mode = False
        self.clip_model = None
        extractors = {}
        source_model = "/path/to/MMRL/logs/lusr_metaworld/model_200_ithor_cnn.pt"
        ### conpe load ####
        #conpe.seed_fix(777)
        device = th.device('cuda')

        self.embedder = IthorDisentangledVAE(class_latent_size=16, content_latent_size = 32)
        self.embedder.load_state_dict(th.load(source_model))
        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.embedder.named_parameters():
            param.requires_grad_(False)

        total_concat_size = 0
        for key, subspace in observation_space.spaces.items():
            if key == "image":
                extractors[key] = self.embedder
                total_concat_size += 32
        self.extractors = nn.ModuleDict(extractors)
        self._features_dim = total_concat_size

        self.noise_std = 0.0

    def forward(self, observations) -> th.Tensor:
        encoded_tensor_list = []

        for key, extractor in self.extractors.items():
            obs = observations[key]
            if key == 'image':
                obs = obs.reshape(-1, 3, 224, 224)
                tens = extractor.encoder.get_feature(obs)
            encoded_tensor_list.append(tens)
        out = th.cat(encoded_tensor_list, dim=1)
        return out