#    Copyright 2023 Haotian Liu
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.


from abc import ABC, abstractmethod

import math
import warnings
from einops.layers.torch import Rearrange
from einops import rearrange, repeat

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from llava.model.multimodal_encoder.builder import build_vision_tower
from llava.model.multimodal_projector.builder import build_vision_projector
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import get_anyres_image_grid_shape

from .world_model_hd2 import RSSM

def xavier_init(module: nn.Module,
                gain: float = 1,
                bias: float = 0,
                distribution: str = 'normal') -> None:
    assert distribution in ['uniform', 'normal']
    if hasattr(module, 'weight') and module.weight is not None:
        if distribution == 'uniform':
            nn.init.xavier_uniform_(module.weight, gain=gain)
        else:
            nn.init.xavier_normal_(module.weight, gain=gain)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class FlowAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, x, pos=None):
        if self.channel_fuse:
            # x = rearrange(x, 'b n (h d) -> b n h d', h = self.heads)

            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))
            
            # q, k, v = map(lambda t: rearrange(t, 'b n h d -> b n h d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)
            # print('\nself fuse', q.shape, k.shape, dots.shape)

            out = torch.matmul(attn, v)
            out = self.to_out(out)
            
            if not self.with_res:
                return out  # .flatten(2,3)
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)  # .flatten(2,3)
        else:
            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\nself not fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = self.to_out(rearrange(out, 'b h n d -> b n (h d)'))
            
            if not self.with_res:
                return out
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)

class FlowCrossAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, q_ini, kv_ini, q_pos=None, k_pos=None, flatten=True):
        if self.channel_fuse:
            # q_init = rearrange(q_init, 'b n (h d) -> b n h d', h = self.heads)
            # kv_init = rearrange(kv_init, 'b n (h d) -> b n h d', h = self.heads)

            if q_pos is not None and k_pos is not None:
                # q_pos = rearrange(q_pos, 'b n (h d) -> b n h d', h = self.heads)
                # k_pos = rearrange(k_pos, 'b n (h d) -> b n h d', h = self.heads)
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\ncross fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n h d')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                out = self.res_norm(out)
            return out.flatten(2,3) if flatten else out
        else:
            if q_pos is not None and k_pos is not None:
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                return self.res_norm(out)
            return out

