# ------------------------------------------------------------------------
# Modified and add the copyrights as well to Institution (Author)
# ------------------------------------------------------------------------
# Copyright (c) 2021 megvii-model. All Rights Reserved.

# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
This script implements the DETR (DEtection TRansformer) model and associated criteria for object detection tasks. 
It includes a custom class `ClipMatcher` for multi-object tracking, various utility functions for feature visualization, 
and an integration of optical flow techniques. The code has been adapted from Deformable DETR, with additional modules 
and layers to support tracking, memory banks, and the ETEM (Expected Temporal Embedding) layer.
"""

import copy
import math
import numpy as np
import torch
from torchvision.utils import save_image
import os
import torch.nn.functional as F
from torch import nn, Tensor
from typing import List
import logging
import cv2
import datetime
import sys
original_sys_path = sys.path.copy()

# Importing utility modules and functions from external sources
from util import box_ops, checkpoint
from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
                       accuracy, get_world_size, interpolate, get_rank,
                       is_dist_avail_and_initialized, inverse_sigmoid)

from models.structures import Instances, Boxes, pairwise_iou, matched_boxlist_iou
from scipy.sparse import coo_matrix
from .backbone import build_backbone
from .matcher import build_matcher
from .deformable_transformer_plus import build_deforamble_transformer
from .etem import build as build_query_interaction_layer  # ETEM layer instead of ETEM
from .memory_bank import build_memory_bank
from .deformable_detr import SetCriterion, MLP
from .segmentation import sigmoid_focal_loss
import matplotlib.pyplot as plt

# Adjusting system path for external imports and restoring after import
sys.path.append('models/core')
from raft import RAFT
from utils.utils import InputPadder
from utils import flow_viz
sys.path = original_sys_path  # Restore the original sys.path after import

# Utility function to save a frame tensor as a PNG image
def save_frame_as_png(frame_tensor, frame_index, save_dir="saved_frames"):
    """
    Saves the current frame as a PNG file.
    
    Args:
    frame_tensor (torch.Tensor): The frame tensor to save, expected shape [C, H, W].
    frame_index (int): Index of the frame for naming the file.
    save_dir (str): Directory where to save the frame images.
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Check if frame_tensor is a batch tensor with a batch size of 1
    if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1:
        frame_tensor = frame_tensor.squeeze(0)  # Remove the batch dimension if it's a single image in the batch
    
    # Normalize the image tensor to [0, 1] if it's not already
    if frame_tensor.max() > 1.0:
        frame_tensor = frame_tensor / 255.0
    
    # Create the file path
    file_path = os.path.join(save_dir, f"frame_{frame_index}.png")
    
    # Save the image
    save_image(frame_tensor, file_path)
    logging.debug(f"Frame {frame_index} saved to {file_path}")

# Utility function to draw HSV-based optical flow visualization
def draw_hsv(flow):
    """
    Converts optical flow into a color representation using HSV color space.
    
    Args:
    flow (np.ndarray): Optical flow array of shape [H, W, 2].
    
    Returns:
    np.ndarray: BGR image with visualized optical flow.
    """
    h, w = flow.shape[:2]
    fx, fy = flow[:,:,0], flow[:,:,1]
    ang = np.arctan2(fy, fx) + np.pi
    v = np.sqrt(fx*fx+fy*fy)
    hsv = np.zeros((h, w, 3), np.uint8)
    hsv[...,0] = ang*(180/np.pi/2)
    hsv[...,1] = 255
    hsv[...,2] = np.minimum(v*4, 255)
    bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    return bgr

# Function to visualize and save optical flow as an image
def visualize_optical_flow(flow):
    """
    Visualizes the optical flow using HSV color mapping and saves the image.
    
    Args:
    flow (np.ndarray): Optical flow array of shape [H, W, 2].
    """
    flow_color = draw_hsv(flow)
    plt.figure(figsize=(10, 5))
    plt.imshow(flow_color)
    plt.title("Optical Flow Visualization")
    plt.axis('off')
    plt.savefig('optical_flow.png')

# Function to visualize and save feature maps as heatmaps
def visualize_features(tensor):
    """
    Visualizes feature maps from a tensor and saves the heatmap images.
    
    Args:
    tensor (torch.Tensor): The input tensor of feature maps.
    """
    # Check and handle multi-channel data
    if tensor.dim() == 4 and tensor.size(1) > 1:  # Check if it's a batch of multi-channel data
        tensor = tensor[0]  # Take the first image in the batch
    if tensor.size(0) > 1:  # More than one channel
        tensor = tensor[0]  # Take the first channel

    tensor = tensor.detach().cpu()  # Move tensor to CPU and detach from the graph

    # Generate the datetime-based filename
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"splatted_feature_{timestamp}.png"

    # Define the directory to save the images
    save_dir = "/work/Institution/author/results/splatted_features_mot17"
    os.makedirs(save_dir, exist_ok=True)  # Ensure the directory exists

    # Save the visualization
    plt.imshow(tensor, cmap='hot')
    plt.colorbar()
    plt.savefig(os.path.join(save_dir, filename))
    plt.close()  # Close the plot to free up resources
    plt.clf()

