from collections import OrderedDict
from typing import Tuple, Optional, List, Dict
import math
from operator import mul
from functools import reduce

import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv2d, Dropout


from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

class CLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, num_classes, feat_dim):
        super(CLIPImageClassifier, self).__init__()

        self.dtype = clip_model.dtype
        self.backbone = clip_model
        self.head = nn.Linear(feat_dim, num_classes)
        self._features_dim = feat_dim

    def forward(self, x):
        features = self.backbone.encode_image(x)
        features = features.type(self.dtype)
        predictions = self.head(features)

        if self.training:
            return predictions, features
        else:
            return predictions
    
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        if optimiza_head:
            params = [
                {"params": self.head.parameters(), "lr": 1.0 * base_lr},
            ]
        else:
            params = []

        return params

## CoOp ##
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x

class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc, n_ctx):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = n_ctx
        ctx_init = None # "a photo of" 
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
        use_csc = use_csc

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if use_csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)#.requires_grad_(False)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = "end"

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

class CustomCLIPText(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, device, use_csc, n_ctx)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self,):        
        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)
        text_features = text_features.type(self.dtype)
        return text_features

class TextPromptCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, use_csc=False, n_ctx=8):
        super(TextPromptCLIPImageClassifier, self).__init__()
        self.visual_backbone = clip_model.visual
        self.textual_backbone = CustomCLIPText(clip_model, classnames, device, use_csc, n_ctx)
        self.logit_scale = clip_model.logit_scale
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in self.visual_backbone.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, mode=None):
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        if self.training:
            return logits, image_features, image_features
        else:
            return logits, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params
## CoOp END ##

## CoCoOp ##
class CoPromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, device):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = 8
        ctx_init = None
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        vis_dim = clip_model.visual.output_dim
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)

        self.meta_net = nn.Sequential(OrderedDict([
            ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
            ("relu", nn.ReLU(inplace=True)),
            ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
        ]))

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)  # (n_cls, n_tkn)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
    
    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,     # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )

        return prompts

    def forward(self, im_features):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx                     # (n_ctx, ctx_dim)
        bias = self.meta_net(im_features)  # (batch, ctx_dim)
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias           # (batch, n_ctx, ctx_dim)
        
        # Use instance-conditioned context tokens for all classes
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        
        return prompts

class CoCustomCLIPText(nn.Module):
    def __init__(self, clip_model, classnames,  device):
        super().__init__()
        self.prompt_learner = CoPromptLearner(clip_model, classnames, device)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image_features, label=None):
        tokenized_prompts = self.tokenized_prompts
        
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        prompts = self.prompt_learner(image_features)
        
        logits = []
        for pts_i, imf_i in zip(prompts, image_features):
            text_features = self.text_encoder(pts_i, tokenized_prompts)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            l_i = self.logit_scale * imf_i @ text_features.t()
            logits.append(l_i)
        logits = torch.stack(logits)
        
        return logits, image_features

class CoTextPromptCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device):
        super(CoTextPromptCLIPImageClassifier, self).__init__()
        self.visual_backbone = clip_model.visual
        self.textual_backbone = CoCustomCLIPText(clip_model, classnames, device)
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in self.visual_backbone.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images):
        image_features = self.visual_backbone(images)
        logits, image_features = self.textual_backbone(image_features)

        if self.training:
            return logits, image_features
        else:
            return logits
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params
## CoCoOp END ##


class VisualPromptCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device):
        super(VisualPromptCLIPImageClassifier, self).__init__()
        self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self._features_dim = feat_dim
        self.logit_scale = clip_model.logit_scale
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classnames]).to(device)
        text_features = clip_model.encode_text(text_inputs)
        self.text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.visual_backbone.named_parameters():
            if "prompter" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.visual_backbone.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = self.logit_scale * image_features @ self.text_features.t()
        if self.training:
            return logits, image_features
        else:
            return logits
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.prompter.parameters(), "lr": 1.0 * base_lr},
        ]

        return params

class VisualPromptTuningCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False):
        super(VisualPromptTuningCLIPImageClassifier, self).__init__()
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt)
        self._features_dim = feat_dim
        self.logit_scale = clip_model.logit_scale
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classnames]).to(device)
        text_features = clip_model.encode_text(text_inputs)
        self.text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        logits = self.logit_scale * image_features @ self.text_features.t()
        if self.training:
            return logits, image_features
        else:
            return logits
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
        ]

        return params