class SceneFlow(nn.Module):
    def __init__(self,
                 flow_patches=[1],
                 flow_ids=[0],
                 flow_attn_ids=[0],
                 flow_wm_ids=[0],
                 flow_pad_single=1,
                 flow_num=24,
                 spat_size=2,
                 in_channels=1024,
                 ):
        super(SceneFlow, self).__init__()
        self.flow_patches = flow_patches
        self.flow_ids = flow_ids
        self.flow_attn_ids = flow_attn_ids
        self.flow_wm_ids = flow_wm_ids
        self.flow_pad_single = flow_pad_single
        self.flow_pad_sum = 1 + 2 * flow_pad_single
        self.ids_pad_left_right = torch.tensor([[5,1],[0,2],[1,3],[2,4],[3,5],[4,0]], dtype=torch.long, requires_grad=False)
        self.frame_num = 1
        self.flow_num = flow_num
        self.spat_size = spat_size
        self.embed_dims = in_channels

        self.ego_pose_temporal = nn.Embedding(self.frame_num, 18)

        self.flow_mlps = nn.ModuleList([
            nn.Linear(self.flow_pad_sum*self.flow_patches[flow_id], self.flow_patches[flow_id]) for flow_id in self.flow_ids
        ])
        self.flow_attns = nn.ModuleList([
            FlowAttention(dim=in_channels, heads=4, dim_head=in_channels//4, dropout=0.) for flow_attn_id in self.flow_attn_ids
        ])

        self.world_models_temporal = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.world_models_spatial = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.patch_selfattn = FlowAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_local = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.patch_crossattn_temporal = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_spatial = FlowCrossAttention(dim=in_channels, heads=self.frame_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.flow_wm_convs = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])
        self.flow_wm_convs_spatial = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])

        # self.can_bus_mlp = nn.Sequential(
        #     nn.Linear(18, self.embed_dims // 2),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(self.embed_dims // 2, self.embed_dims),
        #     nn.ReLU(inplace=True),
        #     nn.LayerNorm(self.embed_dims)
        # )
        # xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)

        self.flow_query_xys = self.generate_flow_xys()
        self.flow_query_embed = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_query_embed, distribution='uniform', bias=0.)

        self.flow_spatial_pe_map = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_spatial_pe_map, distribution='uniform', bias=0.)

        self.flow_temporal_pe_map = nn.Sequential(
            nn.Linear(18, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_temporal_pe_map, distribution='uniform', bias=0.)
    
    def generate_flow_xys(self,):
        pc_range = 51.2
        tan_30 = math.tan(math.radians(30))
        flow_items = self.flow_num
        p_x = pc_range * tan_30

        f_range_x = [-p_x, p_x]
        f_range_y = [pc_range, pc_range]

        fr_range_x = [p_x, pc_range]
        fr_range_y = [pc_range, 0]

        br_range_x = [pc_range, p_x]
        br_range_y = [0, -pc_range]

        b_range_x = [-p_x, p_x]
        b_range_y = [-pc_range, -pc_range]

        bl_range_x = [-p_x, -pc_range]
        bl_range_y = [-pc_range, 0]

        fl_range_x = [-pc_range, -p_x]
        fl_range_y = [0, pc_range]

        flow_xys_all = []
        for range_x, range_y in zip([f_range_x, fr_range_x, br_range_x, b_range_x, bl_range_x, fl_range_x],
                                    [f_range_y, fr_range_y, br_range_y, b_range_y, bl_range_y, fl_range_y]):
            step_x = (range_x[1]-range_x[0])/flow_items
            step_y = (range_y[1]-range_y[0])/flow_items
            xys = torch.stack([
                torch.arange(range_x[0], range_x[1], step_x)+step_x/2 if step_x != 0 else torch.ones(flow_items)*range_x[0],
                torch.arange(range_y[0], range_y[1], step_y)+step_y/2 if step_y != 0 else torch.ones(flow_items)*range_y[0],
                                ], dim=-1)
            flow_xys_all.append(xys)

        flow_xys_all = torch.stack(flow_xys_all, dim=0)
        return flow_xys_all
    
    def get_forward_point(self, sdc_trajs):
        ego_xys = [one[1] for one in sdc_trajs]
        forward_points = []
        # tan_30 = math.tan(math.radians(30))
        tan_60 = math.tan(math.radians(60))
        half = 51.2 / tan_60
        step = half / self.flow_num * 2
        for ego_xy in ego_xys:
            if sum(ego_xy) == 0:
                forward_points.append(11)
                continue
            if ego_xy[0] >= 0:
                left_right = 1
            else:
                left_right = 0
            tan = abs(ego_xy[1]/ego_xy[0])
            if tan >= tan_60:
                forward_point = (51.2/tan + half)//step
            else:
                if left_right:
                    x = 51.2 * tan_60 / (tan + tan_60)
                    y = x * tan
                    delta = math.sqrt((x-51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = 22 + delta//step
                else:
                    x = 51.2 * tan_60 / (-tan - tan_60)
                    y = - x * tan
                    delta = math.sqrt((x+51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = - delta//step
            forward_points.append(int(forward_point))
        return forward_points[0]


    def forward(self, x):
        BN, L, C = x.shape
        N = 6
        B = BN//N
        H = W = int(math.sqrt(L))
        x = x.transpose(1,2).reshape(B, N, C, H, W)

        # can_bus = mlvl_feats[0].new_tensor(
        #     [each['ego_pose'][0] for each in img_metas])  # [:, :]
        # can_bus = self.can_bus_mlp(can_bus)[:, None, :]

        mlvl_feats = [x]
        device = mlvl_feats[0].device
        B, TN, C, _, _ = mlvl_feats[0].shape
        T = TN // N
        flow_spatial_pe = self.flow_spatial_pe_map(self.flow_query_xys.to(device)).unsqueeze(0)  # [1, N, 22, C]
        flow_spatial_pe_local = rearrange(flow_spatial_pe.repeat(B*T, 1, 1, 1), 'b n y c -> (b n) y c')  # [BTN, 22, C]

        # pad spatial PE for 3*patch self-attn
        flow_spatial_pe_3p = flow_spatial_pe_local.transpose(1,2).reshape(B, T, N, C, 1, self.flow_num)
        pe_pad_left = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single:]
        pe_pad_right = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single]
        # [B, T, N, C, 1, 24]
        flow_spatial_pe_3p = torch.cat([pe_pad_left, flow_spatial_pe_3p, pe_pad_right], dim=-1)
        # [BTN, C*1*3, 1*22]
        flow_spatial_pe_3p_unfold = F.unfold(flow_spatial_pe_3p.flatten(0,2), kernel_size=(1, self.flow_pad_sum), stride=1)
        # [BTN, 22, 3, C]
        flow_spatial_pe_3p = rearrange(flow_spatial_pe_3p_unfold, 'b (c l) y -> b y l c', c=C, l=self.flow_pad_sum)

        # ego_pose_temporal = torch.from_numpy(
        #     np.stack([each['ego_pose'] for each in img_metas], axis=0)).to(device).to(flow_spatial_pe.dtype)
        ego_pose_temporal = self.ego_pose_temporal.weight
        flow_temporal_pe = self.flow_temporal_pe_map(ego_pose_temporal)  # [B, T, C]
        # flow_temporal_pe_local = flow_temporal_pe.reshape(B, T, 1, 1, C).repeat(1, 1, N, self.flow_num, 1).flatten(0,2)  # [BTN, 22, C]
        flow_temporal_pe_local = flow_temporal_pe.reshape(1, 1, 1, 1, C).repeat(B, T, N, self.flow_num, 1).flatten(0,2)  # [BTN, 22, C]

        outs = {}
        outputs_wm_temporal = []
        outputs_wm_spatial = []

        mlvl_feats_flow = []
        mlvl_feats_flow_unfold = []
        mlvl_feats_flow_unfold_fuse = []

        H_W_Ps = []
        for idx, mlvl_feat in enumerate(mlvl_feats):
            if idx not in self.flow_ids:
                mlvl_feats_flow.append(mlvl_feat)
                continue
            flow_patch = self.flow_patches[idx]
            B, TN, C, H, W = mlvl_feat.shape
            H_W_Ps.append([H,W,flow_patch])
            mlvl_feat_reshaped = mlvl_feat.reshape(B, T, N, C, H, W)

            # pad two sides of each view
            feat_pad_left = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single*flow_patch:]
            feat_pad_right = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single*flow_patch]
            mlvl_feat_pad_reshaped = torch.cat([feat_pad_left, mlvl_feat_reshaped, feat_pad_right], dim=-1)

            # [BTN, C*H*3P, 1*22]
            mlvl_feat_pad_unfold = F.unfold(mlvl_feat_pad_reshaped.flatten(0,2), kernel_size=(H, self.flow_pad_sum*flow_patch), stride=flow_patch)
            # [BTN, C, H, 3P, 22]
            mlvl_feat_flow_unfold = mlvl_feat_pad_unfold.reshape(B*TN, C, H, self.flow_pad_sum*flow_patch, W//flow_patch)

            if idx in self.flow_attn_ids:
                # interaction inside 3 patches
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, 'b c h p y -> (b h y) p c')
                # [BTN, 22, 3, C] -> [BTN * H * 22, 3P, C]
                flow_spatial_pe_3p_selfattn = flow_spatial_pe_3p.unsqueeze(3).unsqueeze(1).repeat(1,H,1,1,flow_patch,1).flatten(0,2).flatten(1,2)
                # [BTN * H * 22, 3P, C]
                mlvl_feat_flow_unfold = self.flow_attns[self.flow_attn_ids.index(idx)](mlvl_feat_flow_unfold, pos=flow_spatial_pe_3p_selfattn)
                # [BTN, C, H, 3P, 22]
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, '(b h y) p c -> b c h p y', b=B*TN, h=H, y=W//flow_patch)

            # [BTN, C, H, 22, 3P] -> [BTN, C, H, 22, P]
            mlvl_feat_flow_unfold = self.flow_mlps[self.flow_ids.index(idx)](mlvl_feat_flow_unfold.transpose(3,4))
            mlvl_feats_flow_unfold.append(mlvl_feat_flow_unfold)
            # [BTN, C, 1, 22, patch]
            mlvl_feats_flow_unfold_fuse.append(mlvl_feat_flow_unfold.mean(dim=2))
        
        # [BTN, C, 22, 8+4+2+1]
        mlvl_feats_flow_unfold_fuse = torch.cat(mlvl_feats_flow_unfold_fuse, dim=-1)
        # [BTN, 22, C] -> [BTN, 1, 22, C]
        flow_spatial_pe_selfattn = flow_spatial_pe_local.unsqueeze(1)
        # [BTN, 8+4+2+1, 22, C]
        mlvl_feats_flow_unfold_fuse = self.patch_selfattn(rearrange(mlvl_feats_flow_unfold_fuse, 'b c y p -> b p y c'), pos=flow_spatial_pe_selfattn)

        # [N, 22, C] -> [BTN, 22, C]
        flow_query = self.flow_query_embed(self.flow_query_xys.to(mlvl_feat.device)).unsqueeze(0).repeat(B*T, 1, 1, 1).flatten(0,1)
        # get local messages from [8+4+2+1] of [BTN, 8+4+2+1, 22, C]
        flow_query = self.patch_crossattn_local(flow_query.unsqueeze(1), mlvl_feats_flow_unfold_fuse).reshape(B*TN, W//flow_patch, C)
        # print('output', flow_query.shape)

        # get temporal messages from [T] of [BTN, 8+4+2+1, 22, C]
        # prompt for temporal world model
        flow_temporal_pe_crossattn = rearrange(flow_temporal_pe_local, '(b t n) y c -> (b n) t y c', b=B, t=T, n=6)
        # [BN, T, 22, C]
        flow_query_temporal = self.patch_crossattn_temporal(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_temporal_pe_crossattn,
            k_pos=flow_temporal_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_temporal.shape)

        # get spatial messages from [22] of [BTN, 8+4+2+1, 22, C]
        # prompt for spatial world model
        flow_spatial_pe_crossattn = rearrange(flow_spatial_pe_local, '(b t n) y c -> (b n) y t c', b=B, t=T, n=6)
        # [BN, 22, T, C]
        flow_query_spatial = self.patch_crossattn_spatial(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_spatial_pe_crossattn,
            k_pos=flow_spatial_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_spatial.shape)

        # split left/right spatial flow from ego forward_point
        # sdc_trajs = [each['sdc_traj'] for each in img_metas]
        # forward_point = self.get_forward_point(sdc_trajs)
        forward_point = 11
        flow_query_spatial = rearrange(flow_query_spatial, '(b n) y t c -> (b t n) y c', b=B, t=T, n=6)
        flow_query_spat = rearrange(flow_query_spatial, '(b t n) y c -> b (n y) t c', b=B, t=T, n=6)
        if forward_point >= 0:
            flow_query_spat_l = torch.cat([flow_query_spat[:, forward_point+3*self.flow_num:], flow_query_spat[:, :forward_point]], dim=1).flip(dims=[1])
            flow_query_spat_r = flow_query_spat[:, forward_point:forward_point+3*self.flow_num]
        else:
            flow_query_spat_r = torch.cat([flow_query_spat[:, forward_point:], flow_query_spat[:, :forward_point+3*self.flow_num:]], dim=1)
            flow_query_spat_l = flow_query_spat[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
        # [B, 2, 66, T, C]
        flow_query_spat = torch.stack([flow_query_spat_l, flow_query_spat_r], dim=1)
        flow_query_spat = rearrange(flow_query_spat, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)

        for idx, (mlvl_feat_flow_unfold, (H, W, flow_patch)) in enumerate(zip(mlvl_feats_flow_unfold, H_W_Ps)):
            if idx in self.flow_wm_ids:
                layer_idx = self.flow_wm_ids.index(idx)

                # temporal world model
                # [BTN, C, H, W/patch, patch]
                output_wm, mlvl_feat_flow_unfold_fut = self.world_models_temporal[layer_idx](
                    rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> (b n) t y c', b=B, t=T, n=6).flip(dims=[1]),
                    flow_query_temporal.flip(dims=[1]))
                outputs_wm_temporal.append(output_wm)
                
                # enhance temporal messages with temporal prediction
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        rearrange(mlvl_feat_flow_unfold_fut.flip(dims=[1]), '(b n) t y c -> (b t n) y c', b=B, n=6, y=W//flow_patch).unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs[layer_idx](mlvl_feat_flow_fold)

                mlvl_feat_flow_unfold = F.unfold(mlvl_feat_flow_fold, kernel_size=(H, flow_patch), stride=flow_patch).reshape(B*TN, C, H, flow_patch, W//flow_patch).transpose(3,4)  # [btn, chp, y]
                
                # split left/right spatial flow from ego forward_point
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> b (n y) t c', b=B, t=T, n=6)
                # mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                # mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                if forward_point >= 0:
                    mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                    mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                else:
                    mlvl_feat_flow_unfold_spatial_r = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point:], mlvl_feat_flow_unfold_spatial[:, :forward_point+3*self.flow_num]], dim=1)
                    mlvl_feat_flow_unfold_spatial_l = mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
                # [B, 2, 66, T, C]
                mlvl_feat_flow_unfold_spatial = torch.stack([mlvl_feat_flow_unfold_spatial_l, mlvl_feat_flow_unfold_spatial_r], dim=1)
                # split with spatial patch (len 6)
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold_spatial, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)
                
                # spatial world model
                # [BTN, C, H, W/patch, patch]
                output_wm_spatial, mlvl_feat_flow_unfold_fut_spatial = self.world_models_spatial[layer_idx](
                    mlvl_feat_flow_unfold_spatial,
                    flow_query_spat)
                outputs_wm_spatial.append(output_wm_spatial)

                # enhance spatial messages with spatial prediction
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, '(b n p) y t c -> b n (y p) t c', b=B, n=2, p=self.spat_size)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,1], mlvl_feat_flow_unfold_fut_spatial[:,0].flip(dims=[1])], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,-forward_point:], mlvl_feat_flow_unfold_fut_spatial[:,:-forward_point]], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, 'b (n y) t c -> (b t n) y c', n=6, y=self.flow_num)
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        mlvl_feat_flow_unfold_fut_spatial.unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch*3, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs_spatial[layer_idx](mlvl_feat_flow_fold)
            else:
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
            mlvl_feats_flow.append(mlvl_feat_flow_fold.reshape(B, TN, C, H, W))
        x = mlvl_feats_flow[0].flatten(0,1).flatten(2,3).transpose(1,2)
        return x