# Function to visualize the original image and its corresponding optical flow side by side
def viz_flow(img, flo):
    """
    Visualizes the original image and its corresponding optical flow side by side.
    
    Args:
    img (torch.Tensor): The original image tensor.
    flo (torch.Tensor): The optical flow tensor.
    """
    logging.debug(f"Original flow shape: {flo.shape}")
    flo = flo[0].squeeze()  # Take the first element if batched, remove any extra dims
    logging.debug(f"Squeezed flow shape: {flo.shape}")
    
    # Permute and convert flow to numpy array if needed
    if flo.dim() == 3 and flo.size(0) == 2:  # [2, H, W]
        flo = flo.permute(1, 2, 0).cpu().numpy()  # Convert to [H, W, 2]

    # Handle img tensor
    if img.dim() == 4 and img.size(1) == 3:  # [N, C, H, W]
        img_tensor = img[0].permute(1, 2, 0).cpu().numpy()  # Convert to [H, W, C] for RGB
    elif img.dim() == 3 and img.size(0) == 3:  # [C, H, W]
        img_tensor = img.permute(1, 2, 0).cpu().numpy()  # Convert to [H, W, C] for RGB

    # Map flow to RGB image using a visualization utility
    flo_img = flow_viz.flow_to_image(flo)
    
    # Create figure and subplots
    plt.figure(figsize=(6, 10))  # Adjusted figure size for vertical layout
    plt.subplot(2, 1, 1)  # First subplot for the image (2 rows, 1 column, 1st position)
    plt.imshow(img_tensor)
    plt.title("Original Image")
    plt.axis('off')  # Hide axes for better visualization

    plt.subplot(2, 1, 2)  # Second subplot for the flow (2 rows, 1 column, 2nd position)
    plt.imshow(flo_img / 255.0)  # Assuming flow image needs scaling
    plt.title("Optical Flow")
    plt.axis('off')

    plt.savefig('combined_plot.png')

