# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from dataclasses import dataclass, field
from typing import Optional, Callable
from functools import partial
from omegaconf import II
from enum import Enum, auto
from fairseq.modules import EMAModule, EMAModuleConfig
from fairseq.dataclass import FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
import sys

from peft.lora.continual_lora import ContinualLora
from peft.adapter.continual_adapter import ContinualAdapter
from peft.prompt.continual_prompt import ContinualPrompt
# TODO: change
from peft.prompt.dp_prompt import EPrompt
# from peft.prompt.hide_prompt import EPrompt
# from peft.lora.hide_lora import HideLoraPool
# from peft.lora.momentum_lora import MomentumLora


if 1:
    from base import (
        MaskSeed,
        D2vModalityConfig,
        ModalitySpecificEncoder, 
        get_annealed_rate,
    )

    from modules import (
        D2vDecoderConfig,
        AltBlock,
        Decoder1d,
    )

    from images import (
        D2vImageConfig,
        ImageEncoder,
    )


logger = logging.getLogger(__name__)

# we follow the work of data2vec 2.0 on image modality and Audio-MAE in EAT 
class Modality(Enum):
    AUDIO = auto()
    IMAGE = auto()
    TEXT = auto()

@dataclass
class D2vModalitiesConfig(FairseqDataclass):
    image: D2vImageConfig = D2vImageConfig()
    
@dataclass
class Data2VecMultiConfig(FairseqDataclass):

    loss_beta: float = field(
        default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
    )
    
    loss_scale: Optional[float] = field(
        default=None,
        metadata={
            "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
        },
    )

    depth: int = 12
    
    # standard vision Transformer
    start_drop_path_rate: float = 0
    end_drop_path_rate: float = 0
    num_heads: int = 12
    norm_eps: float = 1e-6
    norm_affine: bool = True
    encoder_dropout: float = 0.1
    post_mlp_drop: float = 0.1
    attention_dropout: float = 0.1
    activation_dropout: float = 0.0
    dropout_input: float = 0.0
    layerdrop: float = 0.0
    embed_dim: int = 768
    mlp_ratio: float = 4
    layer_norm_first: bool = False

    # EAT averages all Transformer block output (12 layers in total) 
    average_top_k_layers: int = field(
        default=12, metadata={"help": "how many layers to average"}
    )

    end_of_block_targets: bool = False

    # clone batch for multi-mask strategy
    clone_batch: int = 16

    # normalization for teacher Transformer layer output
    layer_norm_target_layer: bool = False
    batch_norm_target_layer: bool = False
    instance_norm_target_layer: bool = False
    instance_norm_targets: bool = False
    layer_norm_targets: bool = False

    # EMA settings
    ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
    ema_same_dtype: bool = True
    log_norms: bool = True
    ema_end_decay: float = field(
        default=0.9999, metadata={"help": "final ema decay rate"}
    )

    ema_anneal_end_step: int = II("optimization.max_update")

    # In EAT, the Transformer encoder and the CNN encoder are both EMA updated
    ema_encoder_only: bool = field(
        default=True,
        metadata={
            "help": "whether to momentum update only the shared transformer encoder"
        },
    )

    max_update: int = II("optimization.max_update")

    modalities: D2vModalitiesConfig = D2vModalitiesConfig()

    shared_decoder: Optional[D2vDecoderConfig] = None

    min_target_var: float = field(
        default=0.1, metadata={"help": "stop training if target var falls below this"}
    )
    min_pred_var: float = field(
        default=0.01,
        metadata={"help": "stop training if prediction var falls below this"},
    )

    supported_modality: Optional[Modality] = None
    mae_init: bool = False

    seed: int = II("common.seed")

    skip_ema: bool = False

    # d2v_loss is the frame-level loss while cls_loss is the utterance-level loss
    cls_loss: float = 0
    recon_loss: float = 0
    d2v_loss: float = 1

    decoder_group: bool = False

    # the experiment of using dino loss instead of direct utterance loss (not included in our paper)
    utterance_level: bool = field(default=False, metadata={"help": "if true, we will add utterance-level loss to the total loss"})
    init_center_token_zero: bool = field(default=False, metadata={"help": "if true, we will initialize the centor token with zero vertors"})
    center_exp: float = field(default=0.9, metadata={"help": "this value control the exponent decay of center value's coefficient"})
    softmax_temperature_student: float = field(default=0.1, metadata={"help": "this value control the temperature of softmax function of student output in the dino loss"})
    softmax_temperature_teacher: float = field(default=0.05, metadata={"help": "this value control the temperature of softmax function in teacher output the dino loss"})


# @register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
# class Data2VecMultiModel(BaseFairseqModel):
#     def make_modality_encoder(
#         self,
#         cfg: D2vModalityConfig,
#         embed_dim: int,
#         make_block: Callable[[float], nn.ModuleList],
#         norm_layer: Callable[[int], nn.LayerNorm],
#         layer_norm_first: bool,
#         alibi_biases,
#         task,
#     ) -> ModalitySpecificEncoder:

#         # import pdb; pdb.set_trace()
#         if cfg.type.value == Modality.IMAGE.value:
#             enc_cls = ImageEncoder
#         else:
#             raise Exception(f"unsupported modality {cfg.type}")

#         return enc_cls(
#             cfg,
#             embed_dim,
#             make_block,
#             norm_layer,
#             layer_norm_first,
#             alibi_biases,
#             task,
#         )

#     def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
#         super().__init__()
#         self.cfg = cfg
#         self.modalities = modalities
#         self.task = task

#         # import pdb; pdb.set_trace()
        
#         make_layer_norm = partial(
#             nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
#         )

#         def make_block(drop_path, dim=None, heads=None):
#             return AltBlock(
#                 cfg.embed_dim if dim is None else dim,
#                 cfg.num_heads if heads is None else heads,
#                 cfg.mlp_ratio,
#                 qkv_bias=True,
#                 drop=cfg.encoder_dropout,
#                 attn_drop=cfg.attention_dropout,
#                 mlp_drop=cfg.activation_dropout,
#                 post_mlp_drop=cfg.post_mlp_drop,
#                 drop_path=drop_path,
#                 norm_layer=make_layer_norm,
#                 layer_norm_first=cfg.layer_norm_first,
#                 ffn_targets=not cfg.end_of_block_targets,
#             )

#         self.alibi_biases = {}
#         self.modality_encoders = nn.ModuleDict()
        
#         # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
#         for mod in self.modalities:
#             mod_cfg = getattr(cfg.modalities, mod.name.lower())
#             # import pdb; pdb.set_trace()
#             enc = self.make_modality_encoder(
#                 mod_cfg,
#                 cfg.embed_dim,
#                 make_block,
#                 make_layer_norm,
#                 cfg.layer_norm_first,
#                 self.alibi_biases,
#                 task,
#             )
#             self.modality_encoders[mod.name] = enc

#         self.ema = None

#         self.average_top_k_layers = cfg.average_top_k_layers
#         self.loss_beta = cfg.loss_beta
#         self.loss_scale = cfg.loss_scale
#         self.utterance_level = cfg.utterance_level

#         self.dropout_input = nn.Dropout(cfg.dropout_input)

#         dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

#         self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])

#         self.norm = None
#         if cfg.layer_norm_first:
#             self.norm = make_layer_norm(cfg.embed_dim)

#         if self.cfg.mae_init:
#             self.apply(self._init_weights)
#         else:
#             from fairseq.modules.transformer_sentence_encoder import init_bert_params

#             self.apply(init_bert_params)

#         for mod_enc in self.modality_encoders.values():
#             mod_enc.reset_parameters()