class img_adapter(nn.Module):
    def __init__(self, img_token_num, new_token_num):
        super(img_adapter, self).__init__()
        self.img_token_num = img_token_num
        self.new_token_num = new_token_num

        self.adapter = nn.Sequential(
            nn.Linear(img_token_num, new_token_num),
            nn.GELU(),
            nn.Linear(new_token_num, new_token_num),
            nn.GELU(),
            nn.Linear(new_token_num, new_token_num))


    def forward(self, x):
        x = x.transpose(1, 2)  # [B, D, N]
        x = self.adapter(x)  # [B, D, M]
        x = x.transpose(1, 2)  # [B, M, D]
        
        return x


class IdentityMap(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

    @property
    def config(self):
        return {"img_adapter_type": 'identity'}


class LlavaMetaModel:

    def __init__(self, config):
        super(LlavaMetaModel, self).__init__(config)

        if hasattr(config, "mm_vision_tower"):
            self.vision_tower = build_vision_tower(config, delay_load=True)
            self.mm_projector = build_vision_projector(config)

            if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
                self.image_newline = nn.Parameter(
                    torch.empty(config.hidden_size, dtype=self.dtype)
                )

        # self.img_adapter = IdentityMap()
        new_token_num = 128  # token per view
        self.img_adapter = self.build_img_adapter(new_token_num=new_token_num)

        self.scene_flow = self.build_scene_flow()

    def build_scene_flow(self):
        return SceneFlow()

    def build_img_adapter(self, img_token_num=576, new_token_num=128):

        return img_adapter(img_token_num=img_token_num,
                           new_token_num=new_token_num)


    def get_vision_tower(self):
        vision_tower = getattr(self, 'vision_tower', None)
        if type(vision_tower) is list:
            vision_tower = vision_tower[0]
        return vision_tower

    def initialize_vision_modules(self, model_args, fsdp=None):
        vision_tower = model_args.vision_tower
        mm_vision_select_layer = model_args.mm_vision_select_layer
        mm_vision_select_feature = model_args.mm_vision_select_feature
        pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
        mm_patch_merge_type = model_args.mm_patch_merge_type

        self.config.mm_vision_tower = vision_tower

        if self.get_vision_tower() is None:
            vision_tower = build_vision_tower(model_args)

            if fsdp is not None and len(fsdp) > 0:
                self.vision_tower = [vision_tower]
            else:
                self.vision_tower = vision_tower
        else:
            if fsdp is not None and len(fsdp) > 0:
                vision_tower = self.vision_tower[0]
            else:
                vision_tower = self.vision_tower
            vision_tower.load_model()

        self.config.use_mm_proj = True
        self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
        self.config.mm_hidden_size = vision_tower.hidden_size
        self.config.mm_vision_select_layer = mm_vision_select_layer
        self.config.mm_vision_select_feature = mm_vision_select_feature
        self.config.mm_patch_merge_type = mm_patch_merge_type

        if getattr(self, 'mm_projector', None) is None:
            self.mm_projector = build_vision_projector(self.config)

            if 'unpad' in mm_patch_merge_type:
                embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
                self.image_newline = nn.Parameter(
                    torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
                )
        else:
            # In case it is frozen by LoRA
            for p in self.mm_projector.parameters():
                p.requires_grad = True

        if pretrain_mm_mlp_adapter is not None:
            mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
            def get_w(weights, keyword):
                return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}

            self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))


