"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
from lavis.common.registry import registry
from lavis.models.albef_models.albef_retrieval import AlbefRetrieval
from lavis.models.base_model import concat_all_gather


@registry.register_model("albef_retrieval_modified")
class AlbefRetrievalModified(AlbefRetrieval):
    """
    ALBEF retrieval model.

    Supported model types:
        - coco: fine-tuned ALBEF base model on COCO dataset (Karparthy split).
        - flickr: fine-tuned ALBEF base model on Flickr30k dataset.

    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("albef_retrieval", "coco")
        >>> model = load_model("albef_retrieval", "flickr")
    """

    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
        """Update the queue with new features in a FIFO manner.
        No longer requires `queue_size % batch_size == 0`.

        Args:
            image_feat (Tensor): Image features of the current batch.
            text_feat (Tensor): Text features of the current batch.
            idxs (Tensor, optional): Indices of the current batch samples.
        """
        # Gather features from all GPUs (if distributed)
        image_feats = concat_all_gather(image_feat)
        text_feats = concat_all_gather(text_feat)
        batch_size = image_feats.shape[0]

        ptr = int(self.queue_ptr)

        # Calculate remaining space until the end of the queue
        remaining = self.queue_size - ptr

        if remaining >= batch_size:
            # Simple case: enough space to insert the batch contiguously
            self.image_queue[:, ptr:ptr + batch_size] = image_feats.T
            self.text_queue[:, ptr:ptr + batch_size] = text_feats.T
            if idxs is not None:
                idxs = concat_all_gather(idxs)
                self.idx_queue[:, ptr:ptr + batch_size] = idxs.T
            ptr += batch_size
        else:
            # Wrap-around case: split the batch into two parts
            # First part fills the remaining queue space
            self.image_queue[:, ptr:] = image_feats[:remaining, :].T
            self.text_queue[:, ptr:] = text_feats[:remaining, :].T
            if idxs is not None:
                idxs = concat_all_gather(idxs)
                self.idx_queue[:, ptr:] = idxs[:remaining, :].T

            # Second part starts from the beginning of the queue
            self.image_queue[:, :batch_size - remaining] = image_feats[remaining:, :].T
            self.text_queue[:, :batch_size - remaining] = text_feats[remaining:, :].T
            if idxs is not None:
                self.idx_queue[:, :batch_size - remaining] = idxs[remaining:, :].T
            ptr = batch_size - remaining

        # Update pointer (automatically wraps around due to modulo)
        self.queue_ptr[0] = ptr % self.queue_size
