"""
Wrapper around the joint model (mixtures). Siglip from PaliGemma, action-time encoder, proprio encoder, action decoder. Flow matching training

Generates causal masking for the mixtures

Potentially customized to add/remove mixtures, e.g., remove proprio or add another vision module

"""

import logging
from typing import Optional, Tuple

import hydra
import torch
from torch import nn
import copy
from src.model.kv_cache import KVCache
from src.model.vla.modules import (
    ActionEncoder,
    SinusoidalPosEmb,
)
from src.utils.decorator import NoSyncBase
from src.utils.monitor import log_execution_time

from torch.linalg import inv
import torchvision.transforms.functional as TF
import torch.nn.functional as F

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.suppress_errors = True

SIGLIP_MEAN, SIGLIP_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
ZOE_MEAN, ZOE_STD = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

from transformers import AutoTokenizer

import logging
from src.model.vla.processing import VLAProcessor
import hydra
import numpy as np
import pretty_errors
from PIL import Image
from omegaconf import OmegaConf, open_dict
from torch.utils.data import DataLoader
from torch.linalg import inv
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import transformers
from transformers import (
    AutoTokenizer,)
# from transformers import DINOv3ViTModel, DINOv3ViTConfig, DINOv3ViTImageProcessorFast
from src.model.vla.processing import VLAProcessor

log = logging.getLogger(__name__)