#         # make teacher model
#         if not skip_ema:
#             self.ema = self.make_ema_teacher(cfg.ema_decay)
#             self.shared_decoder = (
#                 Decoder1d(cfg.shared_decoder, cfg.embed_dim)
#                 if self.cfg.shared_decoder is not None
#                 else None
#             )
#             if self.shared_decoder is not None:
#                 self.shared_decoder.apply(self._init_weights)

#             self.recon_proj = None
#             if cfg.recon_loss > 0:
#                 self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim//3)
                
#             self.cls_proj = None
#             if cfg.utterance_level:
#                 self.cls_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)

#         for pn, p in self.named_parameters():
#             if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
#                 p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
#             if cfg.decoder_group and "decoder" in pn:
#                 p.param_group = "decoder"
        
#         # dino loss experiment
#         self.center = None
#         if self.utterance_level:
#             self.center_exp = cfg.center_exp
#             self.soft_tem_s = cfg.softmax_temperature_student
#             self.soft_tem_t = cfg.softmax_temperature_teacher
#             self.center = nn.Parameter(
#                     torch.zeros(1, 1, cfg.embed_dim, requires_grad=False)
#                 )
#             if not cfg.init_center_token_zero:
#                 nn.init.normal_(self.center)
#             elif self.center.size(1) > 1:
#                 nn.init.normal_(self.center[:, 1:])

#         self.num_updates = 0

#     def _init_weights(self, m):

#         try:
#             from apex.normalization import FusedLayerNorm

#             fn = FusedLayerNorm
#         except:
#             fn = nn.LayerNorm

#         if isinstance(m, nn.Linear):
#             torch.nn.init.xavier_uniform_(m.weight)
#             if isinstance(m, nn.Linear) and m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#         elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
#             if m.bias is not None:
#                 nn.init.constant_(m.bias, 0)
#             if m.weight is not None:
#                 nn.init.constant_(m.weight, 1.0)

#     @torch.no_grad()
#     def make_ema_teacher(self, ema_decay):
#         ema_config = EMAModuleConfig(
#             ema_decay=ema_decay,
#             ema_fp32=True,
#             log_norms=self.cfg.log_norms,
#             add_missing_params=False,
#         )

#         model_copy = self.make_target_model()

#         return EMAModule(
#             model_copy,
#             ema_config,
#             copy_model=False,
#         )

#     # teacher model (with independent CNN encoder and Transformer encoder)
#     def make_target_model(self):
#         logger.info("making target model")

#         model_copy = Data2VecMultiModel(
#             self.cfg, self.modalities, skip_ema=True, task=self.task
#         )

#         if self.cfg.ema_encoder_only:
#             model_copy = model_copy.blocks
#             for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
#                 p_t.data.copy_(p_s.data)
#         else:
#             for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
#                 p_t.data.copy_(p_s.data)

#             for mod_enc in model_copy.modality_encoders.values():
#                 mod_enc.decoder = None
#                 if not mod_enc.modality_cfg.ema_local_encoder:
#                     mod_enc.local_encoder = None
#                     mod_enc.project_features = None

#         model_copy.requires_grad_(False)
#         return model_copy

#     # teacher model updated with EMA
#     def set_num_updates(self, num_updates):
#         super().set_num_updates(num_updates)

#         if self.ema is not None and (
#             (self.num_updates == 0 and num_updates > 1)
#             or self.num_updates >= num_updates
#         ):
#             pass
#         elif self.training and self.ema is not None:
#             ema_weight_decay = None
#             if self.cfg.ema_decay != self.cfg.ema_end_decay:
#                 if num_updates >= self.cfg.ema_anneal_end_step:
#                     decay = self.cfg.ema_end_decay
#                 else:
#                     decay = get_annealed_rate(
#                         self.cfg.ema_decay,
#                         self.cfg.ema_end_decay,
#                         num_updates,
#                         self.cfg.ema_anneal_end_step,
#                     )
#                 self.ema.set_decay(decay, weight_decay=ema_weight_decay)
#             if self.ema.get_decay() < 1:
#                 self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)

#         self.num_updates = num_updates

#     def state_dict(self, destination=None, prefix="", keep_vars=False):
#         state = super().state_dict(destination, prefix, keep_vars)

#         if self.ema is not None:
#             state[prefix + "_ema"] = self.ema.fp32_params

#         return state

#     def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
#         k = prefix + "_ema"
#         if self.ema is not None:
#             assert k in state_dict
#             self.ema.restore(state_dict[k], True)
#             del state_dict[k]
#         elif k in state_dict:
#             del state_dict[k]

#         return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

#     @classmethod
#     def build_model(cls, cfg: Data2VecMultiConfig, task=None):
#         """Build a new model instance."""
#         if task is None or not hasattr(task, "supported_modalities"):
#             modalities = (
#                 [cfg.supported_modality]
#                 if cfg.supported_modality is not None
#                 else [
#                     Modality.AUDIO,
#                     Modality.IMAGE,
#                     Modality.TEXT,
#                 ]
#             )
#         else:
#             modalities = task.supported_modalities
#         return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)

#     def forward(
#         self,
#         source,
#         target=None,
#         id=None,
#         mode=None,
#         padding_mask=None,
#         mask=True,
#         features_only=False,
#         force_remove_masked=False,
#         remove_extra_tokens=True,
#         precomputed_mask=None,
#     ):
#         if mode is None:
#             assert self.cfg.supported_modality is not None
#             mode = self.cfg.supported_modality

#         if isinstance(mode, Modality):
#             mode = mode.name

#         feature_extractor = self.modality_encoders[mode]

#         mask_seeds = None
#         if id is not None:
#             mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

#         # extract (unmasked) features using CNN encoder
#         extractor_out = feature_extractor(
#             source,
#             padding_mask,
#             mask,
#             remove_masked=not features_only or force_remove_masked,
#             clone_batch=self.cfg.clone_batch if not features_only else 1,
#             mask_seeds=mask_seeds,
#             precomputed_mask=precomputed_mask,
#         )

#         # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
#         # EAT does not employ the ablibi mechanism in Transformer
#         x = extractor_out["x"]
#         encoder_mask = extractor_out["encoder_mask"]
#         masked_padding_mask = extractor_out["padding_mask"]
#         masked_alibi_bias = extractor_out.get("alibi_bias", None)
#         alibi_scale = extractor_out.get("alibi_scale", None)

#         if self.dropout_input is not None:
#             x = self.dropout_input(x)

#         # standard Transformer (for student encoder)
#         layer_results = []
#         for i, blk in enumerate(self.blocks):
#             if (
#                 not self.training
#                 or self.cfg.layerdrop == 0
#                 or (np.random.random() > self.cfg.layerdrop)
#             ):
#                 ab = masked_alibi_bias
#                 if ab is not None and alibi_scale is not None:
#                     scale = (
#                         alibi_scale[i]
#                         if alibi_scale.size(0) > 1
#                         else alibi_scale.squeeze(0)
#                     )
#                     ab = ab * scale.type_as(ab)

#                 x, lr = blk(
#                     x,
#                     padding_mask=masked_padding_mask,
#                     alibi_bias=ab,
#                 )
#                 if features_only:
#                     layer_results.append(lr)

#         if self.norm is not None:
#             x = self.norm(x)

#         # extract features for fine-tuning
#         if features_only:
#             if remove_extra_tokens:
#                 x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
#                 if masked_padding_mask is not None:
#                     masked_padding_mask = masked_padding_mask[
#                         :, feature_extractor.modality_cfg.num_extra_tokens :
#                     ]

#             return {
#                 "x": x,
#                 "padding_mask": masked_padding_mask,
#                 "layer_results": layer_results,
#                 "mask": encoder_mask,
#             }

