#!/usr/bin/env python3
# Copyright (c) 2024 ByteDance. All Rights Reserved.
# GLEE Model.
# GLEE: General Object Foundation Model for Images and Videos at Scale (CVPR 2024)
# https://arxiv.org/abs/2312.09158

import torch
import torch.nn.functional as F
from torch import nn, einsum
from detectron2.modeling import build_backbone
from .pixel_decoder.maskdino_encoder import build_pixel_decoder
from .transformer_decoder.maskdino_decoder import build_transformer_decoder
import random
from collections import OrderedDict
from ..modules.point_features import point_sample
from timm.models.layers import trunc_normal_
from transformers import CLIPTokenizer, CLIPTextModel, BertTokenizer, BertModel
from .vos_utils import masks_to_boxes, FeatureFuser
import numpy as np 
import math
from itertools import chain
from peft import LoraConfig, TaskType
from typing import Optional, Callable

def rand_sample(x, max_len):
    if x.shape[1] <= max_len:
        return x
    else:
        rand_idx = torch.randperm(x.shape[1])[:max_len]
        return x[:,rand_idx]


def agg_lang_feat(features, mask, pool_type="average"):
    """average pooling of language features"""
    # feat: (bs, seq_len, C)
    # mask: (bs, seq_len)
    embedded = features * mask.unsqueeze(-1).float() # use mask to zero out invalid token features
    aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
    return aggregate
    

class CrossAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5

        # projection layers
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        self.to_out = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        for linear in [self.to_q, self.to_k, self.to_v, self.to_out]:
            eye = torch.eye(linear.in_features)
            linear.weight.data.copy_(eye)

    def _attention(self, q, k, v, mask):
        #attn = torch.einsum('bhqd,bhkd->bhqk', k, q) * self.scale
        attn = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
        if mask is not None:
            #mask = mask.unsqueeze(1).unsqueeze(2)
            mask = mask.to(q.device).unsqueeze(1).unsqueeze(-1)
            attn = attn.masked_fill(mask == 0, -1e4)

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
        return out

    def forward(self, x, query, mask=None):
        b, n, d = x.shape
        h = self.heads
        head_dim = d // h

        q = self.to_q(x).reshape(b, -1, h, head_dim).transpose(1, 2)
        k = self.to_k(query).reshape(b, -1, h, head_dim).transpose(1, 2)
        v = self.to_v(query).reshape(b, -1, h, head_dim).transpose(1, 2)
        out = self._attention(q, k, v, mask)
        out = out.transpose(1, 2).reshape(b, n, d)
        return self.to_out(out)

class CrossAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.scale = (dim // heads) ** -0.5

        # projection layers
        self.to_q = nn.Linear(dim, dim, bias=False)
        self.to_k = nn.Linear(dim, dim, bias=False)
        self.to_v = nn.Linear(dim, dim, bias=False)
        self.to_out = nn.Linear(dim, dim)

        self.dropout = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        for linear in [self.to_q, self.to_k, self.to_v, self.to_out]:
            eye = torch.eye(linear.in_features)
            linear.weight.data.copy_(eye)

    def _attention(self, q, k, v, mask):
        attn = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
        if mask is not None:
            mask = mask.to(q.device).unsqueeze(1).unsqueeze(-1)
            attn = attn.masked_fill(mask == 0, -1e4)

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
        return out

    def forward(self, x, query, mask=None):
        b, n, d = x.shape
        h = self.heads
        head_dim = d // h

        q = self.to_q(x).reshape(b, -1, h, head_dim).transpose(1, 2)
        k = self.to_k(query).reshape(b, -1, h, head_dim).transpose(1, 2)
        v = self.to_v(query).reshape(b, -1, h, head_dim).transpose(1, 2)
        out = self._attention(q, k, v, mask)
        out = out.transpose(1, 2).reshape(b, n, d)
        return self.to_out(out)


def triplet_cosine_loss(anchor, positive, negative, margin=0.2):
    if positive is None:
        return F.relu(margin - F.cosine_similarity(anchor, negative, dim=-1)).mean()
    
    pos_sim = F.cosine_similarity(anchor, positive, dim=-1)
    neg_sim = F.cosine_similarity(anchor, negative, dim=-1)
    return F.relu(margin + neg_sim - pos_sim).mean()


def decorrelation_loss(a, b):
    a_flat = F.normalize(a.view(-1, a.shape[-1]), dim=-1)
    b_flat = F.normalize(b.view(-1, b.shape[-1]), dim=-1)
    return (a_flat * b_flat).sum(dim=-1).mean().abs()


class TextCrossAttentionBlock(nn.Module):
    def __init__(self, dim, heads, depth, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            CrossAttention(dim=dim, heads=heads, dropout=dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)
        nn.init.constant_(self.norm.weight, 1.0)
        nn.init.constant_(self.norm.bias, 0.0)

    def forward(self, x, query, attention_mask=None):
        for layer in self.layers:
            x = x + layer(x, query, attention_mask)
        return self.norm(x)

class DisentangledAggregation(nn.Module):
    def __init__(self, hidden_dim, heads=8, depth=1, dropout=0.1, max_seq_len=128):
        super().__init__()
        self.q_obj = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
        self.q_att = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
        self.q_rel = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))

        self.cross_obj = TextCrossAttentionBlock(hidden_dim, heads, depth, dropout)
        self.cross_att = TextCrossAttentionBlock(hidden_dim, heads, depth, dropout)
        self.cross_rel = TextCrossAttentionBlock(hidden_dim, heads, depth, dropout)

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim, bias=False)
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim, bias=False)
        )
        self.norm = nn.LayerNorm(hidden_dim)
        nn.init.constant_(self.norm.weight, 1.0)
        nn.init.constant_(self.norm.bias, 0.0)
        self._init_weights(hidden_dim)

    def _init_weights(self, hidden_dim):
        eye = torch.eye(hidden_dim)
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                layer.weight.data.copy_(eye)
        for layer in self.mlp2:
            if isinstance(layer, nn.Linear):
                layer.weight.data.copy_(eye)

    def forward(self, x, mask=None, ref=None):
        B, L, D = x.shape
        x = self.norm(x + self.mlp(x))

        q_obj = self.q_obj.expand(B, -1, -1).contiguous()
        q_att = self.q_att.expand(B, -1, -1).contiguous()
        q_rel = self.q_rel.expand(B, -1, -1).contiguous()

        obj = self.cross_obj(x, q_obj, mask)
        att = self.cross_att(x, q_att, mask)
        rel = self.cross_rel(x, q_rel, mask)
        fusion = self.norm(obj + att + rel)
        
        fusion = self.norm(fusion + self.mlp2(fusion))

        if ref is None:
            return fusion
        
        batch = B // 7
        obj = obj[batch:].view(batch, 6, L, D)
        att = att[batch:].view(batch, 6, L, D)
        rel = rel[batch:].view(batch, 6, L, D)
        obj_loss = triplet_cosine_loss(obj[:, 0], obj[:, 1:3].mean(dim=1), obj[:, 3])
        att_loss = triplet_cosine_loss(att[:, 1], att[:, [2, 5]].mean(dim=1), att[:, 4])
        rel_loss = triplet_cosine_loss(rel[:, 2], None, rel[:, 5])
        decor1 = decorrelation_loss(obj, att)
        decor2 = decorrelation_loss(obj, rel)
        decor3 = decorrelation_loss(att, rel)
        disen_loss = obj_loss + att_loss + rel_loss + (decor1 + decor2 + decor3) * 0.1

        return fusion, disen_loss