class UniCoD(nn.Module, NoSyncBase):

    @log_execution_time(log)
    def __init__(self, cfg, use_ddp: bool = False):
        super().__init__()
        self.cfg = cfg
        self.use_ddp = use_ddp  # used in NoSyncBase
        self.vocab_size = cfg.vocab_size
        self.pad_token_id = cfg.pad_token_id
        self.image_token_index = cfg.image_token_index
        self.use_lm_head = cfg.get("use_lm_head", False)

        self.max_image_text_tokens = cfg.max_image_text_tokens
        self.pred_dino = cfg.pred_dino
        self.num_dino_tokens = cfg.num_dino_tokens if self.pred_dino else 0
        self.dino_type = cfg.dino_type
        if not self.pred_dino:
            self.cfg.mixture.pop("dino")
            # self.cfg.joint.config.mixture.pop("dino")
        self.num_proprio_tokens = cfg.cond_steps
        self.num_action_tokens = cfg.horizon_steps
        self.total_num_tokens = (
            self.max_image_text_tokens + self.num_proprio_tokens + self.num_action_tokens + self.num_dino_tokens)

        self.image_text_hidden_size = cfg.mixture.vlm.hidden_size
        self.proprio_hidden_size = cfg.mixture.proprio.hidden_size
        self.action_hidden_size = cfg.mixture.action.hidden_size

        # Action parameterization
        self.num_inference_steps = cfg.num_inference_steps
        self.horizon_steps = cfg.horizon_steps
        self.action_dim = cfg.action_dim
        self.proprio_dim = cfg.proprio_dim
        self.final_action_clip_value = cfg.final_action_clip_value
        self.flow_sig_min = cfg.get("flow_sig_min", 0.001)

        # text input only
        self.embed_tokens = nn.Embedding(
            cfg.vocab_size,
            self.image_text_hidden_size,
            self.pad_token_id,
        )  # 0.527B parameters

        # Vision
        self.vision_tower = hydra.utils.instantiate(cfg.vision)
        self.multi_modal_projector = hydra.utils.instantiate(cfg.vision_projector)

        # Mixtures
        self.joint_model = hydra.utils.instantiate(cfg.joint)

        # Action, proprio, time encoders
        self.action_expert_adaptive_mode = cfg.action_expert_adaptive_mode
        if cfg.action_expert_adaptive_mode:  # adaLN or adaLN-Zero
            self.action_encoder = ActionEncoder(
                self.action_dim,
                self.action_hidden_size,
                time_cond=False,
            )
            self.time_embedding = SinusoidalPosEmb(cfg.time_hidden_size, cfg.time_max_period)
        else:  # matching pi0
            self.action_encoder = ActionEncoder(
                self.action_dim,
                self.action_hidden_size,
                time_cond=True,
            )
            self.time_embedding = SinusoidalPosEmb(self.action_hidden_size, cfg.time_max_period)
        self.proprio_encoder = nn.Linear(
            self.proprio_dim,
            self.proprio_hidden_size,
        )

        # Action decoder
        self.action_decoder = nn.Linear(
            self.action_hidden_size,
            self.action_dim,
        )

        # optional text output
        if self.use_lm_head:
            self.lm_head = nn.Linear(
                self.image_text_hidden_size,
                self.vocab_size,
                bias=False,
            )
            self.lm_head.weight = self.embed_tokens.weight  # tie weights

        if self.pred_dino:
            if self.dino_type == "siglip":
                self.dino_model = hydra.utils.instantiate(cfg.vision)
            elif self.dino_type == "siglip_imgloss":
                self.dino_model = hydra.utils.instantiate(cfg.vision)
                # 先聚合seq_len维度，然后输出为 (batch, 3, 224, 224)
                self.din_decoder = nn.Sequential(
                    # 将seqlen聚合成1，输出形状batch,hidden，利用1维cnn
                    nn.Conv1d(cfg.num_dino_tokens, 1, kernel_size=1),
                    nn.Flatten(start_dim=1),
                    nn.Linear(cfg.mixture.dino.hidden_size, 1024),
                    nn.GELU(),
                    nn.Linear(1024, 3 * 224 * 224),
                    nn.Tanh(),  # 输出范围为 [-1, 1]，可根据需要调整
                    nn.Unflatten(-1, (3, 224, 224)))

            elif self.dino_type == "siglip_hyc_wo_enc":
                self.dino_model = hydra.utils.instantiate(cfg.vision)
                self.siglip = self.dino_model
            elif self.dino_type == "siglip_hyc_w_encoder_decoder":
                self.siglip = hydra.utils.instantiate(cfg.vision)
                self.dino_model = self.siglip  # match hyc's code for loading ckpt
                self.din_encoder = nn.Linear(cfg.dino_hidden_size, cfg.mixture.dino.hidden_size)
                self.din_decoder = nn.Linear(cfg.mixture.dino.hidden_size, cfg.dino_hidden_size)
            elif self.dino_type == "dino":
                from transformers import DINOv3ViTModel, DINOv3ViTConfig, DINOv3ViTImageProcessorFast
                self.dino_config = DINOv3ViTConfig.from_pretrained(cfg.dino_model_path, local_files_only=True)
                self.dino_model = DINOv3ViTModel.from_pretrained(
                    cfg.dino_model_path, config=self.dino_config, local_files_only=True).eval()
            elif self.dino_type == "dino_future_distill":
                from transformers import DINOv3ViTModel, DINOv3ViTConfig, DINOv3ViTImageProcessorFast
                self.dino_config = DINOv3ViTConfig.from_pretrained(cfg.dino_model_path, local_files_only=True)
                self.dino_model = DINOv3ViTModel.from_pretrained(
                    cfg.dino_model_path, config=self.dino_config, local_files_only=True).eval()
                # self.din_encoder = nn.Linear(4096, cfg.mixture.dino.hidden_size)
                # self.din_decoder = nn.Linear(cfg.mixture.dino.hidden_size, 4096)

            for param in self.dino_model.parameters():
                param.requires_grad = False
            # self.dino_processor = DINOv3ViTImageProcessorFast.from_pretrained(
            #     cfg.dino_model_path, local_files_only=True)
            # self.din_encoder = nn.Linear(4096, cfg.mixture.dino.hidden_size)
            # self.din_decoder = nn.Linear(cfg.mixture.dino.hidden_size, 4096)
            # self.din_encoder = nn.Linear(self.image_text_hidden_size, cfg.mixture.dino.hidden_size)
            # self.din_decoder = nn.Linear(cfg.mixture.dino.hidden_size, self.image_text_hidden_size)

    @property
    def action_expert_parameters(self):
        return (list(self.action_encoder.parameters()) + list(self.action_decoder.parameters()) +
                list(self.proprio_encoder.parameters()) + list(self.joint_model.mixtures["action"].parameters())
               )  # note: action and proprio share weights

    @property
    def trainable_dino_expert_parameters(self):
        # return (list(self.din_encoder.parameters()) + list(self.din_decoder.parameters()))
        if self.dino_type == "siglip" or self.dino_type == "siglip_hyc_wo_enc":
            return (list(self.joint_model.mixtures["dino"].parameters()))
        if self.dino_type == "siglip_imgloss":
            return (list(self.joint_model.mixtures["dino"].parameters()) + list(self.din_decoder.parameters()))
        elif self.dino_type == "siglip_hyc_w_encoder_decoder":
            return (list(self.joint_model.mixtures["dino"].parameters()) + list(self.din_encoder.parameters()) +
                    list(self.din_decoder.parameters()))
        elif self.dino_type == "dino":
            return (list(self.joint_model.mixtures["dino"].parameters()))
        elif self.dino_type == "dino_future_distill":
            return (list(self.joint_model.mixtures["dino"].parameters()))
            # return (list(self.joint_model.mixtures["dino"].parameters()) + list(self.din_encoder.parameters()) +
            #         list(self.din_decoder.parameters()))
        else:
            raise NotImplementedError

    @property
    def trainable_vlm_parameters(self):
        return (list(self.vision_tower.parameters()) + list(self.multi_modal_projector.parameters()) +
                self.trainable_gemma_parameters)

    @property
    def lora_trainable_vlm_parameters(self):
        params = []
        for name, param in self.vision_tower.named_parameters():
            if "lora_" in name:
                params.append(param)
        for name, param in self.multi_modal_projector.named_parameters():
            if "lora_" in name:
                params.append(param)
        params.extend(self.trainable_lora_gemma_parameters)
        return params

    @property
    def trainable_gemma_parameters(self):
        gemma_parameters = []
        for name, param in self.joint_model.mixtures["vlm"].named_parameters():
            if not self._check_gemma_unused_parameter_by_name(name):
                gemma_parameters.append(param)
        return gemma_parameters

    @property
    def trainable_lora_gemma_parameters(self):
        gemma_parameters = []
        for name, param in self.joint_model.mixtures["vlm"].named_parameters():
            if not self._check_gemma_unused_parameter_by_name(name):
                if "lora_" in name:
                    gemma_parameters.append(param)
        return gemma_parameters

    @log_execution_time(log)
    def load_pretrained_weights(self):
        """vision, projector, lm from paligemma"""
        import glob
        import os

        from safetensors import safe_open

        # load tensors from files
        safetensors_files = glob.glob(os.path.join(self.cfg.pretrained_model_path, "*.safetensors"))
        tensors = {}
        for safetensors_file in safetensors_files:
            with safe_open(safetensors_file, framework="pt", device="cpu") as f:
                for key in f.keys():
                    tensors[key] = f.get_tensor(key)

        # load embed tokens
        embed_tokens_state_dict = self.embed_tokens.state_dict()
        for k, v in tensors.items():
            if "embed_tokens" in k:
                new_key = k.replace("language_model.model.embed_tokens.", "")
                embed_tokens_state_dict[new_key] = v
        self.embed_tokens.load_state_dict(embed_tokens_state_dict, strict=True)
        log.info("Loaded pre-trained weights for embed tokens")

        # load vision tower --- "vision_tower.vision_model" -> "vision_model"
        vision_tower_state_dict = self.vision_tower.state_dict()
        for k, v in tensors.items():
            if "vision_tower" in k:
                new_key = k.replace("vision_tower.", "")
                vision_tower_state_dict[new_key] = v
        self.vision_tower.load_state_dict(vision_tower_state_dict, strict=True)
        log.info("Loaded pre-trained weights for vision tower")
        if self.pred_dino and type(self.vision_tower) == type(
                self.dino_model) and (self.dino_type == "siglip" or self.dino_type == "siglip_hyc_w_encoder_decoder" or
                                      self.dino_type == "siglip_hyc_wo_enc" or self.dino_type == "siglip_imgloss"):
            self.dino_model.load_state_dict(vision_tower_state_dict, strict=True)
            log.info("Loaded pre-trained weights for dino model")
            # print("dino model parameters requires_grad: ", next(self.dino_model.parameters()).requires_grad)
            # print("vision tower parameters requires_grad: ", next(self.vision_tower.parameters()).requires_grad)

        # load projector --- "multi_modal_projector.linear" -> "linear"
        multi_modal_projector_state_dict = self.multi_modal_projector.state_dict()
        for k, v in tensors.items():
            if "multi_modal_projector" in k:
                new_key = k.replace("multi_modal_projector.", "")
                multi_modal_projector_state_dict[new_key] = v
        self.multi_modal_projector.load_state_dict(multi_modal_projector_state_dict, strict=True)
        log.info("Loaded pre-trained weights for projector")

        # load lm --- do not change any lora weights
        joint_model_state_dict = self.joint_model.state_dict()
        lora_keys = []
        for key in (joint_model_state_dict.keys()):  # avoid RuntimeError: OrderedDict mutated during iteration
            if "lora_" in key:
                lora_keys.append(key)
        for key in lora_keys:
            del joint_model_state_dict[key]
        for k, v in tensors.items():
            if "language_model.model" in k:
                new_key = k.replace("language_model.model.", "mixtures.vlm.")
                joint_model_state_dict[new_key] = v
        self.joint_model.load_state_dict(joint_model_state_dict, strict=False)
        log.info("Loaded pre-trained weights for lm part of the joint model")

    def _check_gemma_unused_parameter_by_name(self, name: str) -> bool:
        """no need to train vlm parameters after attention of last layer"""
        last_hidden_layer_index = self.joint_model.num_hidden_layers - 1
        if (f"{last_hidden_layer_index}.post" in name or f"{last_hidden_layer_index}.mlp" in name or
                f"{last_hidden_layer_index}.self_attn.o_proj" in name or
                f"{last_hidden_layer_index}.self_attn.v_proj" in name):  # final norm is not initialized
            return True
        return False

    def freeze_non_lora_weights_in_vlm(self):
        """Keep all bias frozen"""
        for name, param in self.vision_tower.named_parameters():
            param.requires_grad = True if "lora_" in name else False
        log.info("Froze non-lora weights in vision tower")

        for name, param in self.multi_modal_projector.named_parameters():
            param.requires_grad = True if "lora_" in name else False
        log.info("Froze non-lora weights in projector")

        for name, param in self.joint_model.mixtures["vlm"].named_parameters():
            if not self._check_gemma_unused_parameter_by_name(name):
                param.requires_grad = True if "lora_" in name else False
        log.info("Froze non-lora weights in lm part of the joint model")

    def freeze_unused_weights(self):
        """text embedding and part of last layer of vlm, including lora"""
        self.embed_tokens.weight.requires_grad = False
        for name, param in self.joint_model.mixtures["vlm"].named_parameters():
            if self._check_gemma_unused_parameter_by_name(name):
                param.requires_grad = False

    def freeze_all_weights(self):
        for _, param in self.named_parameters():
            param.requires_grad = False

    def tie_action_proprio_weights(self):
        """technically more than just tying weights"""
        self.joint_model.mixtures["proprio"] = self.joint_model.mixtures["action"]

    def build_text_cache(self):
        return KVCache()

    # ---------- Input preparation ----------#

    def build_causal_mask_and_position_ids(self, attention_mask: torch.Tensor,
                                           dtype: torch.dtype) -> Tuple[torch.FloatTensor]:
        """
        block attention --- padding for unused text tokens

                 img/text img/text img/text dino(padding) proprio action action
        img/text    x        x        x
        img/text    x        x        x
        img/text    x        x        x
        (padding)
        dino        x        x        x        x
        proprio     x        x        x        x         x
        action      x        x        x        x         x       x      x
        action      x        x        x        x         x       x      x
        """
        bsz = attention_mask.size(0)
        dino_start = self.max_image_text_tokens
        dino_end = self.max_image_text_tokens + self.num_dino_tokens
        proprio_start = dino_end
        proprio_end = proprio_start + self.num_proprio_tokens
        action_start = proprio_end
        image_text_token_cnts = torch.sum(attention_mask, dim=1)
        causal_mask = torch.full(
            (bsz, self.total_num_tokens, self.total_num_tokens),
            torch.finfo(dtype).min,
            dtype=dtype,
        )  # smallest value, avoid using inf for softmax nan issues with padding
        for idx, cnt in enumerate(image_text_token_cnts):
            causal_mask[idx, :cnt, :cnt] = 0  # image/text attend to itself
            causal_mask[idx, dino_start:, :cnt] = (
                0  # dino/proprio/action attend to image/text
            )

        causal_mask[:, dino_start:dino_end, dino_start:dino_end] = (
            0  # dino attend to itself
        )
        causal_mask[:, proprio_start:, dino_start:dino_end] = (
            0  # dino attend to itself
        )

        causal_mask[:, proprio_start:proprio_end, proprio_start:proprio_end] = (
            0  # proprio attend to itself
        )
        causal_mask[:, action_start:, proprio_start:] = (
            0  # action attend to itself and proprio
        )

        # add the head dimension
        # [Batch_Size, Q_Len, KV_Len] -> [Batch_Size, Num_Heads_Q, Q_Len, KV_Len]
        causal_mask = causal_mask.unsqueeze(1)

        # position ids for each blocks --- start at 1
        vlm_position_ids = torch.arange(1, self.max_image_text_tokens + 1).repeat(bsz, 1)
        proprio_position_ids = torch.arange(1, self.num_proprio_tokens + 1).repeat(bsz, 1)
        action_position_ids = torch.arange(
            self.num_proprio_tokens + 1,
            self.num_proprio_tokens + self.num_action_tokens + 1,
        ).repeat(bsz, 1)
        if self.num_dino_tokens > 0:
            dino_position_ids = torch.arange(self.max_image_text_tokens + 1,
                                             self.max_image_text_tokens + self.num_dino_tokens + 1).repeat(bsz, 1)
        # since proprio and action share the same mixture weights, makes sense to use [1 (proprio), 2 (action), 3 (action), ...] instead of [1 (proprio), 1 (action), 2 (action), ...]
        if self.num_dino_tokens > 0:
            return causal_mask, vlm_position_ids, dino_position_ids, proprio_position_ids, action_position_ids
        else:
            return causal_mask, vlm_position_ids, proprio_position_ids, action_position_ids

    def split_full_mask_into_submasks(
        self,
        causal_mask: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """split into ones for paligemma and action"""
        image_text_proprio_mask = causal_mask[
            ...,
            :self.max_image_text_tokens + self.num_dino_tokens + self.num_proprio_tokens,
            :self.max_image_text_tokens + self.num_dino_tokens + self.num_proprio_tokens,
        ]
        action_mask = causal_mask[..., -self.num_action_tokens:, :]
        return image_text_proprio_mask, action_mask

    def build_causal_mask_and_position_ids_for_text(
        self,
        q_len: int,
        attention_mask: torch.Tensor,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
        dtype, device = attention_mask.dtype, attention_mask.device

        if kv_cache is None or kv_cache.num_items() == 0:
            # do not mask any token, because we're in the prefill phase
            # assume no padding
            causal_mask = torch.full((bsz, q_len, q_len), 0, dtype=dtype, device=device)
        else:
            assert q_len == 1, "Using KV cache so should only use one single token"
            kv_len = kv_cache.num_items() + q_len
            # also in this case we don't need to mask anything, since each query should be able to attend all previous tokens.
            # this only works when we have no padding
            causal_mask = torch.full((bsz, q_len, kv_len), 0, dtype=dtype, device=device)

        # add the head dimension
        # [Batch_Size, Q_Len, KV_Len] -> [Batch_Size, Num_Heads_Q, Q_Len, KV_Len]
        causal_mask = causal_mask.unsqueeze(1)

        if kv_cache is not None and kv_cache.num_items() > 0:
            # use the last location
            position_ids = attention_mask.cumsum(-1)[:, -1:]
        else:
            # create position_ids based on the size of the attention_mask
            # for padded tokens, use number 1
            position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
        return causal_mask, position_ids

    # ---------- Inference ----------#

    def _forward_siglip_and_text_embedding(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
    ) -> torch.FloatTensor:
        dtype, device = pixel_values.dtype, pixel_values.device

        # text embedding
        # [Batch_Size, Seq_Len, Hidden_Size]
        inputs_embeds = self.embed_tokens(input_ids)

        # image features from siglip and projector
        # [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
        selected_image_feature = self.vision_tower(pixel_values)
        image_features = self.multi_modal_projector(selected_image_feature)

        # normalize the image features
        _, _, embed_dim = image_features.shape
        bsz, seq_len = input_ids.shape
        scaled_image_features = image_features / (self.image_text_hidden_size**0.5)

        # put embedding together - image, text, padding
        final_embedding = torch.full((bsz, seq_len, embed_dim), self.pad_token_id, dtype=dtype, device=device)

        # [Batch_Size, Seq_Len]
        text_mask = (input_ids != self.image_token_index) & (input_ids != self.pad_token_id)
        image_mask = input_ids == self.image_token_index
        final_embedding[text_mask] = inputs_embeds[text_mask]
        for i in range(bsz):
            image_indices = image_mask[i].nonzero(as_tuple=True)[0]
            num_image_tokens = len(image_indices)
            final_embedding[i, image_indices] = scaled_image_features[i, :num_image_tokens]
        return final_embedding

    def _forward_dino_embedding(self, pixel_values: torch.FloatTensor, is_input: bool) -> torch.FloatTensor:
        if self.dino_type == "siglip" or self.dino_type == "siglip_hyc_wo_enc":
            with torch.no_grad():
                dino_embeds = self.dino_model(pixel_values)
                # print('dino_inputs_embeds.shape:', dino_inputs_embeds.shape)
        elif self.dino_type == "siglip_imgloss":
            if is_input:
                with torch.no_grad():
                    dino_embeds = self.dino_model(pixel_values)
                    # print('dino_inputs_embeds.shape:', dino_inputs_embeds.shape)
            else:
                return pixel_values  # gt_dino_pixel_values, no need to pass thru dino model
        elif self.dino_type == "dino":
            with torch.no_grad():
                dino_embeds = self.dino_model(pixel_values)['last_hidden_state']
                # print('dino_inputs_embeds.shape:', dino_inputs_embeds.shape)
        elif self.dino_type == "dino_future_distill":
            with torch.no_grad():
                # dino_pixel_values = self.dino_processor(dino_pixel_values, return_tensors='pt')['pixel_values'] # moved to trainer
                if not is_input:
                    dino_embeds = self.dino_model(pixel_values)['last_hidden_state']
                else:
                    dino_embeds = self.dino_model.embeddings(pixel_values)
            # if not is_input:
            #     # requires grad
            #     dino_embeds = self.din_encoder(dino_embeds)
        elif self.dino_type == "siglip_hyc_w_encoder_decoder":
            with torch.no_grad():
                dino_embeds = self.dino_model(pixel_values)
            if is_input:
                dino_embeds = self.din_encoder(dino_embeds)
        return dino_embeds

    def infer_action(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        image_text_proprio_mask: torch.FloatTensor,
        action_mask: torch.FloatTensor,
        vlm_position_ids: torch.LongTensor,
        proprio_position_ids: torch.LongTensor,
        action_position_ids: torch.LongTensor,
        proprios: torch.FloatTensor,
        gt_dino_pixel_values: torch.FloatTensor = None,
        dino_pixel_values: torch.FloatTensor = None,
        dino_position_ids: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        dtype, device = pixel_values.dtype, pixel_values.device
        bsz = pixel_values.size(0)

        kv_caches = self.joint_model.build_mixture_caches()

        # merge the text tokens and the image tokens
        inputs_embeds = self._forward_siglip_and_text_embedding(input_ids, pixel_values)
        if self.pred_dino:
            dino_inputs_embeds = self._forward_dino_embedding(dino_pixel_values, is_input=True)

        # proprio
        proprio_embeds = self.proprio_encoder(proprios)

        # forward pass thru the vlm and proprio, cache the kv
        if self.pred_dino:
            _, kv_caches = self.joint_model(
                attention_mask=image_text_proprio_mask,
                position_ids_all={
                    "vlm": vlm_position_ids,
                    "dino": dino_position_ids,
                    "proprio": proprio_position_ids,
                },
                embeds_all={
                    "vlm": inputs_embeds,
                    "dino": dino_inputs_embeds,
                    "proprio": proprio_embeds,
                },
                kv_caches=kv_caches,
                return_caches=True,
            )
        else:
            _, kv_caches = self.joint_model(
                attention_mask=image_text_proprio_mask,
                position_ids_all={
                    "vlm": vlm_position_ids,
                    "proprio": proprio_position_ids,
                },
                embeds_all={
                    "vlm": inputs_embeds,
                    "proprio": proprio_embeds,
                },
                kv_caches=kv_caches,
                return_caches=True,
            )

        # sample pure action noise
        action = torch.randn((bsz, self.horizon_steps, self.action_dim), device=device, dtype=dtype)

        # forward euler integration --- using kv caches of vlm and proprio
        delta_t = 1.0 / self.num_inference_steps
        t = torch.zeros(bsz, device=device, dtype=dtype)
        for _ in range(self.num_inference_steps):
            # encode action and time into embedding
            time_cond = self.time_embedding(t)
            # [Batch_Size, Horizon_Steps, Embed_Dim]
            if self.action_expert_adaptive_mode:
                action_embeds = self.action_encoder(action)
            else:
                action_embeds = self.action_encoder(action, time_cond)
            # [Batch_Size, Horizon_Steps, Embed_Dim]
            action_embeds = self.joint_model(
                attention_mask=action_mask,
                position_ids_all={"action": action_position_ids},
                embeds_all={"action": action_embeds},
                time_cond=time_cond,
                kv_caches=kv_caches,
                cache_mode="append_non_active",  # use caches from other mixtures, i.e., vlm and proprio
            )["action"]
            # decode action: [Batch_Size, Horizon_Steps, Action_Dim]
            action_vel = self.action_decoder(action_embeds)
            action += delta_t * action_vel
            t += delta_t

        # clamp final output if specified
        if self.final_action_clip_value is not None:
            action = torch.clamp(
                action,
                -self.final_action_clip_value,
                self.final_action_clip_value,
            )
        return action

    def infer_action_naive(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        causal_mask: torch.FloatTensor,
        vlm_position_ids: torch.LongTensor,
        proprio_position_ids: torch.LongTensor,
        action_position_ids: torch.LongTensor,
        proprios: torch.FloatTensor,
    ) -> torch.FloatTensor:
        dtype, device = pixel_values.dtype, pixel_values.device
        bsz = pixel_values.size(0)

        kv_caches = self.joint_model.build_mixture_caches()

        # merge the text tokens and the image tokens
        inputs_embeds = self._forward_siglip_and_text_embedding(input_ids, pixel_values)

        # encode proprio
        proprio_embeds = self.proprio_encoder(proprios)

        # sample pure action noise
        action = torch.randn((bsz, self.horizon_steps, self.action_dim), device=device, dtype=dtype)

        # forward euler integration --- run vlm in each step, which is unnecessary
        delta_t = 1.0 / self.num_inference_steps
        t = torch.zeros(bsz, device=device, dtype=dtype)
        for _ in range(self.num_inference_steps):
            # encode action and time into embedding
            time_cond = self.time_embedding(t)
            # [Batch_Size, Horizon_Steps, Embed_Dim]
            if self.action_expert_adaptive_mode:
                action_embeds = self.action_encoder(action)
            else:
                action_embeds = self.action_encoder(action, time_cond)
            action_embeds = self.joint_model(
                attention_mask=causal_mask,
                position_ids_all={
                    "vlm": vlm_position_ids,
                    "proprio": proprio_position_ids,
                    "action": action_position_ids,
                },
                embeds_all={
                    "vlm": inputs_embeds.clone(),  # clone needed due to modified in-place
                    "proprio": proprio_embeds.clone(),
                    "action": action_embeds,
                },
                time_cond=time_cond,
                kv_caches=kv_caches,
                cache_mode="no_append",  # no new tokens
            )["action"]
            # decode action: [Batch_Size, Horizon_Steps, Action_Dim]
            action_vel = self.action_decoder(action_embeds)
            action += delta_t * action_vel
            t += delta_t

        # clamp final output if specified
        if self.final_action_clip_value is not None:
            action = torch.clamp(
                action,
                -self.final_action_clip_value,
                self.final_action_clip_value,
            )
        return action

    def infer_text(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        attention_mask: torch.Tensor,
        kv_cache: Optional[KVCache] = None,
    ) -> Tuple:
        q_len = input_ids.size(1)

        # text tokens + image tokens
        inputs_embeds = self._forward_siglip_and_text_embedding(input_ids, pixel_values)

        # build causal mask and position ids for text
        (
            causal_mask,
            position_ids,
        ) = self.build_causal_mask_and_position_ids_for_text(q_len, attention_mask, kv_cache)

        hidden_states = self.joint_model(
            attention_mask=causal_mask,
            position_ids_all={"vlm": position_ids},
            embeds_all={"vlm": inputs_embeds},
            kv_caches={"vlm": kv_cache},
            cache_mode="append",  # new tokens for the active mixture
            final_layer_post_attn_skip_names=[],  # do not skip vlm last layer
        )["vlm"]
        logits = self.lm_head(hidden_states)
        output = {
            "logits": logits,
        }
        if kv_cache is not None:
            output["kv_cache"] = kv_cache
        return output

    # ---------- Flow matching training ----------#

    def psi_t(
        self,
        x: torch.FloatTensor,
        x1: torch.FloatTensor,
        t: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """Conditional Flow"""
        t = t[:, None, None]  # (B, 1, 1)
        return (1 - (1 - self.flow_sig_min) * t) * x + t * x1

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.ByteTensor,
        causal_mask: torch.FloatTensor,
        vlm_position_ids: torch.LongTensor,
        proprio_position_ids: torch.LongTensor,
        action_position_ids: torch.LongTensor,
        proprios: torch.FloatTensor,
        actions: torch.FloatTensor,
        t: torch.FloatTensor,
        gt_dino_pixel_values: torch.FloatTensor = None,
        dino_pixel_values: torch.FloatTensor = None,
        dino_position_ids: torch.LongTensor = None,
    ) -> torch.FloatTensor:
        """flow matching loss for action prediction, no use of kv cache"""
        # noisy action
        # [Batch_Size, Horizon_Steps, Action_Dim]
        if self.pred_dino:
            dino_inputs_embeds = self._forward_dino_embedding(dino_pixel_values, is_input=True)
            dino_gt = self._forward_dino_embedding(gt_dino_pixel_values, is_input=False)

        x0 = torch.randn_like(actions, device=t.device, dtype=t.dtype)
        x1 = actions
        psi_t = self.psi_t(x0, x1, t)
        # text tokens + image tokens
        inputs_embeds = self._forward_siglip_and_text_embedding(input_ids, pixel_values)

        # proprio
        proprio_embeds = self.proprio_encoder(proprios)

        # inference with noisy action
        # [Batch_Size, Embed_Dim]
        time_cond = self.time_embedding(t)
        # [Batch_Size, Horizon_Steps, Embed_Dim]
        if self.action_expert_adaptive_mode:
            action_embeds = self.action_encoder(psi_t)
        else:
            action_embeds = self.action_encoder(psi_t, time_cond)
        if self.pred_dino:
            joint_embeds = self.joint_model(
                attention_mask=causal_mask,
                position_ids_all={
                    "vlm": vlm_position_ids,
                    "dino": dino_position_ids,
                    "proprio": proprio_position_ids,
                    "action": action_position_ids,
                },
                embeds_all={
                    "vlm": inputs_embeds,
                    "dino": dino_inputs_embeds,
                    "proprio": proprio_embeds,
                    "action": action_embeds,
                },
                time_cond=time_cond,
                kv_caches={},  # no caching during training
            )
            action_embeds = joint_embeds["action"]
            dino_pred = joint_embeds["dino"]
            # print("dino_pred.shape:", dino_pred.shape)
            if self.dino_type == "siglip_hyc_w_encoder_decoder" or self.dino_type == "siglip_imgloss":
                dino_pred = self.din_decoder(dino_pred)
                # print("after din_decoder, dino_pred.shape:", dino_pred.shape)
        else:
            action_embeds = self.joint_model(
                attention_mask=causal_mask,
                position_ids_all={
                    "vlm": vlm_position_ids,
                    "proprio": proprio_position_ids,
                    "action": action_position_ids,
                },
                embeds_all={
                    "vlm": inputs_embeds,
                    "proprio": proprio_embeds,
                    "action": action_embeds,
                },
                time_cond=time_cond,
                kv_caches={},  # no caching during training
            )["action"]

        # [Batch_Size, Horizon_Steps, Action_Dim]
        v_psi = self.action_decoder(action_embeds)

        # compare to true velocity
        d_psi = x1 - (1 - self.flow_sig_min) * x0
        action_loss = torch.mean((v_psi - d_psi)**2)
        if self.pred_dino:
            dino_loss = torch.mean((dino_pred - dino_gt)**2)
            return {"action_loss": action_loss, "dino_loss": dino_loss}
        else:
            return {"action_loss": action_loss}
        #return torch.mean((v_psi - d_psi) ** 2)


class UniCoDInference(UniCoD):

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        image_text_proprio_mask: torch.FloatTensor,
        action_mask: torch.FloatTensor,
        vlm_position_ids: torch.LongTensor,
        proprio_position_ids: torch.LongTensor,
        action_position_ids: torch.LongTensor,
        proprios: torch.FloatTensor,
    ) -> torch.FloatTensor:
        return super().infer_action(
            input_ids,
            pixel_values,
            image_text_proprio_mask,
            action_mask,
            vlm_position_ids,
            proprio_position_ids,
            action_position_ids,
            proprios,
        )
