# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Optional, Tuple
from dataclasses import dataclass
import math
import torch
from torch import nn
import torch.nn.functional as F

import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.layers import (
    ParallelEmbedding,
    RowParallelLinear,
    ColumnParallelLinear,
)
from ..components import RMSNorm
from flash_attn import flash_attn_func

import open_clip

# for honeybee
import torch.utils.checkpoint
from tqdm import tqdm
from transformers.models.auto import AutoModelForCausalLM, AutoConfig
from ..honeybee.honeybee.configuration_honeybee import HoneybeeConfig
from ..honeybee.honeybee.visual_encoders import build_encoder
from ..honeybee.pipeline.utils import check_local_file
from ..honeybee.honeybee.projectors import (
    CAbstractor,
    DAbstractor,
    MLPProjector,
    HoneybeeVisualProjectorModel, # Resampler
)
from ..honeybee.honeybee.common_layers import HoneybeePreTrainedModel

default_linear_init = nn.init.xavier_uniform_


@dataclass
class ModelArgs:
    dim: int = 512
    n_layers: int = 8
    n_heads: int = 8
    vocab_size: int = -1  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
                   [: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim -
             1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.embedding = self._create_positional_embeddings()

    def _create_positional_embeddings(self):
        position = torch.arange(0, self.max_seq_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.dim, 2).float() * (-math.log(10000.0) / self.dim)
        )
        embeddings = torch.zeros(self.max_seq_len, self.dim)
        embeddings[:, 0::2] = torch.sin(position * div_term)
        embeddings[:, 1::2] = torch.cos(position * div_term)
        return embeddings.unsqueeze(0)  # Add batch dimension

    def forward(self, x):
        return self.embedding[:, :x.size(1), :].to(x.device)
    
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=default_linear_init,
        )

        self.flash = True
        self.k_cache, self.v_cache = None, None

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        if freqs_cis is not None:
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if self.k_cache is None or self.v_cache is None:
            keys, values = xk, xv
        else:
            self.k_cache = self.k_cache.to(xk)
            self.v_cache = self.v_cache.to(xv)
            self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk
            self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv
            keys = self.k_cache[:bsz, :start_pos + seqlen]
            values = self.v_cache[:bsz, :start_pos + seqlen]

        output = flash_attn_func(
            xq, keys, values, dropout_p=0.0, causal=mask is not None)
        output = output.contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

    def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
        kv_cache_shape = (max_batch_size, max_seq_len,
                          self.n_local_heads, self.head_dim)
        if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
            self.k_cache = torch.empty(kv_cache_shape)
        if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
            self.v_cache = torch.empty(kv_cache_shape)

    def destroy_kv_cache(self) -> None:
        self.k_cache, self.v_cache = None, None


class CrossAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=default_linear_init,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=default_linear_init,
        )

        self.flash = True
        self.k_cache, self.v_cache = None, None

    def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
        # the x_q would be the learnable queries
        bsz, seqlen_q, _ = x_q.shape
        _, seqlen_kv, _ = x_kv.shape
        xq, xk, xv = self.wq(x_q), self.wk(x_kv), self.wv(x_kv)

        xq = xq.view(bsz, seqlen_q, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen_kv, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen_kv, self.n_local_heads, self.head_dim)

        if freqs_cis is not None:
            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)  # 这里的embedding还要再研究一下

        if self.k_cache is None or self.v_cache is None:
            keys, values = xk, xv
        else:
            self.k_cache = self.k_cache.to(xk)
            self.v_cache = self.v_cache.to(xv)
            self.k_cache[:bsz, start_pos: start_pos + seqlen_kv, :, :] = xk
            self.v_cache[:bsz, start_pos: start_pos + seqlen_kv, :, :] = xv
            keys = self.k_cache[:bsz, :start_pos + seqlen_kv]
            values = self.v_cache[:bsz, :start_pos + seqlen_kv]

        output = flash_attn_func(
            xq, keys, values, dropout_p=0.0, causal=mask is not None)
        output = output.contiguous().view(bsz, seqlen_q, -1)

        return self.wo(output)

    def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
        kv_cache_shape = (max_batch_size, max_seq_len,
                          self.n_local_heads, self.head_dim)
        if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
            self.k_cache = torch.empty(kv_cache_shape)
        if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
            self.v_cache = torch.empty(kv_cache_shape)

    def destroy_kv_cache(self) -> None:
        self.k_cache, self.v_cache = None, None

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * \
            ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init,
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init
        )

    def _silu_gating(self, x, y):
        return F.silu(x) * y

    def forward(self, x):
        return self.w2(self._silu_gating(self.w1(x), self.w3(x)))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def _forward_ffn(self, h):
        return h + self.feed_forward(self.ffn_norm(h))

    def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
        return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
        h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt)
        out = self._forward_ffn(h)
        return out


class QformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        args = ModelArgs()
        args.n_heads = 8
        args.dim = 1024
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = self.dim // self.n_heads
        self.attention = Attention(args)
        self.cross_attention = CrossAttention(args)
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        )

        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def _forward_ffn(self, h):
        return h + self.feed_forward(self.ffn_norm(h))

    def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
        return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)

    def _forward_cross_attention(self, x_q, x_kv, start_pos, freqs_cis, mask, prompt):
        return x_q + self.cross_attention.forward(self.attention_norm(x_q), self.attention_norm(x_kv),start_pos, freqs_cis, mask, prompt)

    def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):

        # learnable queries self_attention
        x_q_h = self._forward_attention(x_q, start_pos, freqs_cis, mask, prompt)
        # x_q_h_2 = self._forward_ffn(x_q_h)
        # cross attention
        h = self._forward_cross_attention(x_q_h, x_kv, start_pos, freqs_cis, mask, prompt)
        out = self._forward_ffn(h)

        return out
    
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

class Transformer(nn.Module):

    def build_projector(self, config: HoneybeeConfig):
        """Build projector (abstractor) and query_tokens (optionally for resampler)"""
        proj_config = config.projector_config
        proj_type = proj_config.projector_type
        num_input_tokens = self.vision_model.get_num_tokens()

        # 这里的abstractor是不是就没有加载参数?
        # 基于proj_type选一个类,进行初始化,这是一种基于变量选类名的写法,可以学习
        abstractor = {
            "mlp": MLPProjector,
            "resampler": HoneybeeVisualProjectorModel,
            "c-abs": CAbstractor,
            "d-abs": DAbstractor,
        }[proj_type](proj_config, num_input_tokens=num_input_tokens)

        # deformable attention only supports fp32
        if proj_type == "d-abs":
            abstractor.to(torch.float)

        return abstractor

    def build_language_model(self, config: HoneybeeConfig):
        lm_local_files_only, lm_file_name = check_local_file(
            config.lm_config.pretrained_lm_name_or_path
        )

        # lm_local_files_only = True
        try:
            language_model = AutoModelForCausalLM.from_pretrained(
                lm_file_name,
                local_files_only=lm_local_files_only,  # 是否从huggingface hub上下载模型,如果为True,没有模型不下载,直接报错
                attn_implementation="flash_attention_2",
            )
            # 确实是一个只在from_pretrained里用到的参数

            # Just initialize the model without load the pretrained weights
            # llm_config = AutoConfig.from_pretrained(lm_file_name)
            # language_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation="flash_attention_2")
        
        except Exception as e:
            language_model = AutoModelForCausalLM.from_pretrained(
                lm_file_name,
                local_files_only=lm_local_files_only,
            )

        return language_model
    
    def __init__(self, config: HoneybeeConfig, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers
        # 这里的tok_embeddings也需要学习吗? 先参考onellm 让他学吧
        # tok_embeddings是weight里的model.embed_tokens.weight
        self.tok_embeddings = ParallelEmbedding(
            params.vocab_size, params.dim, init_method=nn.init.normal_,
        )
        
        # vision encoder
        print("Build vision model ...")
        self.vision_model = build_encoder(config.vision_config)
        def _set_hf_initialized(module):
            module._is_hf_initialized = True
        self.vision_model.apply(_set_hf_initialized)

        print("Build projector ...")
        self.proj_type = config.projector_config.projector_type
        self.abstractor = self.build_projector(config)

        print("Build LM ...")        
        self.language_model = self.build_language_model(config)
                
        # use layers to replace language model (vicuna seems the same as llama)
        # self.layers = torch.nn.ModuleList()
        # for layer_id in range(params.n_layers):
        #     self.layers.append(TransformerBlock(layer_id, params))
        # # Now, when use action loss, we try to unfrozen the parameters.
        # for param in self.layers.parameters():
        #     param.requires_grad = False
        #     # param.data = param.data.half()

        print("Load HoneyBee pretrained weights...")
        honeybee_state_dict = torch.load("./LLM_ckpt/honeybee/7B-C-Abs-M256/last/pytorch_model.bin", map_location="cpu")

        # adjust the weights name in honeybee_state_dict
        vision_model_state_dict = {}
        for key, value in honeybee_state_dict.items():
            if "vision_model" in key:
                key_split = key.split(".")
                new_key = ".".join(key_split[1:])
                vision_model_state_dict[new_key] = value   

        abstractor_state_dict = {}
        for key, value in honeybee_state_dict.items():
            if "abstractor" in key:
                key_split = key.split(".")
                new_key = ".".join(key_split[1:])
                abstractor_state_dict[new_key] = value   
                
        llm_state_dict = {}
        for key, value in honeybee_state_dict.items():
            if "language_model" in key:
                key_split = key.split(".")
                new_key = ".".join(key_split[1:])
                llm_state_dict[new_key] = value  

        vis_msg = self.vision_model.load_state_dict(vision_model_state_dict, strict=False)
        abs_msg = self.abstractor.load_state_dict(abstractor_state_dict, strict=False)
        llm_msg = self.language_model.load_state_dict(llm_state_dict, strict=False)

        # load tok embeddings
        with torch.no_grad():
            self.tok_embeddings.weight.copy_(llm_state_dict['model.embed_tokens.weight'])
        
        self.tok_embeddings.weight.requires_grad = True

        print(vis_msg)
        print(abs_msg)
        print(llm_msg)
        print("debug")

        # 冻结参数
        for param in self.vision_model.parameters():
            param.requires_grad = False
            param.data = param.data.half()  # convert to half precision, float16

        for param in self.language_model.parameters():
            param.requires_grad = False
            # param.data = param.data.half()

        # # The instance of LLaMA7B
        # self.layers = torch.nn.ModuleList()
        # for layer_id in range(params.n_layers):
        #     self.layers.append(TransformerBlock(layer_id, params))
        # # Now, when use action loss, we try to unfrozen the parameters.
        # for param in self.layers.parameters():
        #     param.requires_grad = False
        #     # param.data = param.data.half()

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        # 这里output weight需要更新
        # self.output = ColumnParallelLinear(
        #     params.dim, params.vocab_size, bias=False, init_method=default_linear_init,
        # )

        # here 27 is for mmfi, in the following, set it to a params setting parameter
        self.output_action = ColumnParallelLinear(
            params.dim, 27, bias=False, init_method=default_linear_init,
        )

        # The complex positional embeddings.
        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
        )

        # load clip ViT-L-14 and frozen its parameters.
        # change here to obtain image bind
        # self.clip, _, _ = open_clip.create_model_and_transforms(
        #     'ViT-L-14', pretrained='openai')

        # for param in self.clip.parameters():
        #     param.requires_grad = False
        #     param.data = param.data.half()  # convert to half precision, float16
        
        # self.cache_image_words = 0  # for inference

        # clip_width = self.vision_model.visual.conv1.out_channels
        clip_width = self.vision_model.vision_model.embeddings.patch_embedding.out_channels

        self.clip_positional_embeddings = SinusoidalPositionalEmbedding(dim=clip_width)

        # self.num_ca = 8
        # self.connector = torch.nn.ModuleList()
        # for i in range(self.num_ca):
        #     self.connector.append(QformerBlock())

        self.conv1 = nn.ModuleDict()  # conv2D for each modality
        
        # self.learnable_queries = nn.ParameterDict()

        # self.positional_embedding = nn.ParameterDict()

        self.clip_proj1 = nn.ModuleDict()
        self.clip_proj2 = nn.ModuleDict()

        self.start_tag = nn.ParameterDict()
        self.end_tag = nn.ParameterDict()

        self.modals = ['video', 'depth', 'wifi', 'mmwave', 'lidar', 'infra', 'rfid']
        
        # self.conv1相当于一种针对每一种模态设计的adapter
        for modal in self.modals:
    
            # 这里还要确保两个数据集相同的modality可以处理成一致的形式啊。
            if modal in ['video','depth', 'infra']:
                pass


            elif modal == 'lidar':

                from model.lib.point_utils import MMFiPointPatchEmbed
                self.conv1[modal] = MMFiPointPatchEmbed(
                    in_channels=3, channels=clip_width, sample_number=1024)

            elif modal == 'mmwave':

                from model.lib.point_utils import MMFiPointPatchEmbed
                self.conv1[modal] = MMFiPointPatchEmbed(
                    in_channels=5, channels=clip_width, sample_number=64)

            elif modal == 'wifi':

                self.conv1[modal] = nn.Conv2d(3, clip_width, kernel_size=(3, 3), stride=(2, 2))


            elif modal == 'rfid':

                # the channel of rfid = [23, 148]
                self.conv1[modal] = nn.Conv2d(23, clip_width, kernel_size=1, stride=1)  # 固定 17个2-dim点,直接conv。


            self.clip_proj1[modal] = nn.Sequential(
                nn.Linear(clip_width, clip_width),
                nn.LayerNorm(clip_width))

            # clip_proj2也相当于一层linear的adapter, 向LLM投影
            # self.clip_proj2[modal] = nn.Sequential(
            #     nn.Linear(clip_width, params.dim),
            #     nn.LayerNorm(params.dim))

            self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
            self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
        # TODO: Freeze some parameters at here. Freeze LLM for pretraining and Projection for finetuining.

    # @torch.no_grad()

    def clip_encode_image(self, x, modal='video'):
        # shape = [*, width, grid ** 2]
        # x = x.reshape(x.shape[0], x.shape[1], -1)
        # x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]

        # input size [*, grid ** 2, width]
        # self.vision_model.vision_model.embeddings.patch_embedding.out_channels

        # x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
        #               x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

        # use pretrained pos embeding for rest modalities
        pos_embedding = self.clip_positional_embeddings(x)

        x = x + pos_embedding.to(x.dtype)
        # x = self.clip.visual.ln_pre(x)  
        x = self.vision_model.vision_model.pre_layrnorm(x)  

        x = x.permute(1, 0, 2)  # NLD -> LND
        # x = self.clip.visual.transformer(x)
        outputs = self.vision_model.vision_model.encoder(x)
        x = outputs.last_hidden_state
        x = x.permute(1, 0, 2)  # LND -> NLD

        # preserve all spatial tokens
        x = self.vision_model.vision_model.post_layernorm(x[:, :, :])
        return x

    def encode_image(self, x, modal='video'):
        bsz = x.size(0)
        T = 1

        if modal == 'video':
            # just copy the video modality
            B, T = x.shape[:2]
            bsz = B * T
            # pos_embedding = self.clip.visual.positional_embedding[1:]  # remove cls visual embeddings

            x = x.reshape(bsz, *x.shape[2:])
            x_conv = self.vision_model.vision_model.embeddings.patch_embedding(x).view(x.shape[0],1024,-1).permute(0,2,1)  # [15, 1024, 16, 16, 3]

            x = x_conv
            
        elif modal == 'depth':
            B, T = x.shape[:2]
            bsz = B * T
            x = x.reshape(bsz, *x.shape[2:])
            # 为了使用clip.visual.conv1, 复制channels, 并统一转为float,和open_clip的half()对上.
            # x = x.repeat(1,3,1,1).float()
            x = x.repeat(1,3,1,1)
            # if not self.training:
            #     x = x.half()

            # x_conv = self.clip.visual.conv1(x)
            x_conv = self.vision_model.vision_model.embeddings.patch_embedding(x).view(x.shape[0],1024,-1).permute(0,2,1)  # [15, 1024, 16, 16, 3]
            
            x = x_conv 

        elif modal == 'lidar':
            # [B, 1024, 3] -> [B, 1024, 1024, 1]  # 第一个1024是3->1024特征维度,第二个1024是随机采样的点数。
            # 他们这里的点都是统一维度16384?
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])
            # x_conv, x_selected = self.conv1[modal](x.float())
            x_conv, x_selected = self.conv1[modal](x)
            x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            # if self.training:
            #     x_conv, x_selected = self.conv1[modal](x.float())
            #     x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            # else:
            #     x_conv, x_selected = self.conv1[modal](x.half())
            #     x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)

            x = x_conv 

        elif modal == 'mmwave':
            # [B, 64, 5] -> [B, 1024, 64, 1]
            # 他们这里的点都是统一维度16384?
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])

            x_conv, x_selected = self.conv1[modal](x)
            x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            # if self.training:
            #     x_conv, x_selected = self.conv1[modal](x.float())
            #     x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            # else:
            #     x_conv, x_selected = self.conv1[modal](x.half())
            #     x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)
            
            x = x_conv
            
        elif modal == 'infra':
            B, T = x.shape[:2]
            bsz = B * T
            x = x.reshape(bsz, *x.shape[2:])
            x = x.repeat(1,3,1,1)
            # x = x.half()
            x_conv = self.vision_model.vision_model.embeddings.patch_embedding(x).view(x.shape[0],1024,-1).permute(0,2,1)  # [15, 1024, 16, 16, 3]
            x = x_conv

        elif modal == 'wifi':
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])
            x_conv = self.conv1[modal](x).view(x.shape[0],1024,-1).permute(0,2,1) 
            x = x_conv

        elif modal == 'rfid':
            x = x.unsqueeze(dim=1)
            B, T = x.shape[:2]
            bsz = B * T
            x = x.reshape(bsz, *x.shape[2:])
            if not self.training:
                x = x.squeeze()
            x = x.unsqueeze(dim=-1)

            x_conv = self.conv1[modal](x).view(x.shape[0],1024,-1).permute(0,2,1) 
            x = x_conv 
        # wifi x dtype = bfloat16
        image_feats = self.clip_encode_image(x, modal=modal)  
        
        bsz = int(bsz / T)
        image_feats = image_feats.reshape(
            bsz, T, *image_feats.shape[1:]).mean(dim=1)

        image_feats = self.clip_proj1[modal](image_feats)  # just a linear project

        # honeybee abstractor
        if not self.training:
            image_feats = image_feats.half()
        # print(image_feats.dtype)
        image_feats = self.abstractor(image_feats)
        # query_feat = self.connector[0](self.learnable_queries[modal].repeat(image_feats.shape[0], 1, 1), image_feats, 0, None, None)

        # for i in range(1, self.num_ca):  # the following ca layer
        #     query_feat = self.connector[i](query_feat, image_feats, 0, None, None)

        # # image_feats = torch.cat(
        # #     [self.learnable_queries[modal].repeat(bsz, 1, 1), image_feats], dim=1)

        # # image_feats = self.connector(image_feats, 0, None, None)
        # # image_feats = image_feats[:, :30]

        # image_feats = self.clip_proj2[modal](query_feat)

        return image_feats  # for depth here is float32

    def forward(self, examples, image=None, modal='image'):
        # self._destroy_kv_cache()  # training always disables kv cache
        # for honeybee先注释掉
        
        # modal = modal[0]  # change here
        modal = modal[0].split("_")[-1]  # xrf55_wifi, mmfi_wifi
        _bsz, seqlen = examples.shape
        h = self.tok_embeddings(examples)
        self.freqs_cis = self.freqs_cis.to(h.device)

        start_pos = 0
        prefix_len = 0
        if image is not None:  # dtype = bfloat16
            # bos = "Beginning of Sequence"
            h_bos, h_caption = h[:, :1], h[:, 1:]  # h_caption = [4, 2017, 4096], the size is related to LLaMA2
            # here is the Tokenizer in Paper, should be changed for mmfi and xrf dataset, or we unify them.
            # the design of adapter, also needed in here? the Tokenizer is inside the encode_image, e.g., self.conv1
            image_tokens_outputs = self.encode_image(image, modal)  
            image_tokens = image_tokens_outputs.last_hidden_state

            h = torch.cat((h_bos, self.start_tag[modal].expand(
                _bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1)
            # bos + image token + start_tag[modal], end_tag[modal] is used for caption generation
            prefix_len = image_tokens.shape[1] + 1 + 1
            seqlen = h.shape[1]

        # the freqs_cis is a kind of positional embeddings!
        # specifically, cis is the representation of complex values: cis(θ)=cos(θ)+isin(θ)
        # freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
        # mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
        # mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
        # for layer in self.layers:  # here is the LLaMA2 layers
        #     h = layer(h, start_pos, freqs_cis, mask)
        # h = self.norm(h)

        # huggingface llama2
    
        # mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
        # mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
        
        # LlamaRotaryEmbedding应该在前向传播中自己调用了
        # 但是,现在没有直接用self.language_model输出,可以实现吗
        # for layer in self.language_model.model.layers:  # here is the LLaMA2 layers
        #     h = layer(h, attention_mask=mask)
        # h = self.norm(h)
        batch_size, seq_len = h.shape[0], h.shape[1]
        mask = torch.zeros(batch_size, seqlen).cuda()
        mask = torch.triu(mask, diagonal=start_pos + 1)

        outputs = self.language_model(inputs_embeds=h, attention_mask=mask, output_hidden_states=True)

        # output = self.output(h[:, prefix_len:, :])  # 用prefix后面的特征来预测
        # output_action = self.output_action(h[:, :prefix_len, :].mean(dim=1))  # action用prefix,也就是image tokens来预测?
        
        output = outputs.logits[:, prefix_len:, :]  # 等价于之前output后的输出。
        last_hidden_state = outputs.hidden_states[-1]  # 最后的hidden states
        output_action = self.output_action(last_hidden_state[:, :prefix_len, :].mean(dim=1)) 

        return output, output_action

    @torch.inference_mode()
    def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image', past_key_values=None):
        modal = modal[0] if isinstance(modal, list) else modal
        _bsz, seqlen = tokens.shape
        # if start_pos == 0:
        #     # kv cache will not re-allocate if size is unchanged
        #     self._allocate_kv_cache(_bsz)
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)

        if image is not None:
            h_bos, h_caption = h[:, :1], h[:, 1:]
            # image_tokens = self.encode_image(image, modal)

            image_tokens_outputs = self.encode_image(image, modal)  
            image_tokens = image_tokens_outputs.last_hidden_state

            self.cache_image_words = image_tokens.shape[1]
            h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
            seqlen = h.shape[1]
            freqs_cis = self.freqs_cis[0: seqlen]
        else:
            if start_pos == 0:
                self.cache_image_words = 0
                freqs_cis = self.freqs_cis[0: seqlen]
            else:
                # if image was not None when start_pos=0,
                # the offset should be added to start_pos within later forward_inference calls
                start_pos = start_pos + self.cache_image_words
                freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

        # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        # mask = None
        # if seqlen > 1:
        #     mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
        #     mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
        
        # for layer in self.layers:
        #     h = layer(h, start_pos, freqs_cis, mask)
        # h = self.norm(h)

        batch_size, seq_len = h.shape[0], h.shape[1]
        mask = torch.zeros(batch_size, seqlen).cuda()
        mask = torch.triu(mask, diagonal=start_pos + 1)

        outputs = self.language_model(inputs_embeds=h, past_key_values=past_key_values, attention_mask=mask, output_hidden_states=True, use_cache=True)

        # output = self.output(h[:, prefix_len:, :])  # 用prefix后面的特征来预测
        # output_action = self.output_action(h[:, :prefix_len, :].mean(dim=1))  # action用prefix,也就是image tokens来预测?
        
        output = outputs.logits[:, -1, :]  # 等价于之前output后的输出。
        past_key_values = outputs.past_key_values
        # output = self.output(h[:, -1, :])  # only compute last logits

        return output.float(), past_key_values

    @torch.inference_mode()
    def forward_action_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'):
        modal = modal[0] if isinstance(modal, list) else modal
        # modal = modal.split("_")[-1]  # mmfi_video xrf55_video to video
        _bsz, seqlen = tokens.shape
        
        # if start_pos == 0:
            # kv cache will not re-allocate if size is unchanged
        #     self._allocate_kv_cache(_bsz)
        
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)

        if image is not None:
            h_bos, h_caption = h[:, :1], h[:, 1:]
            # image_tokens = self.encode_image(image, modal)

            image_tokens_outputs = self.encode_image(image, modal)  
            image_tokens = image_tokens_outputs.last_hidden_state

            self.cache_image_words = image_tokens.shape[1]
            h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
            prelen = 1 + image_tokens.shape[1] + 1
            seqlen = h.shape[1]
            freqs_cis = self.freqs_cis[0: seqlen]
        else:
            if start_pos == 0:
                self.cache_image_words = 0
                freqs_cis = self.freqs_cis[0: seqlen]
            else:
                # if image was not None when start_pos=0,
                # the offset should be added to start_pos within later forward_inference calls
                start_pos = start_pos + self.cache_image_words
                freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

        # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        # mask = None
        # if seqlen > 1:
        #     mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
        #     mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        # for layer in self.layers:
        #     h = layer(h, start_pos, freqs_cis, mask)
        # h = self.norm(h)
        # # output = self.output(h[:, -1, :])  # only compute last logits
        # image_tokens = h[:,1:prelen+1,:]  # obtain image token, include start and end
        # output_action = self.output_action(image_tokens.mean(dim=1))

        batch_size, seq_len = h.shape[0], h.shape[1]
        mask = torch.zeros(batch_size, seqlen).cuda()
        mask = torch.triu(mask, diagonal=start_pos + 1)

        outputs = self.language_model(inputs_embeds=h, attention_mask=mask, output_hidden_states=True)

        # output = outputs.logits[:, prelen:, :]  # 等价于之前output后的输出。
        last_hidden_state = outputs.hidden_states[-1]  # 最后的hidden states
        output_action = self.output_action(last_hidden_state[:, :prelen, :].mean(dim=1)) 

        return output_action.float()


    def extract_features(self, x, modal='mmfi_video'):
        bsz = x.size(0)
        T = 1

        if modal == 'mmfi_video':
            # just copy the video modality
            B, T = x.shape[:2]
            bsz = B * T
            # pos_embedding = self.clip.visual.positional_embedding[1:]  # remove cls visual embeddings

            x = x.reshape(bsz, *x.shape[2:])
            x_conv = self.clip.visual.conv1(x).view(x.shape[0],1024,-1).permute(0,2,1)  # [15, 1024, 16, 16, 3]

            x = x_conv
            
        elif modal == 'mmfi_depth':
            B, T = x.shape[:2]
            bsz = B * T
            x = x.reshape(bsz, *x.shape[2:])
            # 为了使用clip.visual.conv1, 复制channels, 并统一转为float,和open_clip的half()对上.
            x = x.repeat(1,3,1,1).float()
            if not self.training:
                x = x.half()

            # x_conv = self.clip.visual.conv1(x)
            x_conv = self.clip.visual.conv1(x).view(x.shape[0],1024,-1).permute(0,2,1)  # [15, 1024, 16, 16, 3]
            
            x = x_conv 

        elif modal == 'mmfi_lidar':
            # [B, 1024, 3] -> [B, 1024, 1024, 1]  # 第一个1024是3->1024特征维度,第二个1024是随机采样的点数。
            # 他们这里的点都是统一维度16384?
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])

            if self.training:
                x_conv, x_selected = self.conv1[modal](x.float())
                x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            else:
                x_conv, x_selected = self.conv1[modal](x.half())
                x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)

            x = x_conv 

        elif modal == 'mmfi_mmwave':
            # [B, 64, 5] -> [B, 1024, 64, 1]
            # 他们这里的点都是统一维度16384?
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])

            if self.training:
                x_conv, x_selected = self.conv1[modal](x.float())
                x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)  # 第一个1024是feature

            else:
                x_conv, x_selected = self.conv1[modal](x.half())
                x_conv = x_conv.to(x.dtype).squeeze().permute(0,2,1)
            
            x = x_conv
            
        elif modal == 'mmfi_infra':
            B, T = x.shape[:2]
            bsz = B * T
            x = x.reshape(bsz, *x.shape[2:])
            if not self.training:
                x = x.squeeze()
            x = x.float().permute(0,2,1).unsqueeze(dim=-1)
            if not self.training:
                x = x.half()

            x_conv = self.conv1[modal](x).squeeze().permute(0,2,1)

            x = x_conv

        elif modal == 'mmfi_wifi':
            B, T = x.shape[:2]
            bsz = B * T
            if not self.training:
                x = x.squeeze()
            x = x.reshape(bsz, *x.shape[2:])


            if self.training:
                x = x.float()
            else:
                x = x.half()

            x_conv = self.conv1[modal](x).view(x.shape[0],1024,-1).permute(0,2,1) 
            x = x_conv
             
        image_feats = self.clip_encode_image(x, modal=modal)  
        
        bsz = int(bsz / T)
        image_feats = image_feats.reshape(
            bsz, T, *image_feats.shape[1:]).mean(dim=1)

        image_feats_2 = self.clip_proj1[modal](image_feats)  # just a linear project

        query_feat = self.connector[0](self.learnable_queries[modal].repeat(image_feats_2.shape[0], 1, 1), image_feats_2, 0, None, None)

        for i in range(1, self.num_ca):  # the following ca layer
            query_feat = self.connector[i](query_feat, image_feats_2, 0, None, None)

        image_feats_2 = self.clip_proj2[modal](query_feat)

        return image_feats, query_feat, image_feats_2
    
    @torch.inference_mode()
    def forward_extract_features(self, tokens, image=None, modal='image'):
        modal = modal[0] if isinstance(modal, list) else modal
        h = self.tok_embeddings(tokens)
        h_bos, h_caption = h[:, :1], h[:, 1:]
        before_connector_feats, after_connector_feats, llm_image_feats = self.extract_features(image, modal)
            
        return before_connector_feats, after_connector_feats, llm_image_feats, h_caption
    
    # for honeybee先注释掉
    # def _allocate_kv_cache(self, max_batch_size: int) -> None:
    #     for layer in self.layers:
    #         layer.attention.allocate_kv_cache(
    #             max_batch_size, self.params.max_seq_len)

    # def _destroy_kv_cache(self) -> None:
    #     for layer in self.layers:
    #         layer.attention.destroy_kv_cache()