# Class for computing loss and tracking for clips, with focus on multi-object tracking
class ClipMatcher(SetCriterion):
    def __init__(self, num_classes, matcher, weight_dict, losses):
        """
        Initialize the ClipMatcher with the specified parameters.
        
        Args:
        num_classes (int): Number of object categories.
        matcher (nn.Module): Module to compute a matching between targets and proposals.
        weight_dict (dict): Dictionary containing loss names and their relative weights.
        losses (list): List of all the losses to be applied.
        """
        super().__init__(num_classes, matcher, weight_dict, losses)
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.losses = losses
        self.focal_loss = True
        self.losses_dict = {}
        self._current_frame_idx = 0

    def initialize_for_single_clip(self, gt_instances: List[Instances]):
        """
        Initializes tracking for a single clip by resetting relevant parameters.
        
        Args:
        gt_instances (List[Instances]): List of ground truth instances for each frame in the clip.
        """
        self.gt_instances = gt_instances
        self.num_samples = 0
        self.sample_device = None
        self._current_frame_idx = 0
        self.losses_dict = {}

    def _step(self):
        """
        Increments the current frame index by 1.
        """
        self._current_frame_idx += 1

    def calc_loss_for_track_scores(self, track_instances: Instances):
        """
        Calculate the loss for tracking scores.
        
        Args:
        track_instances (Instances): Track instances for which loss is calculated.
        """
        frame_id = self._current_frame_idx - 1
        gt_instances = self.gt_instances[frame_id]
        outputs = {
            'pred_logits': track_instances.track_scores[None],
        }
        device = track_instances.track_scores.device

        num_tracks = len(track_instances)
        src_idx = torch.arange(num_tracks, dtype=torch.long, device=device)
        tgt_idx = track_instances.matched_gt_idxes  # -1 for FP tracks and disappeared tracks

        track_losses = self.get_loss('labels',
                                     outputs=outputs,
                                     gt_instances=[gt_instances],
                                     indices=[(src_idx, tgt_idx)],
                                     num_boxes=1)

        self.losses_dict.update(
            {'frame_{}_track_{}'.format(frame_id, key): value for key, value in
             track_losses.items()})

    def get_num_boxes(self, num_samples):
        """
        Computes the number of boxes after distributed reduction.
        
        Args:
        num_samples (int): Number of samples.
        
        Returns:
        int: Number of boxes.
        """
        num_boxes = torch.as_tensor(num_samples, dtype=torch.float, device=self.sample_device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
        return num_boxes

    def get_loss(self, loss, outputs, gt_instances, indices, num_boxes, **kwargs):
        """
        Computes the specified loss.
        
        Args:
        loss (str): Loss type to compute.
        outputs (dict): Model outputs.
        gt_instances (List[Instances]): Ground truth instances.
        indices (list): Indices for matching.
        num_boxes (int): Number of boxes.
        
        Returns:
        dict: Computed loss.
        """
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, gt_instances, indices, num_boxes, **kwargs)
        
    def loss_boxes(self, outputs, gt_instances: List[Instances], indices: List[tuple], num_boxes):
        """
        Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
        
        Args:
        outputs (dict): Model outputs.
        gt_instances (List[Instances]): Ground truth instances.
        indices (List[tuple]): Matching indices between predictions and ground truths.
        num_boxes (int): Number of boxes.
        
        Returns:
        dict: Bounding box losses.
        """
        filtered_idx = []
        for src_per_img, tgt_per_img in indices:
            keep = tgt_per_img != -1
            filtered_idx.append((src_per_img[keep], tgt_per_img[keep]))
        indices = filtered_idx
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([gt_per_img.boxes[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0)
        
        target_obj_ids = torch.cat([gt_per_img.obj_ids[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0) 
        mask = (target_obj_ids != -1)
        
        loss_bbox = F.l1_loss(src_boxes[mask], target_boxes[mask], reduction='none')
        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes[mask]),
            box_ops.box_cxcywh_to_xyxy(target_boxes[mask])))
        logging.debug(f"loss bbox {loss_bbox.shape}")
        logging.debug(f"loss gious {loss_giou.shape}")

        losses = {}
        losses['loss_bbox'] = (loss_bbox).sum() / num_boxes
        losses['loss_giou'] = (loss_giou).sum() / num_boxes

        return losses
    
    def loss_labels(self, outputs, gt_instances: List[Instances], indices, num_boxes, log=False):
        """
        Classification loss (NLL).
        
        Args:
        outputs (dict): Model outputs.
        gt_instances (List[Instances]): Ground truth instances.
        indices (list): Matching indices between predictions and ground truths.
        num_boxes (int): Number of boxes.
        log (bool): Whether to log the class error.
        
        Returns:
        dict: Label losses.
        """
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        labels = []
        for gt_per_img, (_, J) in zip(gt_instances, indices):
            labels_per_img = torch.ones_like(J)
            if len(gt_per_img) > 0:
                labels_per_img[J != -1] = gt_per_img.labels[J[J != -1]]
            labels.append(labels_per_img)
        target_classes_o = torch.cat(labels)
        target_classes[idx] = target_classes_o
        if self.focal_loss:
            gt_labels_target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[:, :, :-1] 
            gt_labels_target = gt_labels_target.to(src_logits)
            loss_ce = sigmoid_focal_loss(src_logits.flatten(1),
                                         gt_labels_target.flatten(1),
                                         alpha=0.25,
                                         gamma=2,
                                         num_boxes=num_boxes, mean_in_dim1=False)
            loss_ce = loss_ce.sum()
        else:
            loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]

        return losses

    def match_for_single_frame(self, outputs: dict):
        """
        Performs matching between predictions and ground truth for a single frame.
        
        Args:
        outputs (dict): Model outputs for the current frame.
        
        Returns:
        Instances: Updated track instances after matching.
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
        gt_instances_i = self.gt_instances[self._current_frame_idx]  # gt instances of i-th image.
        track_instances: Instances = outputs_without_aux['track_instances']
        pred_logits_i = track_instances.pred_logits  # predicted logits of i-th image.
        pred_boxes_i = track_instances.pred_boxes  # predicted boxes of i-th image.

        obj_idxes = gt_instances_i.obj_ids
        obj_idxes_list = obj_idxes.detach().cpu().numpy().tolist()
        obj_idx_to_gt_idx = {obj_idx: gt_idx for gt_idx, obj_idx in enumerate(obj_idxes_list)}
        outputs_i = {
            'pred_logits': pred_logits_i.unsqueeze(0),
            'pred_boxes': pred_boxes_i.unsqueeze(0),
        }

        num_disappear_track = 0
        for j in range(len(track_instances)):
            obj_id = track_instances.obj_idxes[j].item()
            if obj_id >= 0:
                if obj_id in obj_idx_to_gt_idx:
                    track_instances.matched_gt_idxes[j] = obj_idx_to_gt_idx[obj_id]
                else:
                    num_disappear_track += 1
                    track_instances.matched_gt_idxes[j] = -1
            else:
                track_instances.matched_gt_idxes[j] = -1

        full_track_idxes = torch.arange(len(track_instances), dtype=torch.long).to(pred_logits_i.device)
        matched_track_idxes = (track_instances.obj_idxes >= 0)
        prev_matched_indices = torch.stack(
            [full_track_idxes[matched_track_idxes], track_instances.matched_gt_idxes[matched_track_idxes]], dim=1).to(
            pred_logits_i.device)

        unmatched_track_idxes = full_track_idxes[track_instances.obj_idxes == -1]

        tgt_indexes = track_instances.matched_gt_idxes
        tgt_indexes = tgt_indexes[tgt_indexes != -1]

        tgt_state = torch.zeros(len(gt_instances_i)).to(pred_logits_i.device)
        tgt_state[tgt_indexes] = 1
        untracked_tgt_indexes = torch.arange(len(gt_instances_i)).to(pred_logits_i.device)[tgt_state == 0]
        untracked_gt_instances = gt_instances_i[untracked_tgt_indexes]

        def match_for_single_decoder_layer(unmatched_outputs, matcher):
            new_track_indices = matcher(unmatched_outputs,
                                        [untracked_gt_instances]) 

            src_idx = new_track_indices[0][0]
            tgt_idx = new_track_indices[0][1]
            new_matched_indices = torch.stack([unmatched_track_idxes[src_idx], untracked_tgt_indexes[tgt_idx]],
                                              dim=1).to(pred_logits_i.device)
            return new_matched_indices

        unmatched_outputs = {
            'pred_logits': track_instances.pred_logits[unmatched_track_idxes].unsqueeze(0),
            'pred_boxes': track_instances.pred_boxes[unmatched_track_idxes].unsqueeze(0),
        }
        new_matched_indices = match_for_single_decoder_layer(unmatched_outputs, self.matcher)

        track_instances.obj_idxes[new_matched_indices[:, 0]] = gt_instances_i.obj_ids[new_matched_indices[:, 1]].long()
        track_instances.matched_gt_idxes[new_matched_indices[:, 0]] = new_matched_indices[:, 1]

        active_idxes = (track_instances.obj_idxes >= 0) & (track_instances.matched_gt_idxes >= 0)
        active_track_boxes = track_instances.pred_boxes[active_idxes]
        if len(active_track_boxes) > 0:
            gt_boxes = gt_instances_i.boxes[track_instances.matched_gt_idxes[active_idxes]]
            active_track_boxes = box_ops.box_cxcywh_to_xyxy(active_track_boxes)
            gt_boxes = box_ops.box_cxcywh_to_xyxy(gt_boxes)
            track_instances.iou[active_idxes] = matched_boxlist_iou(Boxes(active_track_boxes), Boxes(gt_boxes))

        matched_indices = torch.cat([new_matched_indices, prev_matched_indices], dim=0)

        self.num_samples += len(gt_instances_i) + num_disappear_track
        self.sample_device = pred_logits_i.device
        for loss in self.losses:
            new_track_loss = self.get_loss(loss,
                                           outputs=outputs_i,
                                           gt_instances=[gt_instances_i],
                                           indices=[(matched_indices[:, 0], matched_indices[:, 1])],
                                           num_boxes=1)
            self.losses_dict.update(
                {'frame_{}_{}'.format(self._current_frame_idx, key): value for key, value in new_track_loss.items()})

        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                unmatched_outputs_layer = {
                    'pred_logits': aux_outputs['pred_logits'][0, unmatched_track_idxes].unsqueeze(0),
                    'pred_boxes': aux_outputs['pred_boxes'][0, unmatched_track_idxes].unsqueeze(0),
                }
                new_matched_indices_layer = match_for_single_decoder_layer(unmatched_outputs_layer, self.matcher)
                matched_indices_layer = torch.cat([new_matched_indices_layer, prev_matched_indices], dim=0)
                for loss in self.losses:
                    if loss == 'masks':
                        continue
                    l_dict = self.get_loss(loss,
                                           aux_outputs,
                                           gt_instances=[gt_instances_i],
                                           indices=[(matched_indices_layer[:, 0], matched_indices_layer[:, 1])],
                                           num_boxes=1)
                    self.losses_dict.update(
                        {'frame_{}_aux{}_{}'.format(self._current_frame_idx, i, key): value for key, value in
                         l_dict.items()})
        self._step()
        return track_instances

    def forward(self, outputs, input_data: dict):
        """
        Forward method for processing outputs and input data.
        
        Args:
        outputs (dict): Model outputs.
        input_data (dict): Input data for the model.
        
        Returns:
        dict: Processed losses.
        """
        losses = outputs.pop("losses_dict")
        num_samples = self.get_num_boxes(self.num_samples)
        for loss_name, loss in losses.items():
            losses[loss_name] /= num_samples
        return losses

# Base class for runtime tracking during inference
class RuntimeTrackerBase(object):
    def __init__(self, score_thresh=0.7, filter_score_thresh=0.6, miss_tolerance=5):
        """
        Initialize the runtime tracker base with specified thresholds and tolerance.
        
        Args:
        score_thresh (float): Score threshold for considering a detection.
        filter_score_thresh (float): Threshold for filtering detections.
        miss_tolerance (int): Number of frames to tolerate missing tracks.
        """
        self.score_thresh = score_thresh
        self.filter_score_thresh = filter_score_thresh
        self.miss_tolerance = miss_tolerance
        self.max_obj_id = 0

    def clear(self):
        """
        Clears the tracker by resetting the max object ID.
        """
        self.max_obj_id = 0

    def update(self, track_instances: Instances):
        """
        Updates the track instances by managing disappear times and assigning new object IDs.
        
        Args:
        track_instances (Instances): Instances of tracks to be updated.
        """
        track_instances.disappear_time[track_instances.scores >= self.score_thresh] = 0
        for i in range(len(track_instances)):
            if track_instances.obj_idxes[i] == -1 and track_instances.scores[i] >= self.score_thresh:
                track_instances.obj_idxes[i] = self.max_obj_id
                self.max_obj_id += 1
            elif track_instances.obj_idxes[i] >= 0 and track_instances.scores[i] < self.filter_score_thresh:
                track_instances.disappear_time[i] += 1
                if track_instances.disappear_time[i] >= self.miss_tolerance:
                    track_instances.obj_idxes[i] = -1

# Post-process module to convert the model's output into the expected format by the COCO API
class TrackerPostProcess(nn.Module):
    """ This module converts the model's output into the format expected by the COCO API """
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def forward(self, track_instances: Instances, target_size) -> Instances:
        """
        Perform the computation to convert model outputs to COCO API format.
        
        Args:
        track_instances (Instances): Track instances from the model.
        target_size (tuple): Size of each image in the batch.
        
        Returns:
        Instances: Processed track instances with boxes and scores.
        """
        out_logits = track_instances.pred_logits
        out_bbox = track_instances.pred_boxes

        prob = out_logits.sigmoid()
        scores, labels = prob.max(-1)

        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        img_h, img_w = target_size
        scale_fct = torch.Tensor([img_w, img_h, img_w, img_h]).to(boxes)
        boxes = boxes * scale_fct[None, :]

        track_instances.boxes = boxes
        track_instances.scores = scores
        track_instances.labels = labels
        track_instances.remove('pred_logits')
        track_instances.remove('pred_boxes')
        return track_instances