def unpad_image(tensor, original_size):
    """
    Unpads a PyTorch tensor of a padded and resized image.

    Args:
    tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
    original_size (tuple): The original size of PIL image (width, height).

    Returns:
    torch.Tensor: The unpadded image tensor.
    """
    original_width, original_height = original_size
    current_height, current_width = tensor.shape[1:]

    original_aspect_ratio = original_width / original_height
    current_aspect_ratio = current_width / current_height

    if original_aspect_ratio > current_aspect_ratio:
        scale_factor = current_width / original_width
        new_height = int(original_height * scale_factor)
        padding = (current_height - new_height) // 2
        unpadded_tensor = tensor[:, padding:current_height - padding, :]
    else:
        scale_factor = current_height / original_height
        new_width = int(original_width * scale_factor)
        padding = (current_width - new_width) // 2
        unpadded_tensor = tensor[:, :, padding:current_width - padding]

    return unpadded_tensor


class LlavaMetaForCausalLM(ABC):

    @abstractmethod
    def get_model(self):
        pass

    def get_vision_tower(self):
        return self.get_model().get_vision_tower()

    def encode_images(self, images):
        image_features = self.get_model().get_vision_tower()(images)
        image_features = self.get_model().mm_projector(image_features)
        return image_features

    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        images, image_sizes=None
    ):
        vision_tower = self.get_vision_tower()
        if vision_tower is None or images is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels

        if type(images) is list or images.ndim == 5:
            if type(images) is list:
                images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
            concat_images = torch.cat([image for image in images], dim=0)
            image_features = self.encode_images(concat_images)
            split_sizes = [image.shape[0] for image in images]
            image_features = torch.split(image_features, split_sizes, dim=0)
            mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
            image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
            if mm_patch_merge_type == 'flat':
                image_features = [x.flatten(0, 1) for x in image_features]
            elif mm_patch_merge_type.startswith('spatial'):
                new_image_features = []
                for image_idx, image_feature in enumerate(image_features):
                    if image_feature.shape[0] > 1:
                        base_image_feature = image_feature[0]
                        image_feature = image_feature[1:]
                        height = width = self.get_vision_tower().num_patches_per_side
                        assert height * width == base_image_feature.shape[0]
                        if image_aspect_ratio == 'anyres':
                            num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
                            image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
                        else:
                            raise NotImplementedError
                        if 'unpad' in mm_patch_merge_type:
                            image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
                            image_feature = image_feature.flatten(1, 2).flatten(2, 3)
                            image_feature = unpad_image(image_feature, image_sizes[image_idx])
                            image_feature = torch.cat((
                                image_feature,
                                self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
                            ), dim=-1)
                            image_feature = image_feature.flatten(1, 2).transpose(0, 1)
                        else:
                            image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
                            image_feature = image_feature.flatten(0, 3)
                        image_feature = torch.cat((base_image_feature, image_feature), dim=0)
                    else:
                        image_feature = image_feature[0]
                        if 'unpad' in mm_patch_merge_type:
                            image_feature = torch.cat((
                                image_feature,
                                self.model.image_newline[None].to(image_feature.device)
                            ), dim=0)
                    new_image_features.append(image_feature)
                image_features = new_image_features
            else:
                raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
        else:
            image_features = self.encode_images(images)

        # TODO: image start / end is not implemented here to support pretraining.
        if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
            raise NotImplementedError

        # Let's just add dummy tensors if they do not exist,
        # it is a headache to deal with None all the time.
        # But it is not ideal, and if you have a better idea,
        # please open an issue / submit a PR, thanks.
        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        # remove the padding using attention_mask -- FIXME
        _input_ids = input_ids
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]

        new_input_embeds = []
        new_labels = []
        cur_image_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_image_features = image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_image_idx += 1
                continue

            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []

            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                if i < num_images:
                    cur_image_features = image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)
                    cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))

            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            cur_new_labels = torch.cat(cur_new_labels)

            new_input_embeds.append(cur_new_input_embeds)
            new_labels.append(cur_new_labels)

        # Truncate sequences to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]

        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)

        new_input_embeds_padded = []
        new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)

        for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
                new_input_embeds_padded.append(torch.cat((
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
                    cur_new_embed
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, -cur_len:] = cur_new_labels
                    attention_mask[i, -cur_len:] = True
                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
            else:
                new_input_embeds_padded.append(torch.cat((
                    cur_new_embed,
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, :cur_len] = cur_new_labels
                    attention_mask[i, :cur_len] = True
                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)

        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)

        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        if _position_ids is None:
            position_ids = None

        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels

    def initialize_vision_tokenizer(self, model_args, tokenizer):
        if model_args.mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

        if model_args.mm_use_im_start_end:
            num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
            self.resize_token_embeddings(len(tokenizer))

            if num_new_tokens > 0:
                input_embeddings = self.get_input_embeddings().weight.data
                output_embeddings = self.get_output_embeddings().weight.data

                input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)
                output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
                    dim=0, keepdim=True)

                input_embeddings[-num_new_tokens:] = input_embeddings_avg
                output_embeddings[-num_new_tokens:] = output_embeddings_avg

            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = True
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False

            if model_args.pretrain_mm_mlp_adapter:
                mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
                embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
                assert num_new_tokens == 2
                if input_embeddings.shape == embed_tokens_weight.shape:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
                elif embed_tokens_weight.shape[0] == num_new_tokens:
                    input_embeddings[-num_new_tokens:] = embed_tokens_weight
                else:
                    raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
        elif model_args.mm_use_im_patch_token:
            if model_args.tune_mm_mlp_adapter:
                for p in self.get_input_embeddings().parameters():
                    p.requires_grad = False
                for p in self.get_output_embeddings().parameters():
                    p.requires_grad = False