#         # decode features merged with masked tokens, dx in shape (batch_size * clone_batch, patch, 768)
#         xs = []
#         if self.shared_decoder is not None:
#             dx = self.forward_decoder(
#                 x,
#                 feature_extractor,
#                 self.shared_decoder,
#                 encoder_mask,
#             )
#             xs.append(dx)
#         if feature_extractor.decoder is not None:
#             dx = self.forward_decoder(
#                 x,
#                 feature_extractor,
#                 feature_extractor.decoder,
#                 encoder_mask,
#             )
#             xs.append(dx)
#             orig_x = x

#         assert len(xs) > 0

#         p = next(self.ema.model.parameters())
#         device = x.device
#         dtype = x.dtype
#         ema_device = p.device
#         ema_dtype = p.dtype

#         if not self.cfg.ema_same_dtype:
#             dtype = ema_dtype

#         if ema_device != device or ema_dtype != dtype:
#             logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
#             self.ema.model = self.ema.model.to(dtype=dtype, device=device)
#             ema_dtype = dtype

#             def to_device(d):
#                 for k, p in d.items():
#                     if isinstance(d[k], dict):
#                         to_device(d[k])
#                     else:
#                         d[k] = p.to(device=device)

#             to_device(self.ema.fp32_params)
#         tm = self.ema.model

#         # encode audio spectrogram using teacher model
#         with torch.no_grad():
#             tm.eval()

#             if self.cfg.ema_encoder_only:
#                 assert target is None
#                 ema_input = extractor_out["local_features"]
#                 ema_input = feature_extractor.contextualized_features(
#                     ema_input.to(dtype=ema_dtype),
#                     padding_mask,
#                     mask=False,
#                     remove_masked=False,
#                 )
#                 ema_blocks = tm
#             else:
#                 ema_blocks = tm.blocks
#                 if feature_extractor.modality_cfg.ema_local_encoder:
#                     inp = (
#                         target.to(dtype=ema_dtype)
#                         if target is not None
#                         else source.to(dtype=ema_dtype)
#                     )
#                     ema_input = tm.modality_encoders[mode](
#                         inp,
#                         padding_mask,
#                         mask=False,
#                         remove_masked=False,
#                     )
#                 else:
#                     assert target is None
#                     ema_input = extractor_out["local_features"]
#                     ema_feature_enc = tm.modality_encoders[mode]
#                     ema_input = ema_feature_enc.contextualized_features(
#                         ema_input.to(dtype=ema_dtype),
#                         padding_mask,
#                         mask=False,
#                         remove_masked=False,
#                     )

#             ema_padding_mask = ema_input["padding_mask"]
#             ema_alibi_bias = ema_input.get("alibi_bias", None)
#             ema_alibi_scale = ema_input.get("alibi_scale", None)
#             ema_input = ema_input["x"]

#             # extract target features using teacher CNN encoder
#             # ema_input in shape (batch_size, patch + 1(cls_token), feature_dimension)
#             y = []
#             ema_x = []
#             extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
#             for i, blk in enumerate(ema_blocks):  
#                 ab = ema_alibi_bias
#                 if ab is not None and alibi_scale is not None:
#                     scale = (
#                         ema_alibi_scale[i]
#                         if ema_alibi_scale.size(0) > 1
#                         else ema_alibi_scale.squeeze(0)
#                     )
#                     ab = ab * scale.type_as(ab)

#                 ema_input, lr = blk(
#                     ema_input,
#                     padding_mask=ema_padding_mask,
#                     alibi_bias=ab,
#                 )
#                 y.append(lr[:, extra_tokens:])
#                 ema_x.append(ema_input[:, extra_tokens:])

#         # EAT utilize total 12 Transformer block layer output average as target  
#         y = self.make_targets(y, self.average_top_k_layers)
#         orig_targets = y

#         # multiply the target value according to the number of clone batch
#         if self.cfg.clone_batch > 1:
#             y = y.repeat_interleave(self.cfg.clone_batch, 0)

#         # extract values in masked position to make prediction
#         masked = encoder_mask.mask.unsqueeze(-1)
#         masked_b = encoder_mask.mask.bool()
#         y = y[masked_b]     

#         if xs[0].size(1) == masked_b.size(1):
#             xs = [x[masked_b] for x in xs]
#         else:
#             xs = [x.reshape(-1, x.size(-1)) for x in xs]
            

#         sample_size = masked.sum().long()

#         result = {
#             "losses": {},
#             "sample_size": sample_size,
#         }

#         sample_size = result["sample_size"]

#         # EAT employ utterance-level loss by using mean pooling in patch dimension
#         if self.cfg.cls_loss > 0 and not self.utterance_level:
#             assert extra_tokens > 0
#             cls_target = orig_targets.mean(dim=1)
#             if self.cfg.clone_batch > 1:
#                 cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
#             cls_pred = x[:, extra_tokens - 1]
            
#             result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
#                 self.cfg.cls_loss * sample_size
#             )
            
#         # dino loss experiment
#         if self.cfg.cls_loss > 0 and self.utterance_level:
#             assert extra_tokens > 0
#             cls_target = orig_targets.mean(dim=1)
#             if self.cfg.clone_batch > 1:
#                 cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)  #(btz*clone,1,768)
#             cls_pred = x[:, extra_tokens - 1]
#             cls_target = cls_target - self.center
            
#             cls_pred = cls_pred.squeeze(dim=1)
#             cls_target = cls_target.squeeze(dim=1)
            
#             result["losses"]["cls"] = self.dino_loss(cls_pred, cls_target) * (
#                 self.cfg.cls_loss * sample_size
#             )
            
#             self.center = self.center_exp * self.center + (1 - self.center_exp) * (cls_target.mean(dim=0))
            
#         if self.cfg.recon_loss > 0:

#             with torch.no_grad():
#                 target = feature_extractor.patchify(source)  #(btz,1,512,16*16)
#                 mean = target.mean(dim=-1, keepdim=True)
#                 var = target.var(dim=-1, keepdim=True)
#                 target = (target - mean) / (var + 1.0e-6) ** 0.5   #(btz,1,512,1)

#                 if self.cfg.clone_batch > 1:
#                     target = target.repeat_interleave(self.cfg.clone_batch, 0)  #(btz*clone_btz,1,512,1)

#                 if masked_b is not None:
#                     target = target[masked_b]

#             recon = xs[0]
#             if self.recon_proj is not None:
#                 recon = self.recon_proj(recon)

#             result["losses"]["recon"] = (
#                 self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
#             )

#         if self.cfg.d2v_loss > 0:
#             for i, x in enumerate(xs):
#                 reg_loss = self.d2v_loss(x, y)
#                 n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
#                 result["losses"][n] = reg_loss * self.cfg.d2v_loss

#         # compute state for logs
#         suffix = "" if len(self.modalities) == 1 else f"_{mode}"
#         with torch.no_grad():
#             if encoder_mask is not None:
#                 result["masked_pct"] = 1 - (
#                     encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
#                 )
#             for i, x in enumerate(xs):
#                 n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
#                 result[n] = self.compute_var(x.float())
#             if self.ema is not None:
#                 for k, v in self.ema.logs.items():
#                     result[k] = v

#             y = y.float()
#             result[f"target_var{suffix}"] = self.compute_var(y)

#             if self.num_updates > 5100:
#                 if result[f"target_var{suffix}"] < self.cfg.min_target_var:
#                     logger.error(
#                         f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
#                     )
#                     raise Exception(
#                         f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
#                     )

#                 for k in result.keys():
#                     if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
#                         logger.error(
#                             f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
#                         )
#                         raise Exception(
#                             f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
#                         )

#             result["ema_decay"] = self.ema.get_decay() * 1000

#         return result

#     def forward_decoder(
#         self,
#         x,
#         feature_extractor,
#         decoder,
#         mask_info,
#     ):
#         x = feature_extractor.decoder_input(x, mask_info)
#         x = decoder(*x)

#         return x

#     def d2v_loss(self, x, y):
#         x = x.view(-1, x.size(-1)).float()
#         y = y.view(-1, x.size(-1))