# Function to clone modules for iterative bounding box refinement or decoding layers
def _get_clones(module, N):
    """
    Clones a module N times.
    
    Args:
    module (nn.Module): The module to clone.
    N (int): Number of clones to create.
    
    Returns:
    nn.ModuleList: List of cloned modules.
    """
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

# Module for softmax splatting of features with optical flow integration
class SoftmaxSplattingModule(nn.Module):
    def __init__(self, num_feature_levels):
        """
        Initializes the SoftmaxSplattingModule.
        
        Args:
        num_feature_levels (int): Number of feature levels.
        """
        super().__init__()
        self.weights = nn.Parameter(torch.randn(num_feature_levels))

    def forward(self, features: List[torch.Tensor], optical_flow: torch.Tensor) -> torch.Tensor:
        """
        Forward method for softmax splatting.
        
        Args:
        features (List[torch.Tensor]): List of feature maps from different levels.
        optical_flow (torch.Tensor): Optical flow tensor for feature warping.
        
        Returns:
        torch.Tensor: Aggregated feature map after softmax splatting.
        """
        weights = F.softmax(self.weights, dim=0)
        weights = weights.view(-1, 1, 1, 1)

        min_height = min(f.size(2) for f in features)
        min_width = min(f.size(3) for f in features)

        resized_features = [F.interpolate(f, size=(min_height, min_width), mode='bilinear', align_corners=False) for f in features]

        if isinstance(optical_flow, NestedTensor):
            optical_flow = optical_flow.tensors
        if optical_flow.dim() == 3:
            optical_flow = optical_flow.unsqueeze(0)
        optical_flow_resized = F.interpolate(optical_flow, size=(min_height, min_width), mode='bilinear', align_corners=False)

        if optical_flow_resized.size(1) == 2:
            optical_flow_expanded = optical_flow_resized.repeat(1, resized_features[0].size(1) // 2, 1, 1)
        else:
            optical_flow_expanded = torch.zeros_like(resized_features[0])

        weighted_flow = weights[-1] * optical_flow_expanded
        weighted_features = [w * f for w, f in zip(weights, resized_features)]
        all_features = weighted_features + [weighted_flow]
        aggregated_features = torch.sum(torch.stack(all_features, dim=0), dim=0)
        return aggregated_features

# Simple module for computing optical flow using OpenCV's Farneback method
class SimpleOpticalFlowModuleOpenCV(nn.Module):
    def __init__(self):
        super(SimpleOpticalFlowModuleOpenCV, self).__init__()

    def forward(self, frame1, frame2):
        """
        Compute optical flow using OpenCV's Farneback method.
        
        Args:
        frame1 (torch.Tensor): First frame in the pair.
        frame2 (torch.Tensor): Second frame in the pair.
        
        Returns:
        torch.Tensor: Computed optical flow.
        """
        frame1_np = frame1
        frame2_np = frame2

        if frame1_np.ndim == 4 and frame1_np.shape[0] == 1:
            frame1_np = torch.squeeze(frame1_np, dim=0)
        if frame2_np.ndim == 4 and frame2_np.shape[0] == 1:
            frame2_np = torch.squeeze(frame2_np, dim=0)

        if frame1_np.ndim != 3 or frame2_np.ndim != 3:
            raise ValueError(f"Expected 3 dimensions [C, H, W], got {frame1_np.dim()} and {frame2_np.dim()}")

        if frame1_np.size(0) == 3:
            frame1_np = frame1_np.permute(1, 2, 0)
        if frame2_np.size(0) == 3:
            frame2_np = frame2_np.permute(1, 2, 0)

        frame1_gray = cv2.cvtColor(frame1_np.numpy(), cv2.COLOR_RGB2GRAY)
        frame2_gray = cv2.cvtColor(frame2_np.numpy(), cv2.COLOR_RGB2GRAY)

        flow = cv2.calcOpticalFlowFarneback(frame1_gray, frame2_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)

        flow_tensor = torch.from_numpy(flow).permute(2, 0, 1).float()
        return flow_tensor

# Module for computing optical flow using RAFT (Recurrent All-Pairs Field Transforms)
class SimpleOpticalFlowModule(nn.Module):
    def __init__(self, device='cuda'):
        super(SimpleOpticalFlowModule, self).__init__()
        self.device = device
        args = {
            'small': False,
            'mixed_precision': False,
            'alternate_corr': True, 
            'dropout' : False,
            'corr_levels': None,
            'corr_radius': None
        }
        self.raft_model = torch.nn.DataParallel(RAFT(args))
        self.raft_model.load_state_dict(torch.load('mote_exp_mot17/raft-things.pth'))
        self.raft_model = self.raft_model.module
        self.raft_model.to(device)
        self.raft_model.eval()

    def load_image(self, imfile):
        
        img = np.array(Image.open(imfile)).astype(np.uint8)
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        img = img.unsqueeze(0)
        return img[None].to(self.device)

    def forward(self, img1, img2):
        """
        Compute optical flow between two images using the RAFT model.
        
        Args:
        img1 (torch.Tensor): First image.
        img2 (torch.Tensor): Second image.
        
        Returns:
        torch.Tensor: Computed optical flow.
        """
        img1 = img1.unsqueeze(0)
        img2 = img2.unsqueeze(0)
        padder = InputPadder(img1.shape, mode='sintel')
        img1, img2 = padder.pad(img1, img2)
        
        with torch.no_grad():
            _, flow_up = self.raft_model(img1, img2, iters=20, test_mode=True)
        return flow_up

# Main MOTE (Multi-Object Tracking with Deformable DETR) class
class MOTE(nn.Module):
    def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, criterion, track_embed,
                 aux_loss=True, with_box_refine=False, two_stage=False, memory_bank=None, use_checkpoint=False):
        """
        Initializes the MOTE model.
        
        Args:
        backbone (nn.Module): Backbone network for feature extraction.
        transformer (nn.Module): Transformer module for sequence processing.
        num_classes (int): Number of object classes.
        num_queries (int): Number of object queries.
        num_feature_levels (int): Number of feature levels.
        criterion (nn.Module): Criterion module for calculating losses.
        track_embed (nn.Module): Embedding module for track management.
        aux_loss (bool): Whether to use auxiliary decoding losses.
        with_box_refine (bool): Whether to refine bounding boxes iteratively.
        two_stage (bool): Whether to use a two-stage approach.
        memory_bank (nn.Module): Optional memory bank module for track history.
        use_checkpoint (bool): Whether to use checkpointing for memory efficiency.
        """
        super().__init__()
        self.num_queries = num_queries
        self.track_embed = track_embed
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.num_classes = num_classes
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.num_feature_levels = num_feature_levels
        self.use_checkpoint = use_checkpoint
        if not two_stage:
            self.query_embed = nn.Embedding(num_queries, hidden_dim * 2)
        if num_feature_levels > 1:
            num_backbone_outs = len(backbone.strides)
            input_proj_list = []
            for _ in range(num_backbone_outs):
                in_channels = backbone.num_channels[_]
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
            for _ in range(num_feature_levels - num_backbone_outs):
                input_proj_list.append(nn.Sequential(
                    nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(32, hidden_dim),
                ))
                in_channels = hidden_dim
            self.input_proj = nn.ModuleList(input_proj_list)
        else:
            self.input_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
                    nn.GroupNorm(32, hidden_dim),
                )])
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.with_box_refine = with_box_refine
        self.two_stage = two_stage

        prior_prob = 0.01
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        self.class_embed.bias.data = torch.ones(num_classes) * bias_value
        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
        for proj in self.input_proj:
            nn.init.xavier_uniform_(proj[0].weight, gain=1)
            nn.init.constant_(proj[0].bias, 0)

        num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
        if with_box_refine:
            self.class_embed = _get_clones(self.class_embed, num_pred)
            self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
            nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
            self.transformer.decoder.bbox_embed = self.bbox_embed
        else:
            nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
            self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
            self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
            self.transformer.decoder.bbox_embed = None
        if two_stage:
            self.transformer.decoder.class_embed = self.class_embed
            for box_embed in self.bbox_embed:
                nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
        self.post_process = TrackerPostProcess()
        self.track_base = RuntimeTrackerBase()
        self.criterion = criterion
        self.memory_bank = memory_bank
        self.mem_bank_len = 0 if memory_bank is None else memory_bank.max_his_length
        self.softmax_splatting = SoftmaxSplattingModule(num_feature_levels)
        self.optical_flow_module = SimpleOpticalFlowModule()
        
    def _generate_empty_tracks(self):
        """
        Generates an empty track instance with initialized fields for tracking.
        
        Returns:
        Instances: Empty track instances with initialized fields.
        """
        logging.debug("Generating empty tracks for MOTE.")
        track_instances = Instances((1, 1))
        num_queries, dim = self.query_embed.weight.shape  
        device = self.query_embed.weight.device
        logging.debug(f"Query Embed Weight Shape: {self.query_embed.weight.shape}")
        logging.debug(f"Device: {device}")

        try:
            track_instances.ref_pts = self.transformer.reference_points(self.query_embed.weight[:, :dim // 2])
            logging.debug("Reference points generated successfully.")
        except Exception as e:
            logging.error(f"Failed to generate reference points: {str(e)}")

        track_instances.query_pos = self.query_embed.weight
        track_instances.output_embedding = torch.zeros((num_queries, dim >> 1), device=device)
        track_instances.obj_idxes = torch.full((len(track_instances),), -1, dtype=torch.long, device=device)
        track_instances.matched_gt_idxes = torch.full((len(track_instances),), -1, dtype=torch.long, device=device)
        track_instances.disappear_time = torch.zeros((len(track_instances),), dtype=torch.long, device=device)
        track_instances.iou = torch.zeros((len(track_instances),), dtype=torch.float, device=device)
        track_instances.scores = torch.zeros((len(track_instances),), dtype=torch.float, device=device)
        track_instances.track_scores = torch.zeros((len(track_instances),), dtype=torch.float, device=device)
        track_instances.pred_boxes = torch.zeros((len(track_instances), 4), dtype=torch.float, device=device)
        track_instances.pred_logits = torch.zeros((len(track_instances), self.num_classes), dtype=torch.float, device=device)

        mem_bank_len = self.mem_bank_len
        track_instances.mem_bank = torch.zeros((len(track_instances), mem_bank_len, dim // 2), dtype=torch.float32, device=device)
        track_instances.mem_padding_mask = torch.ones((len(track_instances), mem_bank_len), dtype=torch.bool, device=device)
        track_instances.save_period = torch.zeros((len(track_instances),), dtype=torch.float32, device=device)

        logging.debug("Empty tracks generated and configured.")
        return track_instances.to(device)

    def clear(self):
        """
        Clears the tracking base by resetting all tracks.
        """
        self.track_base.clear()

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        """
        Creates auxiliary losses for all decoder layers except the last one.
        
        Args:
        outputs_class (torch.Tensor): Class prediction outputs.
        outputs_coord (torch.Tensor): Bounding box prediction outputs.
        
        Returns:
        list: List of auxiliary losses.
        """
        return [{'pred_logits': a, 'pred_boxes': b, }
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

    def _forward_single_image(self, samples, optical_flow, track_instances: Instances):
        """
        Forward pass for a single image.
        
        Args:
        samples (NestedTensor): Input samples containing images and masks.
        optical_flow (torch.Tensor): Optical flow tensor for the current frame.
        track_instances (Instances): Track instances from the previous frame.
        
        Returns:
        dict: Output predictions including logits, bounding boxes, and reference points.
        """
        logging.debug("Starting _forward_single_image.")
        features, pos = self.backbone(samples)
        src, mask = features[-1].decompose()
        assert mask is not None, "Mask cannot be None"
        logging.debug(f"Decomposed features from the last layer with shape {src.shape} and mask.")

        if isinstance(optical_flow, Instances):
            logging.warning("optical_flow is an Instances object, initializing with zeros")
            optical_flow = torch.zeros((samples.tensors.shape[0], 2, *samples.tensors.shape[2:]), device=samples.tensors.device)
        elif optical_flow is None:
            logging.warning("optical_flow is None, initializing with zeros")
            optical_flow = torch.zeros((samples.tensors.shape[0], 2, *samples.tensors.shape[2:]), device=samples.tensors.device)
        
        optical_flow = optical_flow.to(self.query_embed.weight.device)

        srcs = []
        masks = []
        for l, feat in enumerate(features):
            src, mask = feat.decompose()
            srcs.append(self.input_proj[l](src))
            masks.append(mask)
            assert mask is not None, f"Mask at level {l} cannot be None"
            logging.debug(f"Feature level {l}: src shape {src.shape}, mask shape {mask.shape}")
        if isinstance(optical_flow, NestedTensor):
            optical_flow = optical_flow.tensors  
        splatted_features = self.softmax_splatting(srcs, optical_flow.to(self.query_embed.weight.device))
        logging.debug(f"Splatted features requires_grad after computation: {splatted_features.requires_grad}")
        logging.debug(f"Splatted Features Shape: {splatted_features.shape}")
        #Commented For Demo
        #visualize_features(splatted_features)
        
        if self.num_feature_levels > len(srcs):
            _len_srcs = len(srcs)
            logging.debug(f"Additional feature levels processing for levels > {_len_srcs}")
            for l in range(_len_srcs, self.num_feature_levels):
                if l == _len_srcs:
                    src = self.input_proj[l](features[-1].tensors)
                else:
                    src = self.input_proj[l](srcs[-1])
                m = samples.mask
                mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
                pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
                srcs.append(src)
                masks.append(mask)
                pos.append(pos_l)
                
                logging.debug(f"Extended src shape {src.shape} at level {l}")
        hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, track_instances.query_pos, ref_pts=track_instances.ref_pts)
        logging.debug("Transformer outputs generated.")

        outputs_classes = []
        outputs_coords = []
        for lvl in range(hs.shape[0]):
            if lvl == 0:
                reference = init_reference
            else:
                reference = inter_references[lvl - 1]
            reference = inverse_sigmoid(reference)
            outputs_class = self.class_embed[lvl](hs[lvl])
            tmp = self.bbox_embed[lvl](hs[lvl])
            if reference.shape[-1] == 4:
                tmp += reference
            else:
                assert reference.shape[-1] == 2
                tmp[..., :2] += reference
            outputs_coord = tmp.sigmoid()
            outputs_classes.append(outputs_class)
            outputs_coords.append(outputs_coord)
            logging.debug(f"Level {lvl}: Outputs coordinates shape {outputs_coord.shape}")

        outputs_class = torch.stack(outputs_classes)
        outputs_coord = torch.stack(outputs_coords)
        logging.debug("Stacked output classes and coordinates.")

        ref_pts_all = torch.cat([init_reference[None], inter_references[:, :, :, :2]], dim=0)
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 'ref_pts': ref_pts_all[5]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        out['hs'] = hs[-1]
        out['splatted_features'] = splatted_features

        logging.debug("Completed processing of _forward_single_image.")
        return out

    def _post_process_single_image(self, frame_res, track_instances, is_last):
        """
        Post-processes the model output for a single image, updating track instances.

        Args:
        frame_res (dict): Results from the forward pass of a single image.
        track_instances (Instances): Track instances from the previous frame.
        is_last (bool): Whether this is the last frame in the sequence.

        Returns:
        dict: Updated frame results including track instances.
        """
        with torch.no_grad():
            if self.training:
                track_scores = frame_res['pred_logits'][0, :].sigmoid().max(dim=-1).values
            else:
                track_scores = frame_res['pred_logits'][0, :, 0].sigmoid()

        track_instances.scores = track_scores
        track_instances.pred_logits = frame_res['pred_logits'][0]
        track_instances.pred_boxes = frame_res['pred_boxes'][0]
        track_instances.output_embedding = frame_res['hs'][0]

        if self.training:
            # In training mode, track id will be assigned by the matcher.
            frame_res['track_instances'] = track_instances
            track_instances = self.criterion.match_for_single_frame(frame_res)
        else:
            # In inference mode, each track will be assigned a unique global id by the track base.
            self.track_base.update(track_instances)

        if self.memory_bank is not None:
            track_instances = self.memory_bank(track_instances)
            if self.training:
                self.criterion.calc_loss_for_track_scores(track_instances)

        tmp = {}
        tmp['init_track_instances'] = self._generate_empty_tracks()
        tmp['splatted_features'] = frame_res['splatted_features']
        tmp['track_instances'] = track_instances

        if not is_last:
            logging.debug('Processing non-final frame in the sequence.')
            out_track_instances = self.track_embed(tmp, tmp['splatted_features'])
            frame_res['track_instances'] = out_track_instances
        else:
            frame_res['track_instances'] = None

        return frame_res

    @torch.no_grad()
    def inference_single_image(self, img, ori_img_size, optical_flow=None, track_instances=None):
        """
        Runs inference on a single image, applying the tracking logic.

        Args:
        img (Tensor): The input image tensor.
        ori_img_size (tuple): The original image size before preprocessing.
        optical_flow (Tensor, optional): Precomputed optical flow, if available.
        track_instances (Instances, optional): Track instances from the previous frame.

        Returns:
        dict: Track instances and reference points for the current frame.
        """
        if not isinstance(img, NestedTensor):
            img = nested_tensor_from_tensor_list(img)
        if track_instances is None:
            track_instances = self._generate_empty_tracks()

        logging.debug("Running forward pass for single image with tracking.")

        if optical_flow is None:
            logging.warning("Optical flow is not provided. Initializing with zeros.")
            optical_flow = torch.zeros((img.tensors.shape[0], 2, *img.tensors.shape[2:]), device=img.tensors.device)

        res = self._forward_single_image(img, optical_flow, track_instances=track_instances)
        logging.debug("Running post-process for the current frame.")
        res = self._post_process_single_image(res, track_instances, False)

        track_instances = res['track_instances']
        track_instances = self.post_process(track_instances, ori_img_size)

        ret = {'track_instances': track_instances}
        if 'ref_pts' in res:
            ref_pts = res['ref_pts']
            img_h, img_w = ori_img_size
            scale_fct = torch.Tensor([img_w, img_h]).to(ref_pts)
            ref_pts = ref_pts * scale_fct[None]
            ret['ref_pts'] = ref_pts

        return ret

    def forward(self, data: dict):
        """
        Forward pass for the entire batch of images in a sequence.

        Args:
        data (dict): Input data containing frames, ground truth instances, etc.

        Returns:
        dict: Output predictions or losses, depending on training/inference mode.
        """
        if self.training:
            self.criterion.initialize_for_single_clip(data['gt_instances'])

        frames = data['imgs']  # List of image tensors.
        outputs = {
            'pred_logits': [],
            'pred_boxes': [],
        }

        track_instances = self._generate_empty_tracks()
        keys = list(track_instances._fields.keys())
        previous_frame_tensors = frames[0]

        for frame_index, frame in enumerate(frames):
            frame.requires_grad = False
            is_last = frame_index == len(frames) - 1
            logging.debug(f"Processing frame {frame_index}/{len(frames)}, last frame: {is_last}")

            if frame_index > 0:
                try:
                    logging.debug(f"Computing optical flow between frames {frame_index-1} and {frame_index}.")
                    optical_flow = self.optical_flow_module(previous_frame_tensors, frame)
                    logging.debug(f"Optical flow computed with shape {optical_flow.shape}.")
                    if optical_flow is None:
                        logging.error("Optical flow computation returned None.")
                        raise ValueError("Optical flow computation failed.")
                except Exception as e:
                    logging.error(f"Exception during optical flow computation: {str(e)}")
                    optical_flow = torch.zeros((frame.size(0), 2, *frame.shape[2:]), device=frame.device)
            else:
                logging.debug("Initializing dummy optical flow for the first frame.")
                optical_flow = torch.zeros((frame.size(0), 2, *frame.shape[2:]), device=frame.device)

            if self.use_checkpoint and not is_last:
                def fn(frame, optical_flow, *args):
                    frame = nested_tensor_from_tensor_list([frame])  # Convert here within the checkpoint
                    optical_flow = nested_tensor_from_tensor_list([optical_flow])  # Also convert optical flow if needed
                    tmp = Instances((1, 1), **dict(zip(keys, args)))
                    frame_res = self._forward_single_image(frame, optical_flow, tmp)
                    return (
                        frame_res['pred_logits'],
                        frame_res['pred_boxes'],
                        frame_res['ref_pts'],
                        frame_res['hs'],
                        frame_res['splatted_features'],
                        *[aux['pred_logits'] for aux in frame_res.get('aux_outputs', [])],
                        *[aux['pred_boxes'] for aux in frame_res.get('aux_outputs', [])]
                    )

                args = [frame, optical_flow] + [track_instances.get(k) for k in keys]
                params = tuple((p for p in self.parameters() if p.requires_grad))
                tmp = checkpoint.CheckpointFunction.apply(fn, len(args), *args, *params)
                num_aux_outputs = 5  # Adjust based on actual aux_outputs count
                frame_res = {
                    'pred_logits': tmp[0],
                    'pred_boxes': tmp[1],
                    'ref_pts': tmp[2],
                    'hs': tmp[3],
                    'splatted_features': tmp[4],
                    'aux_outputs': [{
                        'pred_logits': tmp[5 + i],
                        'pred_boxes': tmp[5 + num_aux_outputs + i],
                    } for i in range(num_aux_outputs)]
                }
            else:
                current_frame = nested_tensor_from_tensor_list([frame])
                frame_res = self._forward_single_image(current_frame, optical_flow, track_instances)

            frame_res = self._post_process_single_image(frame_res, track_instances, is_last)
            track_instances = frame_res['track_instances']
            outputs['pred_logits'].append(frame_res['pred_logits'])
            outputs['pred_boxes'].append(frame_res['pred_boxes'])

            previous_frame_tensors = frame  # Update for the next iteration

        if not self.training:
            outputs['track_instances'] = track_instances
        else:
            outputs['losses_dict'] = self.criterion.losses_dict

        return outputs


def build(args):
    """
    Build function to create the MOTE model, criterion, and postprocessors.

    Args:
    args (Namespace): Arguments containing configuration for the model.

    Returns:
    tuple: The built model, criterion, and postprocessors.
    """
    dataset_to_num_classes = {
        'coco': 91,
        'coco_panoptic': 250,
        'e2e_mot': 1,
        'e2e_dance': 1,
        'e2e_joint': 1,
        'e2e_static_mot': 1,
    }
    assert args.dataset_file in dataset_to_num_classes
    num_classes = dataset_to_num_classes[args.dataset_file]
    device = torch.device(args.device)

    backbone = build_backbone(args)
    transformer = build_deforamble_transformer(args)
    d_model = transformer.d_model
    hidden_dim = args.dim_feedforward
    query_interaction_layer = build_query_interaction_layer(args, args.query_interaction_layer, d_model, hidden_dim, d_model * 2)

    img_matcher = build_matcher(args)
    num_frames_per_batch = max(args.sampler_lengths)
    weight_dict = {}
    for i in range(num_frames_per_batch):
        weight_dict.update({"frame_{}_loss_ce".format(i): args.cls_loss_coef,
                            "frame_{}_loss_bbox".format(i): args.bbox_loss_coef,
                            "frame_{}_loss_giou".format(i): args.giou_loss_coef,
                            })

    if args.aux_loss:
        for i in range(num_frames_per_batch):
            for j in range(args.dec_layers - 1):
                weight_dict.update({"frame_{}_aux{}_loss_ce".format(i, j): args.cls_loss_coef,
                                    "frame_{}_aux{}_loss_bbox".format(i, j): args.bbox_loss_coef,
                                    "frame_{}_aux{}_loss_giou".format(i, j): args.giou_loss_coef,
                                    })
    if args.memory_bank_type is not None and len(args.memory_bank_type) > 0:
        memory_bank = build_memory_bank(args, d_model, hidden_dim, d_model * 2)
        for i in range(num_frames_per_batch):
            weight_dict.update({"frame_{}_track_loss_ce".format(i): args.cls_loss_coef})
    else:
        memory_bank = None

    losses = ['labels', 'boxes']
    criterion = ClipMatcher(num_classes, matcher=img_matcher, weight_dict=weight_dict, losses=losses)
    criterion.to(device)
    postprocessors = {}

    model = MOTE(
        backbone,
        transformer,
        track_embed=query_interaction_layer,
        num_feature_levels=args.num_feature_levels,
        num_classes=num_classes,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
        criterion=criterion,
        with_box_refine=args.with_box_refine,
        two_stage=args.two_stage,
        memory_bank=memory_bank,
        use_checkpoint=args.use_checkpoint,
    )
    return model, criterion, postprocessors