## CoOp ##
class CustomCLIPText_(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, device, use_csc)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype

    def forward(self, image_features):
        
        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)
        text_features = text_features.type(self.dtype)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

class TextPromptCLIPDiscriminator(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc=True):
        super(TextPromptCLIPDiscriminator, self).__init__()

        self.model = CustomCLIPText_(clip_model, classnames, device, use_csc)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.model.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")    
    
    def forward(self, features):
        predictions = self.model(features)
        return predictions
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.model.prompt_learner.parameters(), "lr": 1 * base_lr},
        ]

        return params
## CoOp END ##

## ADAPT MODULES
## vpCLIP ##
class VisualPromptCLIPWrapper(nn.Module):
    def __init__(self, clip_model, target_model, feat_dim, device=None):
        super(VisualPromptCLIPWrapper, self).__init__()
        
        #self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt)
        self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self.textual_backbone = target_model.textual_backbone
        self.logit_scale = clip_model.logit_scale
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if ("prompt" not in name):
                if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        if self.training:
            return logits, image_features
        else:
            return logits
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.prompter.parameters(), "lr": 1.0 * base_lr},
            #{"params": self.textual_backbone.parameters(), "lr": 1.0 * base_lr},
        ]
        return params
## vpCLIP END ##

## vptCLIP ##
class VisualPromptTuningCLIPWrapper(nn.Module):
    def __init__(self, clip_model, target_model, feat_dim, clip_model_type="CLIPViT-B/16", DeepPrompt=False, device=None):
        super(VisualPromptTuningCLIPWrapper, self).__init__()
        
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt)
        #self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self.textual_backbone = target_model.textual_backbone
        self.logit_scale = clip_model.logit_scale
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if ("prompt" not in name):
                if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def forward(self, images):
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        if self.training:
            return logits, image_features
        else:
            return logits
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            #{"params": self.textual_backbone.parameters(), "lr": 1.0 * base_lr},
        ]
        return params
## vptCLIP END ##

# Prompt Learning Models for policy learning

class Prompter(nn.Module):
    def __init__(self, device, pad_size=30, image_size=224):
        super().__init__()
        self.device = device
        self.base_size = image_size - pad_size*2
        self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
        self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
        self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
        self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
        self.conv2d = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0, bias=False)
    
    def forward(self, x):
        base = torch.zeros(1, 3, self.base_size, self.base_size).to(self.device)
        prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3)
        prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2)
        prompt = torch.cat(x.size(0) * [prompt])
        x = self.conv2d(x)
        return x + prompt

class CustomCLIPVisual(nn.Module):
    def __init__(self, clip_model, device):
        super().__init__()
        self.dtype = clip_model.dtype
        self.prompter = Prompter(device)
        self.visual_encoder = clip_model.visual
        
    def forward(self, images):
        prompted_images = self.prompter(images)
        image_features = self.visual_encoder(prompted_images)
        image_features = image_features.type(self.dtype)
        return image_features

class CustomCLIPTextual(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = PromptLearner(clip_model, classnames, device, use_csc, n_ctx)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self):        
        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)
        text_features = text_features.type(self.dtype)
        return text_features

class DualPromptCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, use_csc=False):
        super(DualPromptCLIPImageClassifier, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self.textual_backbone = CustomCLIPTextual(clip_model, classnames, device, use_csc)
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompter" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in self.visual_backbone.visual_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, mode=None):
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        if self.training:
            return logits, image_features
        else:
            return logits
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.prompter.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params


class PromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt

        # prompt config
        if "ViT-B/32" in clip_model_type:
            patch_size = (32, 32)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            patch_size = (16, 16)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        hidden_size = 768
        self.prompt_dropout = Dropout(0.1)
        self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
        nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa

        self.prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
        
    def forward(self, x):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        #print(x.shape)
        x = x + self.positional_embedding.to(x.dtype)
        #print(self.positional_embedding.shape)
        
        # incorporate_prompt
        B = x.size(0)
        x = torch.cat((
                x[:, :1, :],
                self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                x[:, 1:, :]
            ), dim=1)
        #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND

        if self.Deep:  # Deep prompt version
            hidden_states = None
            num_layers = self.transformer.layers

            for i in range(num_layers):
                if i == 0:
                    hidden_states = self.transformer.resblocks[i](x)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_emb = deep_prompt_emb.permute(1, 0, 2)  # NLD -> LND
            
                        hidden_states = torch.cat((
                            hidden_states[:1, :, :],
                            deep_prompt_emb,
                            hidden_states[(1+self.num_tokens):, :, :]
                        ), dim=0)

                    hidden_states = self.transformer.resblocks[i](hidden_states)
            x = hidden_states
        else:
            x = self.transformer(x)
        
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

class DualPromptViTCLIPImageClassifier(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False, n_vtk=5, use_csc=False, n_ctx=8):
        super(DualPromptViTCLIPImageClassifier, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk)
        self.textual_backbone = CustomCLIPTextual(clip_model, classnames, device, use_csc, n_ctx)
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images):
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone()

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        if self.training:
            return logits, image_features
        else:
            return logits
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params

class TextPromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc, n_ctx):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = n_ctx
        ctx_init = None # "a photo of" 
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
        use_csc = use_csc

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init

        else:
            # random initialization
            if use_csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)#.requires_grad_(False)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = "end"

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

class PromptTextual(nn.Module):
    def __init__(self, clip_model, classnames, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = TextPromptLearner(clip_model, classnames, device, use_csc, n_ctx)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self, indices, mode=None):
        if mode == "scene":
            prompts = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts

            prompts = prompts[indices]
            tokenized_prompts = tokenized_prompts[indices]
        elif mode == "name":
            prompts = self.prompt_learner()
            tokenized_prompts = self.tokenized_prompts

        text_features = self.text_encoder(prompts, tokenized_prompts)
        text_features = text_features.type(self.dtype)
        return text_features

class DualPromptViTCLIP(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False, n_vtk=5, use_csc=False, n_ctx=8):
        super(DualPromptViTCLIP, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk)
        self.textual_backbone = PromptTextual(clip_model, classnames, device, use_csc, n_ctx)
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, indices=None, mode=None):
        
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone(indices, mode)

        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        if self.training:
            return logits_per_image, logits_per_text, image_features
        else:
            return logits_per_image, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params


class DualPromptCLIP(nn.Module):
    def __init__(self, clip_model, classnames, feat_dim, device, clip_model_type="CLIPViT-B/16", use_csc=False, n_ctx=8):
        super(DualPromptCLIP, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = CustomCLIPVisual(clip_model, device)
        self.textual_backbone = PromptTextual(clip_model, classnames, device, use_csc, n_ctx)
        self._features_dim = feat_dim

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompter" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, indices=None, mode=None):
        
        image_features = self.visual_backbone(images)
        text_features = self.textual_backbone(indices, mode)

        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        if self.training:
            return logits_per_image, logits_per_text, image_features
        else:
            return logits_per_image, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params


class MTTextPromptLearner(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, device, use_csc, n_ctx):
        super().__init__()
        n_cls = len(class_scenes)
        n_cls_n = len(class_names)
        n_ctx = n_ctx
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
        use_csc = use_csc

        # scene
        # random initialization
        if use_csc[0]:
            print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
        else:
            print("Initializing a generic context")
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(ctx_vectors, std=0.02)

        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized

        # dynamic
        # random initialization
        if use_csc[1]:
            print("Initializing class-specific contexts")
            ctx_vectors = torch.empty(n_cls_n, n_ctx, ctx_dim, dtype=dtype)
        else:
            print("Initializing a generic context")
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(ctx_vectors, std=0.02)
        prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx_n = nn.Parameter(ctx_vectors)  # to be optimized

        # class scenes prompt
        class_scenes = [scene.replace("_", " ") for scene in class_scenes]
        scene_lens = [len(_tokenizer.encode(scene)) for scene in class_scenes]
        scene_prompts = [prompt_prefix + " " + scene + "." for scene in class_scenes]

        scene_tokenized_prompts = torch.cat([clip.tokenize(p) for p in scene_prompts]).to(device)
        with torch.no_grad():
            scene_embedding = clip_model.token_embedding(scene_tokenized_prompts).type(dtype)
        
        # class name prompt
        class_names = [name.replace("_", " ") for name in class_names]
        name_lens = [len(_tokenizer.encode(name)) for name in class_names]
        name_prompts = [prompt_prefix + " " + name + "." for name in class_names]

        name_tokenized_prompts = torch.cat([clip.tokenize(p) for p in name_prompts]).to(device)
        with torch.no_grad():
            name_embedding = clip_model.token_embedding(name_tokenized_prompts).type(dtype)
        
        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("scene_token_prefix", scene_embedding[:, :1, :])  # SOS
        self.register_buffer("scene_token_suffix", scene_embedding[:, 1 + n_ctx :, :])  # CLS, EOS
        self.register_buffer("name_token_prefix", name_embedding[:, :1, :])  # SOS
        self.register_buffer("name_token_suffix", name_embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_cls_n = n_cls_n
        self.n_ctx = n_ctx
        self.scene_tokenized_prompts = scene_tokenized_prompts  # torch.Tensor
        self.name_tokenized_prompts = name_tokenized_prompts  # torch.Tensor
        self.scene_lens = scene_lens
        self.name_lens = name_lens
        self.class_token_position = "end"

    def forward(self, mode=None):        
        if mode == "scene":
            ctx = self.ctx
            if ctx.dim() == 2:
                ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
            prefix = self.scene_token_prefix
            suffix = self.scene_token_suffix
        elif mode == "name":
            ctx = self.ctx_n
            if ctx.dim() == 2:
                ctx = ctx.unsqueeze(0).expand(self.n_cls_n, -1, -1)
            prefix = self.name_token_prefix
            suffix = self.name_token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

class MTPromptTextual(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = MTTextPromptLearner(clip_model, class_scenes, class_names, device, use_csc, n_ctx)
        self.scene_tokenized_prompts = self.prompt_learner.scene_tokenized_prompts
        self.name_tokenized_prompts = self.prompt_learner.name_tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self, indices=None):
        # scene
        prompts = self.prompt_learner("scene")
        tokenized_prompts = self.scene_tokenized_prompts

        #prompts = prompts[indices]
        #tokenized_prompts = tokenized_prompts[indices]
        scene_text_features = self.text_encoder(prompts, tokenized_prompts)
        scene_text_features = scene_text_features.type(self.dtype)
        # dynamic
        prompts = self.prompt_learner("name")
        tokenized_prompts = self.name_tokenized_prompts

        dynamic_text_features = self.text_encoder(prompts, tokenized_prompts)
        dynamic_text_features = dynamic_text_features.type(self.dtype)

        return scene_text_features, dynamic_text_features

class MultiPromptViTCLIP(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False, n_vtk=5, use_csc=False, n_ctx=8):
        super(MultiPromptViTCLIP, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = PromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk)
        self.textual_backbone = MTPromptTextual(clip_model, class_scenes, class_names, device, use_csc, n_ctx)
        self._features_dim = feat_dim
        self.temp_text_features = None

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                 if ("prompt_learner" not in name):
                    param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, indices=None, mode=None):
        image_features = self.visual_backbone(images)
        scene_text_features, dynamic_text_features = self.textual_backbone(indices)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        scene_text_features = scene_text_features / scene_text_features.norm(dim=-1, keepdim=True)
        dynamic_text_features = dynamic_text_features / dynamic_text_features.norm(dim=-1, keepdim=True)
        
        logits_per_image = image_features @ scene_text_features.t()
        logits_per_image = torch.sigmoid(logits_per_image)

        if self.training:
            '''
            dynamic_prompts = []
            for index in indices:
                index = index.unsqueeze(1).expand(index.size(0), 512)
                prompt = index * scene_text_features
                prompt = torch.mean(prompt, dim=0)
                prompt = prompt / prompt.norm(dim=-1, keepdim=True)
                dynamic_prompts.append(prompt)
            dynamic_prompts = torch.stack(dynamic_prompts)
            
            dynamic_logits = []
            for image_feature, dynamic_prompt in zip(image_features, dynamic_prompts.detach()):
                text_features = torch.add(dynamic_text_features, dynamic_prompt) * 0.5
                dynamic_logit = image_feature @ text_features.t()
                dynamic_logits.append(dynamic_logit)
            dynamic_logits = torch.stack(dynamic_logits)
            '''
            dynamic_logits = image_features @ dynamic_text_features.t()
            return logits_per_image, dynamic_logits, image_features
        elif mode=="eval":
            dynamic_logits = image_features @ dynamic_text_features.t()
            return logits_per_image, dynamic_logits
        else:
            '''
            dynamic_prompts = []
            for index in logits_per_image:
                index = index.unsqueeze(1).expand(index.size(0), 512)
                prompt = index * scene_text_features
                prompt = torch.mean(prompt, dim=0)
                prompt = prompt / prompt.norm(dim=-1, keepdim=True)
                dynamic_prompts.append(prompt)
            dynamic_prompts = torch.stack(dynamic_prompts)
            
            dynamic_logits = []
            for image_feature, dynamic_prompt in zip(image_features, dynamic_prompts.detach()):
                text_features = torch.add(dynamic_text_features, dynamic_prompt) * 0.5
                dynamic_logit = image_feature @ text_features.t()
                dynamic_logits.append(dynamic_logit)
            dynamic_logits = torch.stack(dynamic_logits)
            '''
            dynamic_logits = image_features @ dynamic_text_features.t()
            return dynamic_logits, image_features       
        

        if self.training:
            return logits_per_image, dynamic_logits, image_features
        else:
            return logits_per_image, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params


# VPT-CoCoOp version
class CoPromptVisionTransformer(nn.Module):
    def __init__(self, clip_model, clip_model_type, DeepPrompt, n_vtk):
        super().__init__()
        self.input_resolution = clip_model.visual.input_resolution
        self.output_dim = clip_model.visual.output_dim
        self.conv1 = clip_model.visual.conv1
        
        self.class_embedding = clip_model.visual.class_embedding
        self.positional_embedding = clip_model.visual.positional_embedding
        self.ln_pre = clip_model.visual.ln_pre

        self.transformer = clip_model.visual.transformer

        self.ln_post = clip_model.visual.ln_post
        self.proj = clip_model.visual.proj
        self.Deep = DeepPrompt
        self.feature_extractor_type = clip_model_type.split('CLIP')[1]

        # prompt config
        if "ViT-B/32" in clip_model_type:
            patch_size = (32, 32)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk
        elif "ViT-B/16" in clip_model_type:
            patch_size = (16, 16)
            _, prompt_dim = self.positional_embedding.shape
            self.num_tokens = n_vtk

        hidden_size = 768
        self.prompt_dropout = Dropout(0.1)
        self.prompt_proj = nn.Linear(prompt_dim, hidden_size)
        nn.init.kaiming_normal_(self.prompt_proj.weight, a=0, mode='fan_out')

        val = math.sqrt(6. / float(3 * reduce(mul, patch_size, 1) + prompt_dim))  # noqa

        self.prompt_embeddings = nn.Parameter(torch.zeros(
            1, self.num_tokens, prompt_dim))
        # xavier_uniform initialization
        nn.init.uniform_(self.prompt_embeddings.data, -val, val)

        if self.Deep:  # Deep prompt version noqa
            total_d_layer = 12-1
            self.deep_prompt_embeddings = nn.Parameter(torch.zeros(
                total_d_layer, self.num_tokens, prompt_dim))
            # xavier_uniform initialization
            nn.init.uniform_(self.deep_prompt_embeddings.data, -val, val)
        
    def forward(self, x, bias):
        B = x.size(0)
        
        # meta-token
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        projected_embeddings = self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)
        ctx_shifted = projected_embeddings + bias
        # print(bias.shape) # [8,1,768]
        # print(projected_embeddings.shape) # [8,5,768]
        # print(ctx_shifted.shape) # [8,5,768]
  
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        #print(x.shape)
        x = x + self.positional_embedding.to(x.dtype)
        #print(self.positional_embedding.shape)
        # incorporate_prompt
        x = torch.cat((
                x[:, :1, :],
                # self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                self.prompt_dropout(ctx_shifted),
                x[:, 1:, :]
            ), dim=1)
        #print(x.shape) -> (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)  # NLD -> LND

        if self.Deep:  # Deep prompt version
            hidden_states = None
            num_layers = self.transformer.layers

            for i in range(num_layers):
                if i == 0:
                    hidden_states = self.transformer.resblocks[i](x)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        
                        deep_prompt_emb = deep_prompt_emb.permute(1, 0, 2)  # NLD -> LND
            
                        hidden_states = torch.cat((
                            hidden_states[:1, :, :],
                            deep_prompt_emb,
                            hidden_states[(1+self.num_tokens):, :, :]
                        ), dim=0)

                    hidden_states = self.transformer.resblocks[i](hidden_states)
            x = hidden_states
        else:
            x = self.transformer(x)

        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