#         if self.loss_beta == 0:
#             loss = F.mse_loss(x, y, reduction="none")
#         else:
#             loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)

#         if self.loss_scale is not None:
#             scale = self.loss_scale
#         else:
#             scale = 1 / math.sqrt(x.size(-1))

#         reg_loss = loss * scale

#         return reg_loss
    
#     def dino_loss(self,s,t):
#         t = t.detach()
#         s = F.softmax(s/self.soft_tem_s,dim=1)
#         t = F.softmax((t-self.center)/self.soft_tem_t,dim=1)
#         return - (t * torch.log(s)).sum(dim=1).mean()
    
#     # average top-k layers output from teacher model
#     def make_targets(self, y, num_layers):

#         with torch.no_grad():
#             target_layer_results = y[-num_layers:]

#             permuted = False
#             if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
#                 target_layer_results = [
#                     tl.transpose(1, 2) for tl in target_layer_results  # BTC -> BCT
#                 ]
#                 permuted = True
#             if self.cfg.batch_norm_target_layer:
#                 target_layer_results = [
#                     F.batch_norm(
#                         tl.float(), running_mean=None, running_var=None, training=True
#                     )
#                     for tl in target_layer_results
#                 ]
#             if self.cfg.instance_norm_target_layer:
#                 target_layer_results = [
#                     F.instance_norm(tl.float()) for tl in target_layer_results
#                 ]
#             if permuted:
#                 target_layer_results = [
#                     tl.transpose(1, 2) for tl in target_layer_results  # BCT -> BTC
#                 ]
#             if self.cfg.layer_norm_target_layer:
#                 target_layer_results = [
#                     F.layer_norm(tl.float(), tl.shape[-1:])
#                     for tl in target_layer_results
#                 ]

#         y = target_layer_results[0].float()
#         for tl in target_layer_results[1:]:
#             y.add_(tl.float())
#         y = y.div_(len(target_layer_results))

#         if self.cfg.layer_norm_targets:
#             y = F.layer_norm(y, y.shape[-1:])

#         if self.cfg.instance_norm_targets:
#             y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)

#         return y

#     @staticmethod
#     def compute_var(y):
#         y = y.view(-1, y.size(-1))
#         if dist.is_initialized():
#             zc = torch.tensor(y.size(0)).cuda()
#             zs = y.sum(dim=0)
#             zss = (y**2).sum(dim=0)

#             dist.all_reduce(zc)
#             dist.all_reduce(zs)
#             dist.all_reduce(zss)

#             var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
#             return torch.sqrt(var + 1e-6).mean()
#         else:
#             return torch.sqrt(y.var(dim=0) + 1e-6).mean()

#     def extract_features(
#         self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
#     ):
#         res = self.forward(
#             source,
#             mode=mode,
#             padding_mask=padding_mask,
#             mask=mask,
#             features_only=True,
#             remove_extra_tokens=remove_extra_tokens,
#         )
#         return res

#     def remove_pretraining_modules(self, modality=None, keep_decoder=False):
#         self.ema = None
#         self.cfg.clone_batch = 1
#         self.recon_proj = None        

#         if not keep_decoder:
#             self.shared_decoder = None

#         modality = modality.lower() if modality is not None else None
#         for k in list(self.modality_encoders.keys()):
#             if modality is not None and k.lower() != modality:
#                 del self.modality_encoders[k]
#             else:
#                 self.modality_encoders[k].remove_pretraining_modules(
#                     keep_decoder=keep_decoder
#                 )
#                 if not keep_decoder:
#                     self.modality_encoders[k].decoder = None