class TASE_Model(nn.Module):
    def __init__(self, cfg, matcher, device, video_info, contras_mean):
        super().__init__()
        self.cfg = cfg
        self.matcher = matcher
        self.backbone = build_backbone(cfg)
        output_channels = [v for k,v in self.backbone._out_feature_channels.items()]
        
        self.lang_encoder = None
        
        self.find_unused_params = cfg.FIND_UNUSED_PARAMETERS
        
        if cfg.MODEL.VISUAL_PROMPT:
            self.sot_fuser = FeatureFuser(output_channels[-3:], 256)
        
        self.text_encode_type = cfg.MODEL.TEXT.ARCH
        self.early_fusion = cfg.MODEL.EARLYFUSION
        self.hier_training = cfg.MODEL.HIER_TRAINING
        self.root = None
        self.CLIP_PATH = "/data/weight/GLEE/clip_vit_base_patch32"
        self.tokenizer = CLIPTokenizer.from_pretrained(self.CLIP_PATH) 
        self.tokenizer.add_special_tokens({'cls_token': self.tokenizer.eos_token})
        self.text_encoder = CLIPTextModel.from_pretrained(self.CLIP_PATH)
        self.text_encoder_teacher = CLIPTextModel.from_pretrained(self.CLIP_PATH)
        self.lang_encoder = None
        for p in self.text_encoder_teacher.parameters():
            p.requires_grad = False
        self.lang_projection = nn.Parameter(torch.rand(cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, cfg.MODEL.DIM_PROJ))
        
        self.pixel_decoder = build_pixel_decoder(cfg, self.backbone.output_shape())
        transformer_predictor_in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
        self.predictor = build_transformer_decoder(cfg, transformer_predictor_in_channels, lang_encoder = self.lang_encoder, mask_classification=True,)
        self.to(device)
        self.hier_teacher = cfg.MODEL.HIER_TEACHER
        self.hier_posneg = cfg.MODEL.HIER_POSNEG
        
        self.video_info = video_info
        self.contras_mean = contras_mean

        self.track_loss_version = cfg.MODEL.TRACK_VERSION

        self.no_mask_tasks = ['obj365', 'obj365_clip','openimage', 'openimage_clip', 'vg', 'vg_syn', 'grit', 'bdd_det', 'bdd_track_box'] 

        hidden_dim = 256
        self.max_spatial_len = [512,512,512,512]
        self.mask_sptial_embed = nn.ParameterList([nn.Parameter(torch.empty(hidden_dim, hidden_dim)) for x in range(4)])
        trunc_normal_(self.mask_sptial_embed[0], std=.02)
        trunc_normal_(self.mask_sptial_embed[1], std=.02)
        trunc_normal_(self.mask_sptial_embed[2], std=.02)
        trunc_normal_(self.mask_sptial_embed[3], std=.02)
        self.pn_indicator = nn.Embedding(2, hidden_dim)

        self.disentangler = DisentangledAggregation(hidden_dim, depth=1)

        if cfg.MODEL.LORA:
            key = ["q_proj", "v_proj"]
            peft_config = LoraConfig(
                target_modules = key,
                task_type=TaskType.FEATURE_EXTRACTION,
                inference_mode=False, 
                r=cfg.MODEL.LORA_RANK, lora_alpha=cfg.MODEL.LORA_ALPHA, lora_dropout=0.1,
            )
            self.text_encoder.add_adapter(peft_config, adapter_name="default")
        

    @property
    def device(self):
        return self.pixel_mean.device
        
    def forward(self, images, prompts, task, targets=None, batch_name_list=None, is_train = True, visual_prompt_type='scribble'):
        extra =  {}
        eloss, dist_loss = 0.0, 0.0

        ###################
        # CLASS #
        ###################
        
        early_semantic = None 
        if self.text_encode_type == 'clip_frozen':
            if task not in ['grounding', 'rvos']:
                assert batch_name_list
                calsses_name_list = batch_name_list
                token_x, attn_mask = self.get_text_embedding(calsses_name_list, plus=True)
                token_x = token_x @ self.lang_projection
                lang_feat_pool = agg_lang_feat(token_x, attn_mask, pool_type="average")  # (bs, 768)
                extra['class_embeddings'] = lang_feat_pool
                dist_loss =  (lang_feat_pool*0).sum()
                gather_all_classtoken = token_x.flatten(0,1)[attn_mask.flatten(0,1)>0]
                gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
                gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0  #[bs,L]
                early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask} 
        
        #############
        # GROUNDING #
        #############
    
        elif "grounding" in prompts and task not in ['vg_syn', 'vg_syn2']:
            token_x, attn_mask = self.get_text_embedding(prompts['grounding']) 
            token_x = token_x @ self.lang_projection
            token_x = self.disentangler(token_x, mask=attn_mask)
            extra['grounding_tokens'] = token_x.permute(1, 0, 2).contiguous()  # (len, bz, C)
            lang_feat_pool = agg_lang_feat(token_x, attn_mask, pool_type="average").unsqueeze(1)  # (bs, 1, 768)
            extra['grounding_nonzero_mask'] = ~attn_mask.bool()  # [bz, len]      
            extra['grounding_class'] = lang_feat_pool.squeeze(1)  # [bz, C]
            early_semantic = {"hidden": token_x.float(), "masks": attn_mask > 0}
            
        ########################################
        # Hierarchical representation learning #
        ########################################

        elif task in ['vg_syn', 'vg_syn2'] and self.training:
            assert len(prompts['references']) == len(images), f"Expected batch size {len(images)}"
            grounding = prompts.get("grounding", batch_name_list)
            references = prompts['references']

            while any(isinstance(i, list) for i in references): 
                references = list(chain.from_iterable(references))
                
            while any(isinstance(i, list) for i in grounding): 
                grounding = list(chain.from_iterable(grounding))
            batch_size = len(grounding)
                
            prompt_list = grounding + references

            # root embedding
            if self.root is None:
                #with torch.no_grad():
                self.root, attn_mask = self.get_text_embedding([""], teacher=True)
                self.root = self.root @ self.lang_projection
                self.root_token = self.root.detach()
                self.root = agg_lang_feat(self.root, attn_mask, pool_type="average").unsqueeze(1)
                self.root = self.root.detach()

            token_x, attn_mask = self.get_text_embedding(prompt_list)
            token_x = token_x @ self.lang_projection
            token_x, disen_loss = self.disentangler(token_x, ref=self.root_token)
            extra['grounding_tokens'] = token_x[:batch_size].permute(1, 0, 2).contiguous()  # (len, bz, C)
            lang_feat_emb = agg_lang_feat(token_x, attn_mask, pool_type="average").unsqueeze(1)  # (bs, 1, 768)
            lang_feat_pool, hierarchy_emb = lang_feat_emb[:batch_size], lang_feat_emb[batch_size:] 
            
            extra['grounding_nonzero_mask'] = ~attn_mask[:batch_size].bool()  # [bz, len]
            extra['grounding_class'] = lang_feat_pool.squeeze(1)  # [bz, C]
            gather_all_classtoken = token_x[:batch_size].flatten(0, 1)[attn_mask[:batch_size].flatten(0, 1)>0]
            gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images), 1, 1) #[bs,L,C]
            gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0  #[bs,L]
            early_semantic = {"hidden": gather_all_classtoken.float(), "masks": gather_all_classtoken_mask}#, "feat": feat}
            dist_loss = (lang_feat_pool*0).sum()
            
            pos_indices = [(i, i + 1) for b in range(len(references)//6) for i in range(b * 6, b * 6 + 2)]
            pos_indices += [(i, i + 4) for b in range(len(references)//6) for i in range(b * 6, b * 6 + 2)]
            neg_indices = [(i, i + 3) for b in range(len(references)//6) for i in range(b * 6, b * 6 + 3)]
            
            def ij2ext_pos(i, j,  tau=1., eps=1e-8):
                a = hierarchy_emb[i:i+1] - self.root
                b = hierarchy_emb[j:j+1] - self.root
                a_norm = torch.nn.functional.normalize(a, dim=-1, eps=eps)
                b_norm = torch.nn.functional.normalize(b, dim=-1, eps=eps)
                b_prime = b_norm - a_norm
                b_prime = b_prime / (b_prime.norm(dim=-1, keepdim=True) + eps)
                ext_c = (a_norm * b_prime).sum(dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
                ext_a = ext_c.acos() / tau  
                return ext_a
            
            def ij2ext_neg(i, j,  tau=1., eps=1e-8):
                if i % 6 == 0:
                    ref = self.root
                else:
                    ref = torch.nn.functional.normalize(hierarchy_emb[i-1:i], dim=-1, eps=eps)
                a = hierarchy_emb[i:i+1] - ref
                b = ref - hierarchy_emb[j:j+1]
                a_norm = torch.nn.functional.normalize(a, dim=-1, eps=eps)
                b_norm = torch.nn.functional.normalize(b, dim=-1, eps=eps)
                b_prime = b_norm - a_norm
                b_prime = b_prime / (b_prime.norm(dim=-1, keepdim=True) + eps)
                ext_c = (a_norm * b_prime).sum(dim=-1).clamp(min=-1.0 + eps, max=1.0 - eps)
                ext_a = ext_c.acos() / tau 
                return ext_a
            
            P = torch.stack([ij2ext_pos(i, j) for i, j in pos_indices])
            N = torch.stack([ij2ext_neg(i, j) for i, j in neg_indices])
            eloss = P.mean() * 2.0 + N.mean() + disen_loss
    
        #####################
        # STAGE 2/3 - CLASS #
        #####################
        
        elif self.text_encode_type in ["clip_teacher", "bert"]:           
            calsses_name_list = batch_name_list
            token_x, attn_mask = self.get_text_embedding(calsses_name_list)           
            valid_mask = attn_mask.bool()
            token_x_teacher, attn_mask = self.get_text_embedding(calsses_name_list, teacher=True, plus=True)
            dist_loss =  F.mse_loss(token_x[valid_mask], token_x_teacher[valid_mask])
            token_x = token_x @ self.lang_projection
            lang_feat_pool = agg_lang_feat(token_x, attn_mask, pool_type="average")  # (bs,  768)
            extra['class_embeddings'] = lang_feat_pool
            gather_all_classtoken = token_x.flatten(0,1)[attn_mask.flatten(0,1)>0]
            gather_all_classtoken = gather_all_classtoken.unsqueeze(0).repeat(len(images),1,1) #[bs,L,C]
            gather_all_classtoken_mask = torch.ones_like(gather_all_classtoken[:,:,0])>0  #[bs,L]
            early_semantic = {"hidden":gather_all_classtoken.float(),"masks":gather_all_classtoken_mask} 
            
        if isinstance(images, torch.Tensor):
            features = self.backbone(images)
            image_tensor = images
        else:
            features = self.backbone(images.tensor)
            image_tensor = images.tensor

        if 'spatial' in prompts:
            ## setp 1,2,3
            key_images = [image_tensor[kid].unsqueeze(0)  for kid in range(len(image_tensor))]  #bz*[1,3,H,W]
            key_promptmasks = [m.unsqueeze(0) for m in prompts['spatial']] #bz*[1,1,H,W]

            if is_train:
                if np.random.rand() > 0.6:
                    #  image prompt mode
                    if np.random.rand() > 0.8: # box mode
                        prompt_mode = 'box'
                    else:
                        prompt_mode = 'point' # samole a point, extend a  [H//20,W//20] rectangle mask 
                        # Get a random pos point 
                        non_zero_pos_points = [rand_sample((m.nonzero()[:,1:]).t(), 1).t() for m in prompts['spatial']]
                        new_point2mask = []
                        _,h,w = prompts['spatial'][0].shape
                        point_h = h//40
                        point_w = w//40
                        for point,pmask in zip(non_zero_pos_points, key_promptmasks):
                            zeromask = torch.zeros_like(pmask)
                            zeromask[:,:, point[0,0]-point_h: point[0,0]+point_h , point[0,1]-point_w:point[0,1]+point_w ] = True
                            new_point2mask.append(zeromask)
                        key_promptmasks = new_point2mask

                    # update the visual prompt used by step2 self-attention 
                    new_prompts = []
                    for ori_mask, pmask in zip(prompts['spatial'], key_promptmasks):
                        zeromask = torch.zeros_like(ori_mask)
                        x1,y1,x2,y2  = masks_to_boxes(pmask[0])[0].long().tolist()  #[xyxy]
                        zeromask[:, y1:y2 , x1:x2] = True
                        new_prompts.append(zeromask)
                    prompts['spatial'] = new_prompts

                else:
                    prompt_mode = visual_prompt_type     
            else:
                prompt_mode = visual_prompt_type     
               
            ref_feats, ref_masks = self.get_template(key_images, key_promptmasks, prompt_mode) 
            
            early_fusion = {"hidden":ref_feats,"masks":ref_masks} 
            
            if early_semantic is None:
                early_semantic = early_fusion
            else:
                early_semantic["hidden"] = torch.cat([early_semantic["hidden"],early_fusion["hidden"]],dim=1)
                early_semantic["masks"] = torch.cat([early_semantic["masks"],early_fusion["masks"]],dim=1)

        mask_features, _, multi_scale_features, zero_loss = self.pixel_decoder.forward_features(features, masks=None, early_fusion = early_semantic)
        params_zero_loss = (self.pn_indicator.weight*0).sum()
        if zero_loss is not None:
            params_zero_loss += zero_loss
        for p in self.mask_sptial_embed:
            params_zero_loss += (p*0).sum()
        
        # zero_fuser_loss = 
        params_zero_loss += (self.predictor.coco_label_enc.weight*0).sum()  +\
        (self.predictor.obj365_label_enc.weight*0).sum() +\
        (self.predictor.vg_label_enc.weight*0).sum() +\
        (self.predictor.grounding_label_enc.weight*0).sum() +\
        (self.predictor.ytvis19_label_enc.weight*0).sum() +\
        (self.predictor.ytvis21_label_enc.weight*0).sum() +\
        (self.predictor.ovis_label_enc.weight*0).sum() +\
        (self.predictor.uvo_label_enc.weight*0).sum() +\
        (self.predictor.bdd_det.weight*0).sum() +\
        (self.predictor.bdd_inst.weight*0).sum()
        
        if hasattr(self,'sot_fuser') and not self.find_unused_params:  # for EVA02 checkpointing, when not in visual prompt mode, make a fake loss to ensure all parameters participate in loss calculation 
            fake_fuser_loss = self.sot_fuser([torch.zeros(1,256,1,1).to(zero_loss),torch.zeros(1,256,1,1).to(zero_loss),torch.zeros(1,256,1,1).to(zero_loss)])
            params_zero_loss += (fake_fuser_loss.sum())*0

        if task in ['vis', 'ovis', 'ytvis19' ,'ytvis21', 'uvo_video', 'burst', 'rvos','coco_clip','obj365_clip','sa1b_clip',\
            'bdd_track_seg', 'bdd_track_box', 'uvof_clip','lvis_clip','openimage_clip'] and is_train:
            video_outputs = []
            
            outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
            track_loss = self.get_tracking_contrastive_lossv3(outputs[0], targets, task)

            return outputs, track_loss, dist_loss+params_zero_loss
        else:
            outputs = self.predictor(multi_scale_features, mask_features, extra=extra, task=task, masks=None, targets=targets)
            fake_track_loss = (outputs[0]['pred_track_embed']*0).sum()
            losses = (
                fake_track_loss + params_zero_loss,
                dist_loss + params_zero_loss,
                eloss + params_zero_loss,
            )
            return outputs, *losses


    def create_group_attention_mask(self, attn_mask, group_size):
        """
        Creates a block-diagonal attention mask.

        attn_mask: (batch_size, seq_len) binary attention mask (1 = attend, 0 = ignore)
        group_size: The number of tokens per sentence pair
        """
        batch_size, seq_len = attn_mask.shape
        mask = torch.zeros((batch_size, seq_len, seq_len), dtype=torch.float32, device=attn_mask.device)

        for b in range(batch_size):
            for i in range(0, seq_len, group_size):  # Split into sentence pairs
                indices = torch.arange(i, i + group_size, device=attn_mask.device)
                mask[b, indices[:, None], indices] = 1  # Allow self-attention within each pair

        return mask  # shape: (batch_size, seq_len, seq_len)


    def get_text_embedding(self, prompts, teacher=False, plus=False):
        if plus:
            tokens = self.tokenizer.batch_encode_plus(prompts,
                    max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, # 256
                    padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", # max_length
                    return_special_tokens_mask=True, return_tensors='pt', truncation=True).to("cuda")
        else: 
            tokens = self.tokenizer(prompts, padding='max_length', truncation=True,
                max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, return_tensors='pt')

        tokens = {key: value.to(self.text_encoder.device) for key, value in tokens.items()}
        
        token_emb, attention_mask = tokens['input_ids'], tokens['attention_mask']
        prompt_embeds = (token_emb, attention_mask)
        if teacher:
            enc = self.text_encoder_teacher(*prompt_embeds)
        else:
            enc = self.text_encoder(*prompt_embeds)
        prompt_embeds = enc['last_hidden_state']
        return prompt_embeds, attention_mask
     

    def get_template(self, imgs, pad_masks, prompt_mode='scribble'):
        """img: (N, 3, H, W), mask: (N, 1, H, W), bbox: (1, 4)"""
        """get 4-channel template"""

        croped_img_with_mask = []

        for image_i, mask_i in zip( imgs, pad_masks):

            if prompt_mode in ['scribble','point']:
                image_with_mask = image_i + mask_i.to(image_i)
            else:
                image_with_mask = image_i 

            # image_with_mask = torch.cat([image_i,mask_i.to(image_i)],dim=1) #[1,3,H,W]
            box_i = masks_to_boxes(mask_i[0])  #[xyxy]
            box_i[:, 2:] = box_i[:, 2:] - box_i[:, :2] #xywh
            

            x, y, w, h = box_i[0].long().tolist()

            self.search_area_factor=2

            crop_sz = math.ceil(math.sqrt(w * h) * self.search_area_factor)
            x1 = max(0,round(x + 0.5 * w - crop_sz * 0.5))
            x2 = x1 + crop_sz
            y1 = max(0,round(y + 0.5 * h - crop_sz * 0.5))
            y2 = y1 + crop_sz

            im_crop = image_with_mask[:, :, y1:y2, x1:x2]
            # resize
            if im_crop.shape[-1] ==0 or im_crop.shape[-2] ==0 :
                im_crop = image_with_mask
            im_crop = F.interpolate(im_crop, (256,256), mode='bilinear', align_corners=False)
            croped_img_with_mask.append(im_crop)
        croped_img_with_mask = torch.cat(croped_img_with_mask,dim=0) #[bz,3,256,256]
        with torch.no_grad():
            ref_srcs = self.backbone(croped_img_with_mask.contiguous())
        ref_srcs = [v for k,v in ref_srcs.items()]
        ref_feats = self.sot_fuser(ref_srcs[1:]).float() #[bz,256,32,32]

        ref_feats = ref_feats.flatten(-2).permute(0, 2, 1) # (bs, L, C)
        ref_masks = torch.ones_like(ref_feats[:,:,0])>0  #[bs,L]
        
        return ref_feats, ref_masks


    def get_tracking_contrastive_lossv3(self, video_outputs, video_targets, task):  # IDOL track loss
        if task in self.no_mask_tasks:
            indices_all = self.matcher(video_outputs, video_targets, 'task', cost=["cls", "box"])
        else:
            indices_all = self.matcher(video_outputs, video_targets, 'task' )
        
        video_len = self.video_info['len']
        track_loss = 0
        num_inst = 0

        batch_similarity = []
        batch_label = []
        for i in range(self.video_info['bz']): 
            indices = indices_all[i*video_len:(i+1)*video_len]
            bz_embedding = video_outputs['pred_track_embed'][i*video_len:(i+1)*video_len]
            bz_target = video_targets[i*video_len:(i+1)*video_len]
            zero = torch.tensor(0).to(bz_embedding.device)
            one = torch.tensor(1).to(bz_embedding.device)
            video_contras = {}
            memory = {}
            for f,(findice,fembed,ftarget) in enumerate(zip(indices,bz_embedding,bz_target)):
                vf_embed_k = fembed[findice[0]]
                if len(vf_embed_k.shape) ==1:
                    vf_embed_k.unsqueeze(0)
                vf_gt_id_k = ftarget['inst_id'][findice[1]]


                # neg sample
                sampled_index = set(random.sample(range(300),20)) 
                neg_index = sampled_index - set(findice[0].tolist())
                neg_index = list(neg_index)
                vf_embed_neg = fembed[neg_index]
                vf_embed = torch.cat([vf_embed_k,vf_embed_neg],dim=0)
                vf_gt_id = torch.cat([vf_gt_id_k,zero.repeat(len(neg_index))-2],dim=0) 

                video_contras[f] = (vf_embed,vf_gt_id)

                if f > 0:
                    num_inst = num_inst + len(ftarget['inst_id'])
                    similarity_matric =  torch.einsum("ac,bc->ab", video_contras[f-1][0], vf_embed_k)  #[num_1, num_gt]

                    v0_gt_id_m = video_contras[f-1][1].unsqueeze(-1).repeat(1,len(vf_gt_id_k))
                    v1_gt_id_m = vf_gt_id_k.unsqueeze(0).repeat(len(video_contras[f-1][1]),1)
                    similarity_label = (v0_gt_id_m == v1_gt_id_m).float()  # can be treat as one hot label 
                    # use focal loss instand of contrastive
                    # aux  cosine
                    # aux_contrastive_embed=nn.functional.normalize(video_contras[f-1][0].float(),dim=1)
                    # key_embed_i=nn.functional.normalize(vf_embed_k.float(),dim=1)    
                    # cosine = torch.einsum('nc,kc->nk',[aux_contrastive_embed,key_embed_i])

                    # batch_similarity_aux.append(cosine.flatten() )
                    batch_similarity.append(similarity_matric.flatten() )
                    batch_label.append(similarity_label.flatten() )
        if len(batch_similarity)==0 or torch.cat(batch_similarity).shape[0] == 0:
            track_loss = (video_outputs['pred_track_embed']*0).sum()
        else:
            contras_loss = 0
            aux_loss = 0
            for pred, label in zip(batch_similarity, batch_label):
                if len(pred) == 0:
                    continue
                pred = pred.unsqueeze(0)
                label = label.unsqueeze(0)
                # aux_pred = aux_pred.unsqueeze(0)

                pos_inds = (label == 1)
                neg_inds = (label == 0)
                pred_pos = pred * pos_inds.float()
                pred_neg = pred * neg_inds.float()
                # use -inf to mask out unwanted elements.
                pred_pos[neg_inds] = pred_pos[neg_inds] + float('inf')
                pred_neg[pos_inds] = pred_neg[pos_inds] + float('-inf')
                _pos_expand = torch.repeat_interleave(pred_pos, pred.shape[1], dim=1)
                _neg_expand = pred_neg.repeat(1, pred.shape[1])
                # [bz,N], N is all pos and negative samples on reference frame, label indicate it's pos or negative
                x = torch.nn.functional.pad((_neg_expand - _pos_expand), (0, 1), "constant", 0) 
                contras_loss += torch.logsumexp(x, dim=1)

            # track_loss = (contras_loss + 1.5*aux_loss)
            track_loss = contras_loss/max(num_inst,1)

        track_loss = track_loss #  /(self.video_info['bz'])
        return track_loss