## CoCoOp ##
class CoMTTextPromptLearner(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, device, use_csc, n_ctx):
        super().__init__()
        n_cls_c = len(class_scenes)
        n_cls_n = len(class_names)
        n_ctx = n_ctx
        ctx_init = None
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = 224
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        self.ctx = nn.Parameter(ctx_vectors)

        # class scenes prompt
        class_scenes = [scene.replace("_", " ") for scene in class_scenes]
        scene_lens = [len(_tokenizer.encode(scene)) for scene in class_scenes]
        scene_prompts = [prompt_prefix + " " + scene + "." for scene in class_scenes]

        scene_tokenized_prompts = torch.cat([clip.tokenize(p) for p in scene_prompts]).to(device)
        with torch.no_grad():
            scene_embedding = clip_model.token_embedding(scene_tokenized_prompts).type(dtype)
        
        # class name prompt
        class_names = [name.replace("_", " ") for name in class_names]
        name_lens = [len(_tokenizer.encode(name)) for name in class_names]
        name_prompts = [prompt_prefix + " " + name + "." for name in class_names]

        name_tokenized_prompts = torch.cat([clip.tokenize(p) for p in name_prompts]).to(device)
        with torch.no_grad():
            name_embedding = clip_model.token_embedding(name_tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("scene_token_prefix", scene_embedding[:, :1, :])  # SOS
        self.register_buffer("scene_token_suffix", scene_embedding[:, 1 + n_ctx :, :])  # CLS, EOS
        self.register_buffer("name_token_prefix", name_embedding[:, :1, :])  # SOS
        self.register_buffer("name_token_suffix", name_embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls_c = n_cls_c
        self.n_cls_n = n_cls_n
        self.n_ctx = n_ctx
        self.scene_tokenized_prompts = scene_tokenized_prompts  # torch.Tensor
        self.name_tokenized_prompts = name_tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.scene_lens = scene_lens
    
    def construct_prompts(self, ctx, prefix, suffix, label=None):
        # dim0 is either batch_size (during training) or n_cls (during testing)
        # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
        # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
        # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)

        if label is not None:
            prefix = prefix[label]
            suffix = suffix[label]

        prompts = torch.cat(
            [
                prefix,  # (dim0, 1, dim)
                ctx,     # (dim0, n_ctx, dim)
                suffix,  # (dim0, *, dim)
            ],
            dim=1,
        )

        return prompts

    def forward(self, bias, mode):
        if mode == "scene":
            prefix = self.scene_token_prefix
            suffix = self.scene_token_suffix
            n_cls = self.n_cls_c
        elif mode == "name":
            prefix = self.name_token_prefix
            suffix = self.name_token_suffix
            n_cls = self.n_cls_n
        
        ctx = self.ctx                     # (n_ctx, ctx_dim)
        bias = bias.unsqueeze(1)           # (batch, 1, ctx_dim)
        ctx = ctx.unsqueeze(0)             # (1, n_ctx, ctx_dim)
        ctx_shifted = ctx + bias           # (batch, n_ctx, ctx_dim)
        
        # Use instance-conditioned context tokens for all classes
        prompts = []
        for ctx_shifted_i in ctx_shifted:
            ctx_i = ctx_shifted_i.unsqueeze(0).expand(n_cls, -1, -1)
            pts_i = self.construct_prompts(ctx_i, prefix, suffix)  # (n_cls, n_tkn, ctx_dim)
            prompts.append(pts_i)
        prompts = torch.stack(prompts)
        
        return prompts

class CoMTPromptTextual(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, device, use_csc, n_ctx):
        super().__init__()
        self.prompt_learner = CoMTTextPromptLearner(clip_model, class_scenes, class_names, device, use_csc, n_ctx)
        self.scene_tokenized_prompts = self.prompt_learner.scene_tokenized_prompts
        self.name_tokenized_prompts = self.prompt_learner.name_tokenized_prompts
        self.text_encoder = TextEncoder(clip_model)
        self.dtype = clip_model.dtype

    def forward(self, bias, indices=None, mode=None):
        if mode == "scene":
            promptses = self.prompt_learner(bias, mode)
            tokenized_prompts = self.scene_tokenized_prompts
            prompts = []
            for i in range(promptses.size(0)):
                prompt = promptses[i, indices[i]]
                prompts.append(prompt)
            prompts = torch.stack(prompts)
            tokenized_prompts = tokenized_prompts[indices]
            text_features = self.text_encoder(prompts, tokenized_prompts)
            return text_features

        elif mode == "name":
            promptses = self.prompt_learner(bias, mode)
            tokenized_prompts = self.name_tokenized_prompts
        
            text_featureses = []
            for pts_i in promptses:
                text_features = self.text_encoder(pts_i, tokenized_prompts)
                text_featureses.append(text_features)
            text_featureses = torch.stack(text_featureses)
            
            return text_featureses

class CoMTDPViTCLIP(nn.Module):
    def __init__(self, clip_model, class_scenes, class_names, feat_dim, device, clip_model_type="CLIPViT-B/16", DeepPrompt=False, n_vtk=5, use_csc=False, n_ctx=8):
        super(CoMTDPViTCLIP, self).__init__()
        self.logit_scale = clip_model.logit_scale
        self.visual_backbone = CoPromptVisionTransformer(clip_model, clip_model_type, DeepPrompt, n_vtk)
        self.textual_backbone = CoMTPromptTextual(clip_model, class_scenes, class_names, device, use_csc, n_ctx)
        self._features_dim = feat_dim
        
        vis_dim = clip_model.visual.output_dim
        vp_dim = clip_model.visual.positional_embedding.shape[1]
        ctx_dim = clip_model.ln_final.weight.shape[0]
        
        # meta net copy
        self.image_encoder = clip_model.visual
        self.meta_prompt_linear = nn.Linear(vis_dim, vis_dim // 16)
        self.meta_prompt_relu = nn.ReLU(inplace=True)
        self.meta_prompt_vp = nn.Linear(vis_dim // 16, vp_dim)
        self.meta_prompt_ctx = nn.Linear(vis_dim // 16, ctx_dim)

        print("Turning off gradients in both the image and the text encoder")
        for name, param in self.named_parameters():
            if "prompt" not in name:
                param.requires_grad_(False)
        # Double check
        enabled = set()
        for name, param in self.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        for name, param in clip_model.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
        for name, param in self.textual_backbone.text_encoder.named_parameters():
            param.requires_grad_(False)
            assert not param.requires_grad
    
    def forward(self, images, indices=None, mode=None):
        # get each meta token
        with torch.no_grad():
            image_features_ori = self.image_encoder(images)
            image_features_ori = image_features_ori / image_features_ori.norm(dim=-1, keepdim=True)
        x = self.meta_prompt_relu(self.meta_prompt_linear(image_features_ori))
        vp_bias = self.meta_prompt_vp(x)
        ctx_bias = self.meta_prompt_ctx(x)
        
        if mode == "scene":
            image_features = self.visual_backbone(images, vp_bias)
            text_features = self.textual_backbone(ctx_bias, indices, mode)

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            logit_scale = self.logit_scale.exp()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

        elif mode == "name":
            image_features = self.visual_backbone(images, vp_bias)
            text_featureses = self.textual_backbone(ctx_bias, indices, mode)

            logit_scale = self.logit_scale.exp()
            logits_per_image = []
            for txf_i, imf_i in zip(text_featureses, image_features):
                text_features = txf_i / txf_i.norm(dim=-1, keepdim=True)
                l_i = logit_scale * imf_i @ text_features.t()
                logits_per_image.append(l_i)
            logits_per_image = torch.stack(logits_per_image)
            logits_per_text = logits_per_image.t()
        
        if self.training:
            return logits_per_image, logits_per_text, image_features
        else:
            return logits_per_image, image_features
    
    @property
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim
    
    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=False, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        params = [
            {"params": self.visual_backbone.parameters(), "lr": 1.0 * base_lr},
            {"params": self.textual_backbone.prompt_learner.parameters(), "lr": 1.0 * base_lr},
        ]

        return params

class LinearProbe(nn.Module):
    def __init__(self, model, num_classes, feat_dim):
        super(LinearProbe, self).__init__()

        self.backbone = model
        self.head = nn.Linear(feat_dim, num_classes)
        self._features_dim = feat_dim

    def forward(self, x, mode=None):
        features = self.backbone(x)
        features = features / features.norm(dim=-1, keepdim=True)
        predictions = self.head(features)

        if self.training:
            return predictions, features,features
        else:
            return predictions, features
    
    def features_dim(self) -> int:
        """The dimension of features before the final `head` layer"""
        return self._features_dim

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.eval()
    
    def get_parameters(self, optimize_head=True, base_lr=1.0) -> List[Dict]:
        """A parameter list which decides optimization hyper-parameters,
            such as the relative learning rate of each layer
        """
        if optimize_head:
            params = [
                {"params": self.head.parameters(), "lr": 1.0 * base_lr},
            ]
        else:
            params = []

        return params