class Data2VecMultiModel_2(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()

        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        # import pdb; pdb.set_trace()
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            x, lr = blk(
                x,
                padding_mask=masked_padding_mask,
                alibi_bias=ab,
            )
            if features_only:
                layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x

        # default model do not need classification function
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res
    
    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res

class Data2VecMultiModel_2_lora(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.embed_dim = cfg.embed_dim
        self.lora_depth = kwargs["args"].lora_depth
        self.rank = kwargs["args"].lora_rank
        self.lora_layer = ContinualLora(dim=cfg.embed_dim, rank=self.rank, depth=self.lora_depth)  

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i < self.lora_depth:
                x, lr = blk(
                    x,
                    lora=self.lora_layer, 
                    depth_id=i,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            layer_results.append(x[0, :])

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res


class Data2VecMultiModel_2_adapter(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.adapter_depth = 4
        self.embed_dim = cfg.embed_dim
        self.adapter_rank = 4
        self.adapters = ContinualAdapter(depth=self.adapter_depth, dim=cfg.embed_dim, rank=self.adapter_rank)

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i < self.adapter_depth:
                x, lr = blk(
                    x,
                    adapter=self.adapters, 
                    depth_id=i,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res

class Data2VecMultiModel_2_prompt(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_length = 20
        self.num_e_prompt = 5
        self.use_prefix_tune_for_e_prompt = True
        self.num_heads = 12
        self.e_prompt = ContinualPrompt(length=self.prompt_length, embed_dim=cfg.embed_dim, num_layers=self.num_e_prompt, use_prefix_tune_for_e_prompt=self.use_prefix_tune_for_e_prompt, num_heads=self.num_heads)
        self.e_prompt_layer_idx = [0, 1, 2, 3, 4]

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        e_prompt_counter = -1
        res = self.e_prompt(x) 
        e_prompt = res['batched_prompt']

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                x, lr = blk(
                    x,
                    prompt=e_prompt[e_prompt_counter],
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            layer_results.append(x)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res
    
    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res

class Data2VecMultiModel_2_dual_prompt(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_pool = True
        self.head_type = 'token'
        self.use_prompt_mask = True

        self.use_g_prompt = True
        self.g_prompt_layer_idx = [0, 1]
        g_prompt_layer_idx = [0, 1]
        # num_g_prompt : The actual number of layers to which g-prompt is attached.
        # In official code, create as many layers as the total number of layers and select them based on the index
        num_g_prompt = len(self.g_prompt_layer_idx) if self.g_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_g_prompt = True
        use_prefix_tune_for_g_prompt = True

        self.use_e_prompt = True
        use_e_prompt = True
        self.e_prompt_layer_idx = [2, 3, 4]
        e_prompt_layer_idx = [2, 3, 4]
        num_e_prompt = len(self.e_prompt_layer_idx) if self.e_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_e_prompt = True
        use_prefix_tune_for_e_prompt = True
        g_prompt_length = 5
        prompt_init = 'uniform'
        same_key_value = False
        num_heads = 12
        use_g_prompt = True

        if use_g_prompt and g_prompt_length is not None and len(g_prompt_layer_idx) != 0:
            if not use_prefix_tune_for_g_prompt:
                g_prompt_shape = (num_g_prompt, g_prompt_length, cfg.embed_dim)
                if prompt_init == 'zero':
                    self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                elif prompt_init == 'uniform':
                    self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                    nn.init.uniform_(self.g_prompt, -1, 1)
            else:
                if same_key_value:
                    g_prompt_shape = (num_g_prompt, 1, g_prompt_length, num_heads, embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
                    self.g_prompt = self.g_prompt.repeat(1, 2, 1, 1, 1)
                else:
                    g_prompt_shape = (num_g_prompt, 2, g_prompt_length, num_heads, cfg.embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
        else:
            self.g_prompt = None

        embedding_key = 'cls'
        prompt_length = 20
        prompt_pool = True
        prompt_key = True
        pool_size = 10
        top_k = 1
        batchwise_prompt = False
        prompt_key_init = 'uniform'
        same_key_value = False
        num_heads = 12

        if use_e_prompt and e_prompt_layer_idx is not None:
            self.e_prompt = EPrompt(length=prompt_length, embed_dim=cfg.embed_dim, embedding_key=embedding_key,
                                    prompt_init=prompt_init,
                                    prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k,
                                    batchwise_prompt=batchwise_prompt,
                                    prompt_key_init=prompt_key_init, num_layers=num_e_prompt,
                                    use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                    num_heads=num_heads, same_key_value=same_key_value)

        self.total_prompt_len = 0
        if self.prompt_pool:
            if not self.use_prefix_tune_for_g_prompt:
                self.total_prompt_len += g_prompt_length * len(self.g_prompt_layer_idx)
            if not self.use_prefix_tune_for_e_prompt:
                self.total_prompt_len += prompt_length * top_k * len(self.e_prompt_layer_idx)
    
        # import pdb; pdb.set_trace()

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=-1,
        cls_features=None,
        train=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        feats_l = []
        
        # print(task_id)
        # print(cls_features)
        # if self.train == False:
        #     import pdb; pdb.set_trace()
        if self.use_g_prompt or self.use_e_prompt:
            if self.use_prompt_mask and train:
            # if 1:
                start = task_id * self.e_prompt.top_k
                end = (task_id + 1) * self.e_prompt.top_k
                single_prompt_mask = torch.arange(start, end).to(x.device)
                prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1)
                if end > self.e_prompt.pool_size:
                    prompt_mask = None
            else:
                prompt_mask = None
            g_prompt_counter = -1
            e_prompt_counter = -1
            # import pdb; pdb.set_trace()
            res = self.e_prompt(x, prompt_mask=prompt_mask, cls_features=cls_features)
            e_prompt = res['batched_prompt']

        else:
            raise AssertionError

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.g_prompt_layer_idx:
                if self.use_prefix_tune_for_g_prompt:
                    g_prompt_counter += 1
                    # Prefix tunning, [B, 2, g_prompt_length, num_heads, embed_dim // num_heads]
                    idx = torch.tensor([g_prompt_counter] * x.shape[0]).to(x.device)
                    g_prompt = self.g_prompt[idx]
                else:
                    raise AssertionError
                    g_prompt = None
                x, lr = blk(
                    x,
                    prompt=g_prompt,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            elif i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                if self.use_prefix_tune_for_e_prompt:
                    # Prefix tunning, [B, 2, top_k * e_prompt_length, num_heads, embed_dim // num_heads]
                    x, lr = blk(
                        x,
                        prompt=e_prompt[e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    raise AssertionError
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            # return {
            #     "x": x,
            #     "padding_mask": masked_padding_mask,
            #     "layer_results": layer_results,
            #     "mask": encoder_mask,
            # }
            res["x"] = x
            res["padding_mask"] = masked_padding_mask
            res["layer_results"] = layer_results
            res["mask"] = encoder_mask
            return res

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=-1, train=None, cls_features=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            train=train,
            cls_features=cls_features,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

class Data2VecMultiModel_2_l2p(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_pool = True
        self.head_type = 'token'
        self.use_prompt_mask = False

        self.use_g_prompt = False
        use_g_prompt = False
        self.g_prompt_layer_idx = []
        g_prompt_layer_idx = []
        # num_g_prompt : The actual number of layers to which g-prompt is attached.
        # In official code, create as many layers as the total number of layers and select them based on the index
        num_g_prompt = len(self.g_prompt_layer_idx) if self.g_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_g_prompt = False
        use_prefix_tune_for_g_prompt = False

        self.use_e_prompt = True
        use_e_prompt = True
        self.e_prompt_layer_idx = [0]
        e_prompt_layer_idx = [0]
        num_e_prompt = len(self.e_prompt_layer_idx) if self.e_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_e_prompt = True
        use_prefix_tune_for_e_prompt = True
        g_prompt_length = 5
        prompt_init = 'uniform'
        same_key_value = False
        num_heads = 12

        if use_g_prompt and g_prompt_length is not None and len(g_prompt_layer_idx) != 0:
            if not use_prefix_tune_for_g_prompt:
                g_prompt_shape = (num_g_prompt, g_prompt_length, cfg.embed_dim)
                if prompt_init == 'zero':
                    self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                elif prompt_init == 'uniform':
                    self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                    nn.init.uniform_(self.g_prompt, -1, 1)
            else:
                if same_key_value:
                    g_prompt_shape = (num_g_prompt, 1, g_prompt_length, num_heads, embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
                    self.g_prompt = self.g_prompt.repeat(1, 2, 1, 1, 1)
                else:
                    g_prompt_shape = (num_g_prompt, 2, g_prompt_length, num_heads, cfg.embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
        else:
            self.g_prompt = None

        embedding_key = 'cls'
        prompt_length = 5
        prompt_pool = True
        prompt_key = True
        pool_size = 10
        top_k = 4
        batchwise_prompt = True
        prompt_key_init = 'uniform'
        same_key_value = False
        num_heads = 12

        if use_e_prompt and e_prompt_layer_idx is not None:
            self.e_prompt = EPrompt(length=prompt_length, embed_dim=cfg.embed_dim, embedding_key=embedding_key,
                                    prompt_init=prompt_init,
                                    prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k,
                                    batchwise_prompt=batchwise_prompt,
                                    prompt_key_init=prompt_key_init, num_layers=num_e_prompt,
                                    use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                    num_heads=num_heads, same_key_value=same_key_value)

        self.total_prompt_len = 0
        if self.prompt_pool:
            if not self.use_prefix_tune_for_g_prompt:
                self.total_prompt_len += g_prompt_length * len(self.g_prompt_layer_idx)
            if not self.use_prefix_tune_for_e_prompt:
                self.total_prompt_len += prompt_length * top_k * len(self.e_prompt_layer_idx)
    
        # import pdb; pdb.set_trace()

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=-1,
        cls_features=None,
        train=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        feats_l = []
        
        # print(task_id)
        # print(cls_features)
        # if self.train == False:
        #     import pdb; pdb.set_trace()
        if self.use_g_prompt or self.use_e_prompt:
            if self.use_prompt_mask and train:
            # if 1:
                start = task_id * self.e_prompt.top_k
                end = (task_id + 1) * self.e_prompt.top_k
                single_prompt_mask = torch.arange(start, end).to(x.device)
                prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1)
                if end > self.e_prompt.pool_size:
                    prompt_mask = None
            else:
                prompt_mask = None
            g_prompt_counter = -1
            e_prompt_counter = -1
            # import pdb; pdb.set_trace()
            res = self.e_prompt(x, prompt_mask=prompt_mask, cls_features=cls_features)
            e_prompt = res['batched_prompt']

        else:
            raise AssertionError

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.g_prompt_layer_idx:
                if self.use_prefix_tune_for_g_prompt:
                    g_prompt_counter += 1
                    # Prefix tunning, [B, 2, g_prompt_length, num_heads, embed_dim // num_heads]
                    idx = torch.tensor([g_prompt_counter] * x.shape[0]).to(x.device)
                    g_prompt = self.g_prompt[idx]
                else:
                    raise AssertionError
                    g_prompt = None
                x, lr = blk(
                    x,
                    prompt=g_prompt,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            elif i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                if self.use_prefix_tune_for_e_prompt:
                    # Prefix tunning, [B, 2, top_k * e_prompt_length, num_heads, embed_dim // num_heads]
                    x, lr = blk(
                        x,
                        prompt=e_prompt[e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    raise AssertionError
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            # return {
            #     "x": x,
            #     "padding_mask": masked_padding_mask,
            #     "layer_results": layer_results,
            #     "mask": encoder_mask,
            # }
            res["x"] = x
            res["padding_mask"] = masked_padding_mask
            res["layer_results"] = layer_results
            res["mask"] = encoder_mask
            return res

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=-1, train=None, cls_features=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            train=train,
            cls_features=cls_features,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

class Data2VecMultiModel_2_sprompt(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_pool = True
        self.head_type = 'token'
        self.use_prompt_mask = True

        self.use_g_prompt = False
        self.g_prompt_layer_idx = []
        g_prompt_layer_idx = []
        # num_g_prompt : The actual number of layers to which g-prompt is attached.
        # In official code, create as many layers as the total number of layers and select them based on the index
        num_g_prompt = len(self.g_prompt_layer_idx) if self.g_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_g_prompt = False
        use_prefix_tune_for_g_prompt = False

        self.use_e_prompt = True
        use_e_prompt = True
        self.e_prompt_layer_idx = [0, 1, 2, 3, 4]
        e_prompt_layer_idx = [0, 1, 2, 3, 4]
        num_e_prompt = len(self.e_prompt_layer_idx) if self.e_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_e_prompt = True
        use_prefix_tune_for_e_prompt = True
        g_prompt_length = 5
        prompt_init = 'uniform'
        same_key_value = False
        num_heads = 12
        use_g_prompt = False

        if use_g_prompt and g_prompt_length is not None and len(g_prompt_layer_idx) != 0:
            if not use_prefix_tune_for_g_prompt:
                g_prompt_shape = (num_g_prompt, g_prompt_length, cfg.embed_dim)
                if prompt_init == 'zero':
                    self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                elif prompt_init == 'uniform':
                    self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                    nn.init.uniform_(self.g_prompt, -1, 1)
            else:
                if same_key_value:
                    g_prompt_shape = (num_g_prompt, 1, g_prompt_length, num_heads, embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
                    self.g_prompt = self.g_prompt.repeat(1, 2, 1, 1, 1)
                else:
                    g_prompt_shape = (num_g_prompt, 2, g_prompt_length, num_heads, cfg.embed_dim // num_heads)
                    if prompt_init == 'zero':
                        self.g_prompt = nn.Parameter(torch.zeros(g_prompt_shape))
                    elif prompt_init == 'uniform':
                        self.g_prompt = nn.Parameter(torch.randn(g_prompt_shape))
                        nn.init.uniform_(self.g_prompt, -1, 1)
        else:
            self.g_prompt = None

        embedding_key = 'cls'
        prompt_length = 20
        prompt_pool = True
        prompt_key = True
        top_k = 1
        batchwise_prompt = False
        prompt_key_init = 'uniform'
        same_key_value = False
        num_heads = 12
        pool_size = 10

        if use_e_prompt and e_prompt_layer_idx is not None:
            self.e_prompt = EPrompt(length=prompt_length, embed_dim=cfg.embed_dim, embedding_key=embedding_key,
                                    prompt_init=prompt_init,
                                    prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k,
                                    batchwise_prompt=batchwise_prompt,
                                    prompt_key_init=prompt_key_init, num_layers=num_e_prompt,
                                    use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                    num_heads=num_heads, same_key_value=same_key_value)

        self.total_prompt_len = 0
        if self.prompt_pool:
            if not self.use_prefix_tune_for_g_prompt:
                self.total_prompt_len += g_prompt_length * len(self.g_prompt_layer_idx)
            if not self.use_prefix_tune_for_e_prompt:
                self.total_prompt_len += prompt_length * top_k * len(self.e_prompt_layer_idx)
    
        # import pdb; pdb.set_trace()

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=-1,
        cls_features=None,
        train=None,
        stop=False,
        features_mean=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        feats_l = []
        
        # print(task_id)
        # print(cls_features)
        if stop == True:
            import pdb; pdb.set_trace()
        if self.use_g_prompt or self.use_e_prompt:
            if self.use_prompt_mask and train:
                start = task_id * self.e_prompt.top_k
                end = (task_id + 1) * self.e_prompt.top_k
                single_prompt_mask = torch.arange(start, end).to(x.device)
                prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1)
                if end > self.e_prompt.pool_size:
                    prompt_mask = None
            else:
                prompt_mask = None
            g_prompt_counter = -1
            e_prompt_counter = -1
            # import pdb; pdb.set_trace()
            res = self.e_prompt(x, prompt_mask=prompt_mask, cls_features=cls_features, stop=stop, features_mean=features_mean)
            e_prompt = res['batched_prompt']

        else:
            raise AssertionError

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.g_prompt_layer_idx:
                if self.use_prefix_tune_for_g_prompt:
                    g_prompt_counter += 1
                    # Prefix tunning, [B, 2, g_prompt_length, num_heads, embed_dim // num_heads]
                    idx = torch.tensor([g_prompt_counter] * x.shape[0]).to(x.device)
                    g_prompt = self.g_prompt[idx]
                else:
                    raise AssertionError
                    g_prompt = None
                x, lr = blk(
                    x,
                    prompt=g_prompt,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            elif i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                if self.use_prefix_tune_for_e_prompt:
                    # Prefix tunning, [B, 2, top_k * e_prompt_length, num_heads, embed_dim // num_heads]
                    x, lr = blk(
                        x,
                        prompt=e_prompt[e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    raise AssertionError
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            # return {
            #     "x": x,
            #     "padding_mask": masked_padding_mask,
            #     "layer_results": layer_results,
            #     "mask": encoder_mask,
            # }
            res["x"] = x
            res["padding_mask"] = masked_padding_mask
            res["layer_results"] = layer_results
            res["mask"] = encoder_mask
            return res

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=-1, train=None, cls_features=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            train=train,
            cls_features=cls_features,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

class Data2VecMultiModel_2_hidepet(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, prompt_type=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_tool = False
        prompt_tool = False
        self.head_type = 'token'
        head_type = 'token'
        self.use_prompt_mask = False
        use_prompt_mask = False

        self.use_e_prompt = True
        use_e_prompt = True
        self.e_prompt_layer_idx = [0, 1, 2, 3, 4]
        e_prompt_layer_idx = [0, 1, 2, 3, 4]
        num_e_prompt = len(self.e_prompt_layer_idx) if self.e_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_e_prompt = True
        use_prefix_tune_for_e_prompt = True
        prompt_init = 'uniform'
        embedding_key = 'cls'
        prompt_length = 20
        prompt_key = 'False'
        pool_size = None
        batchwise_prompt = False
        num_e_prompt = 5
        same_key_value = False
        num_heads = 12
        self.prompt_type = prompt_type # continual for tii

        if use_e_prompt and e_prompt_layer_idx is not None:
            if prompt_type == 'hide':
                self.e_prompt = EPrompt(length=prompt_length, embed_dim=cfg.embed_dim, embedding_key=embedding_key,
                                        prompt_init=prompt_init,
                                        prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k,
                                        batchwise_prompt=batchwise_prompt,
                                        prompt_key_init=prompt_key_init, num_layers=num_e_prompt,
                                        use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                        num_heads=num_heads, same_key_value=same_key_value)
            if prompt_type == 'continual':
                self.e_prompt = ContinualPrompt(length=prompt_length, embed_dim=cfg.embed_dim, num_layers=num_e_prompt, use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt, num_heads=num_heads)
            if prompt_type == 'momentum':
                self.e_prompt = MomentumPrompt(length=prompt_length, embed_dim=cfg.embed_dim, num_layers=num_e_prompt, use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                        num_heads=num_heads)



    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=None,
        train=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        # import pdb; pdb.set_trace()
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        if self.use_e_prompt:
            e_prompt_counter = -1
            if self.prompt_type == 'hide':
                if self.use_prompt_mask and train:
                    start = task_id * self.e_prompt.top_k
                    end = (task_id + 1) * self.e_prompt.top_k
                    single_prompt_mask = torch.arange(start, end).to(x.device)
                    prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1)
                    if end > self.e_prompt.pool_size:
                        prompt_mask = None
                    if task_id == 0:
                        prompt_momentum = 0
                else:
                    prompt_mask = None
                res = self.e_prompt(x, train=train, task_id=task_id, prompt_mask=prompt_mask, prompt_idx=prompt_id)
            
            if self.prompt_type in ['continual', 'momentum']:
                res = self.e_prompt(x, train=train)      
            
            e_prompt = res['batched_prompt']
        else:
            raise AssertionError

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                if self.prompt_type == 'hide':
                    x, lr = blk(
                        x,
                        prompt=e_prompt[:, :, e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    x, lr = blk(
                        x,
                        prompt=e_prompt[e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            train=train,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res

    def after_task(self, task_id=-1, device=None):
        if self.prompt_type in ['momentum', 'hide']:
            self.e_prompt.after_task(task_id=task_id, device=device)
        # import pdb; pdb.set_trace()
        # if self.use_mlp_head:
        #     #self.mlp_head = MlpHead(input_dim=self.embed_dim, output_dim=self.mlp_output_dim)
        #     for layer in self.mlp_head.children():
        #         if hasattr(layer, 'reset_parameters'):
        #            layer.reset_parameters()

class Data2VecMultiModel_2_hidepet_stage2(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, prompt_type=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.prompt_pool = True
        prompt_pool = True
        self.head_type = 'token'
        head_type = 'token'
        self.use_prompt_mask = True
        use_prompt_mask = True

        self.use_e_prompt = True
        use_e_prompt = True
        self.e_prompt_layer_idx = [0, 1, 2, 3, 4]
        e_prompt_layer_idx = [0, 1, 2, 3, 4]
        num_e_prompt = len(self.e_prompt_layer_idx) if self.e_prompt_layer_idx is not None else 0
        self.use_prefix_tune_for_e_prompt = True
        use_prefix_tune_for_e_prompt = True
        prompt_key_init = 'uniform'
        prompt_init = 'uniform'
        embedding_key = 'cls'
        prompt_length = 20
        prompt_key = 'False'
        pool_size = 10
        batchwise_prompt = False
        num_e_prompt = 5
        same_key_value = False
        num_heads = 12
        self.prompt_type = 'hide'
        top_k = 1
        self.top_k = 1

        # import pdb; pdb.set_trace()

        if use_e_prompt and e_prompt_layer_idx is not None:
            if prompt_type == 'hide':
                self.e_prompt = EPrompt(length=prompt_length, embed_dim=cfg.embed_dim, embedding_key=embedding_key,
                                        prompt_init=prompt_init,
                                        prompt_pool=prompt_pool, prompt_key=prompt_key, pool_size=pool_size, top_k=top_k,
                                        batchwise_prompt=batchwise_prompt,
                                        prompt_key_init=prompt_key_init, num_layers=num_e_prompt,
                                        use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                        num_heads=num_heads, same_key_value=same_key_value)
            if prompt_type == 'continual':
                self.e_prompt = ContinualPrompt(length=prompt_length, embed_dim=cfg.embed_dim, num_layers=num_e_prompt, use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt, num_heads=num_heads)
            if prompt_type == 'momentum':
                self.e_prompt = MomentumPrompt(length=prompt_length, embed_dim=cfg.embed_dim, num_layers=num_e_prompt, use_prefix_tune_for_e_prompt=use_prefix_tune_for_e_prompt,
                                        num_heads=num_heads)



    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=None,
        train=None,
        prompt_id=None,
        prompt_momentum=None
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        if self.use_e_prompt:
            e_prompt_counter = -1
            if self.prompt_type == 'hide':
                if self.use_prompt_mask and train:
                    start = task_id * self.e_prompt.top_k
                    end = (task_id + 1) * self.e_prompt.top_k
                    single_prompt_mask = torch.arange(start, end).to(x.device)
                    prompt_mask = single_prompt_mask.unsqueeze(0).expand(x.shape[0], -1)
                    if end > self.e_prompt.pool_size:
                        prompt_mask = None
                    if task_id == 0:
                        prompt_momentum = 0
                else:
                    prompt_mask = None
                # import pdb; pdb.set_trace()
                res = self.e_prompt(x, train=train, task_id=task_id, prompt_mask=prompt_mask, prompt_idx=prompt_id)
            
            if self.prompt_type in ['continual', 'momentum']:
                res = self.e_prompt(x, train=train)      
            
            e_prompt = res['batched_prompt']
        else:
            raise AssertionError

        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            if i in self.e_prompt_layer_idx:
                e_prompt_counter += 1
                if self.prompt_type == 'hide':
                    x, lr = blk(
                        x,
                        prompt=e_prompt[:, :, e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    x, lr = blk(
                        x,
                        prompt=e_prompt[e_prompt_counter],
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            # if features_only:
            #     layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            # return {
            #     "x": x,
            #     "padding_mask": masked_padding_mask,
            #     "layer_results": layer_results,
            #     "mask": encoder_mask,
            # }
            res["x"] = x
            res["padding_mask"] = masked_padding_mask
            res["layer_results"] = layer_results
            res["mask"] = encoder_mask
            return res


    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None, prompt_id=None, prompt_momentum=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            train=train,
            prompt_id=prompt_id,
            prompt_momentum=prompt_momentum
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

    def output_layer_only(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        # x = res["x"][:, 0]
        res = {}
        res["pre_logits"] = source
        source = self.layer_norm(source)
        res["logits"] = self.cls_head(source)
        return res

    def after_task(self, task_id=-1, device=None):
        if self.prompt_type in ['momentum', 'hide']:
            self.e_prompt.after_task(task_id=task_id, device=device)
        # import pdb; pdb.set_trace()
        # if self.use_mlp_head:
        #     #self.mlp_head = MlpHead(input_dim=self.embed_dim, output_dim=self.mlp_output_dim)
        #     for layer in self.mlp_head.children():
        #         if hasattr(layer, 'reset_parameters'):
        #            layer.reset_parameters()

class Data2VecMultiModel_2_ranpac(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, rp_dim=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task
        self.rp_dim = rp_dim

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.rp_dim = rp_dim
        W_rand = torch.randn(cfg.embed_dim, self.rp_dim)
        self.register_buffer('W_rand', W_rand)


    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            x, lr = blk(
                x,
                padding_mask=masked_padding_mask,
                alibi_bias=ab,
            )
            if features_only:
                layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x

        # default model do not need classification function
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res
    
    def forward_rp(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = F.relu(x @ self.W_rand)
        res["x"] = x
        return res

class Data2VecMultiModel_2_ranpac(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, rp_dim=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task
        self.rp_dim = rp_dim

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.rp_dim = rp_dim
        W_rand = torch.randn(cfg.embed_dim, self.rp_dim)
        self.register_buffer('W_rand', W_rand)


    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None
            x, lr = blk(
                x,
                padding_mask=masked_padding_mask,
                alibi_bias=ab,
            )
            if features_only:
                layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x

        # default model do not need classification function
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res
    
    def forward_rp(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, train=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        x = F.relu(x @ self.W_rand)
        return x

class Data2VecMultiModel_2_ranpac_lora(nn.Module):
    def make_modality_encoder(
        self,
        cfg: D2vModalityConfig,
        embed_dim: int,
        make_block: Callable[[float], nn.ModuleList],
        norm_layer: Callable[[int], nn.LayerNorm],
        layer_norm_first: bool,
        alibi_biases,
        task,
    ) -> ModalitySpecificEncoder:

        # import pdb; pdb.set_trace()
        if cfg.type.value == Modality.IMAGE.value:
            enc_cls = ImageEncoder
        else:
            raise Exception(f"unsupported modality {cfg.type}")

        return enc_cls(
            cfg,
            embed_dim,
            make_block,
            norm_layer,
            layer_norm_first,
            alibi_biases,
            task,
        )

    def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None, rp_dim=None, **kwargs):
        super().__init__()
        self.cfg = cfg
        if kwargs.get('dataset'):
            dataset = kwargs.get('dataset')
        self.modalities = modalities
        self.task = task

        make_layer_norm = partial(
            nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
        )

        def make_block(drop_path, dim=None, heads=None):
            return AltBlock(
                cfg.embed_dim if dim is None else dim,
                cfg.num_heads if heads is None else heads,
                cfg.mlp_ratio,
                qkv_bias=True,
                drop=cfg.encoder_dropout,
                attn_drop=cfg.attention_dropout,
                mlp_drop=cfg.activation_dropout,
                post_mlp_drop=cfg.post_mlp_drop,
                drop_path=drop_path,
                norm_layer=make_layer_norm,
                layer_norm_first=cfg.layer_norm_first,
                ffn_targets=not cfg.end_of_block_targets,
            )

        self.alibi_biases = {}
        self.modality_encoders = nn.ModuleDict()
        
        # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py)
        for mod in self.modalities:
            # import pdb ; pdb.set_trace()
            mod_cfg = getattr(cfg.modalities, mod.name.lower())
            mod_cfg.type = cfg.type
            mod_cfg.max_length = cfg.max_length
            enc = self.make_modality_encoder(
                mod_cfg,
                cfg.embed_dim,
                make_block,
                make_layer_norm,
                cfg.layer_norm_first,
                self.alibi_biases,
                task,
            )
            self.modality_encoders[mod.name] = enc

        self.ema = None

        self.average_top_k_layers = cfg.average_top_k_layers
        self.loss_beta = cfg.loss_beta
        self.loss_scale = cfg.loss_scale
        self.utterance_level = cfg.utterance_level

        self.dropout_input = nn.Dropout(cfg.dropout_input)

        dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)

        self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
        self.num_updates = 0

        norm_layer_before_mlp = partial(nn.LayerNorm, eps=1e-6)
        self.layer_norm = norm_layer_before_mlp(cfg.embed_dim) 
        # self.cls_head = nn.Linear(cfg.embed_dim, 50) 
        self.cls_head = nn.Linear(cfg.embed_dim, kwargs['args'].num_classes) 

        self.embed_dim = cfg.embed_dim
        self.lora_depth =cfg.lora_depth
        self.rank = cfg.lora_rank
        self.lora_layer = ContinualLora(dim=cfg.embed_dim, rank=self.rank, depth=abs(self.lora_depth))  
        self.lora_layer_for_task_1 = ContinualLora(dim=cfg.embed_dim, rank=self.rank, depth=abs(self.lora_depth))  

        self.rp_dim = rp_dim
        if cfg.embed_dim != self.rp_dim:
            W_rand = torch.randn(cfg.embed_dim, self.rp_dim)
        else:
            W_rand = None
            print("W_rand is None!!")
        self.register_buffer('W_rand', W_rand)

    def forward(
        self,
        source,
        target=None,
        id=None,
        mode=None,
        padding_mask=None,
        mask=True,
        features_only=False,
        force_remove_masked=False,
        remove_extra_tokens=True,
        precomputed_mask=None,
        task_id=None,
        drs_lora=0,
    ):

        # import pdb; pdb.set_trace()
        if mode is None:
            assert self.cfg.supported_modality is not None
            mode = self.cfg.supported_modality

        if isinstance(mode, Modality):
            mode = mode.name

        feature_extractor = self.modality_encoders[mode]

        mask_seeds = None
        if id is not None:
            mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)

        # extract (unmasked) features using CNN encoder
        # import pdb; pdb.set_trace()
        extractor_out = feature_extractor(
            source,
            padding_mask,
            mask,
            remove_masked=not features_only or force_remove_masked,
            clone_batch=self.cfg.clone_batch if not features_only else 1,
            mask_seeds=mask_seeds,
            precomputed_mask=precomputed_mask,
        )

        # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) )
        # EAT does not employ the ablibi mechanism in Transformer
        x = extractor_out["x"]
        encoder_mask = extractor_out["encoder_mask"]
        masked_padding_mask = extractor_out["padding_mask"]
        masked_alibi_bias = extractor_out.get("alibi_bias", None)
        alibi_scale = extractor_out.get("alibi_scale", None)

        # assert self.dropout_input is None
        assert masked_alibi_bias is None
        if self.dropout_input is not None:
            # import pdb; pdb.set_trace()
            x = self.dropout_input(x)

        # import pdb; pdb.set_trace()
        # standard Transformer (for student encoder)
        layer_results = []
        for i, blk in enumerate(self.blocks):
            ab = None

            if (i < self.lora_depth and self.lora_depth > 0) or (self.lora_depth < 0 and i >= len(self.blocks) + self.lora_depth):
                if drs_lora == 1:
                    lora_list = [self.lora_layer]
                    if task_id > 1:
                        for lora_idx in range(1, task_id):
                            lora_list.append(getattr(self, f"lora_layer_for_task_{lora_idx}"))
                    x, lr = blk(
                        x,
                        drs_lora=lora_list, 
                        depth_id=i if self.lora_depth > 0 else len(self.blocks) - i - 1,
                        padding_mask=masked_padding_mask,
                        alibi_bias=ab,
                    )
                else:
                    if task_id == 0:
                        x, lr = blk(
                            x,
                            lora=self.lora_layer, 
                            depth_id=i if self.lora_depth > 0 else len(self.blocks) - i - 1,
                            padding_mask=masked_padding_mask,
                            alibi_bias=ab,
                        )
                    elif task_id >= 1:
                        # print('use 2 loras')
                        lora_list = [self.lora_layer]
                        for lora_idx in range(1, task_id + 1):
                            lora_list.append(getattr(self, f"lora_layer_for_task_{lora_idx}"))
                        x, lr = blk(
                            x,
                            lora=lora_list,
                            depth_id=i if self.lora_depth > 0 else len(self.blocks) - i - 1,
                            padding_mask=masked_padding_mask,
                            alibi_bias=ab,
                        )
                    elif task_id == -1:
                        x, lr = blk(
                            x,
                            depth_id=i if self.lora_depth > 0 else len(self.blocks) - i - 1,
                            padding_mask=masked_padding_mask,
                            alibi_bias=ab,
                        )
                    else:
                        raise AssertionError
                        # only for drs test
            else:
                x, lr = blk(
                    x,
                    padding_mask=masked_padding_mask,
                    alibi_bias=ab,
                )
            layer_results.append(lr)

        # extract features for fine-tuning
        # import pdb; pdb.set_trace()
        if features_only:
            if remove_extra_tokens:
                x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
                if masked_padding_mask is not None:
                    masked_padding_mask = masked_padding_mask[
                        :, feature_extractor.modality_cfg.num_extra_tokens :
                    ]

            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }
        else:
            return {
                "x": x,
                "padding_mask": masked_padding_mask,
                "layer_results": layer_results,
                "mask": encoder_mask,
            }

    def extract_features(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=True,
            remove_extra_tokens=remove_extra_tokens,
        )
        return res

    def naive_classification(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, drs_lora=0, task_id=None, ifaug=False, **kwargs):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            drs_lora=drs_lora,
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        # import pdb; pdb.set_trace()
        res["layer_results"] = [_[:, 0] for _ in res["layer_results"]]
        x = self.layer_norm(x)
        res["logits"] = self.cls_head(x)
        return res

    def forward_rp(
        self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=False, task_id=None, drs_lora=None, train=None):
        res = self.forward(
            source,
            mode=mode,
            padding_mask=padding_mask,
            mask=mask,
            features_only=False,
            remove_extra_tokens=remove_extra_tokens,
            task_id=task_id,
            drs_lora=drs_lora
        )
        # import pdb; pdb.set_trace()
        x = res["x"][:, 0]
        res["pre_logits"] = x
        if self.W_rand is not None:
            x = F.relu(x @ self.W_rand)
        res["x"] = x
        return res