from abc import ABC
import dataclasses
import math
import logging
from collections import deque, Counter
import typing
from types import SimpleNamespace

from enum import Enum
import random
import torch
from modeling_autoencoder_v2 import (
    Autoencoder,
    AutoencoderOutput,
)
import numpy as np

from modeling_autoencoder_v2 import Autoencoder


def mask_non_unique(indices: torch.Tensor) -> torch.Tensor:
    """
    Given a tensor of indices, it returns a mask where the non-unique indices are set
    - 0 if the index is unique
    - 1 if the index is non-unique
    """
    unique_values = torch.unique(indices)
    where_non_unique = (
        (
            matches := unique_values.unsqueeze(0)  # [BD, U], torch.bool  # [1, U]
            == indices.reshape(-1, 1)  # [BD, 1]
        )
        .long()
        .cumsum(dim=0)  # [BD, U]
        * matches  # [BD, U]
    ).long().sum(
        dim=1
    ) != 1  # [BD]
    where_non_unique = where_non_unique.reshape(indices.size())
    return where_non_unique


def contr_loss(
    start_from: int,
    source_vectors: torch.FloatTensor,
    target_vectors: torch.FloatTensor,
    hinge_margin: float,
    mask_match: bool = False,
    source_mask: torch.BoolTensor | None = None,
    target_mask: torch.BoolTensor | None = None,
    exact_match: float = 0.0,
    source_level_vectors: torch.LongTensor | None = None,
    target_level_vectors: torch.LongTensor | None = None,
    weigh_by_idf: bool = False,
    collection_frequencies: dict[int, int] | None = None,
    collection_size: int | None = None,
    codebook_size: int | None = None,
    source_indices: torch.LongTensor | None = None,
    target_indices: torch.LongTensor | None = None,
    use_topk: int = -1,
    loss: str = "colbert_softmax",
    softmax_loss_temperature: float = 1.0,
    one_against_all: bool = False,
) -> torch.FloatTensor:
    """
    Computes the ColBERT loss between two sets of vectors.
    The loss is computed as follows:
    a) for each source vector s, for each target sequence T, we find the vector with the max similarity
    b) for each source sequence S, we compute the S-T score as the mean of the max similarities across S
    c) for each source sequence S, the margin loss is between each positive and all the negatives

    additional options:

    mask match: if True, the (source|target)_mask is used to mask the vectors.
    This is meant to be used with mask_non_unique
    - source side masking doesn't compute the loss for a specific source vector
    - target side masking excludes the target vector from the max similarity by setting it's value to 0.0 (ok because the cosine is capped to 0.0)
        even if the target vector is selected, the gradient for the target vector will be 0.0

    exact match: float in [0.0, 1.0] that controls how exact the match should be
    - 0.0: no limit on matching, we compute the max over the full T in step a)
    - 1.0: the max similarity is computed only for positions corresponding to fully matching levels -> matching indices
    - 0.0 < exact_match < 1.0:

    weigh by idf: if True, the cosine similarity is multiplied by the IDF of the corresponding latent token
    - the IDF is computed as the log of the inverse of the frequency of the latent token in the collection
    - the IDF is rescaled to [0.0, 1.0] by dividing by the max IDF
    """

    bsz1, tsz1, hsz1 = source_vectors.size()
    bsz2, tsz2, hsz2 = target_vectors.size()
    assert hsz1 == hsz2

    if exact_match:
        lsz1 = source_level_vectors.size(-1)
        lsz2 = target_level_vectors.size(-1)
        assert source_level_vectors is not None
        assert target_level_vectors is not None
        assert source_level_vectors.size() == (bsz1, tsz1, lsz1)
        assert target_level_vectors.size() == (bsz2, tsz2, lsz2)
        assert lsz1 == lsz2

    if mask_match:
        assert (source_mask is not None) or (target_mask is not None)
        if source_mask is not None:
            assert source_mask.size() == (bsz1, tsz1)
        if target_mask is not None:
            assert target_mask.size() == (bsz2, tsz2)

    if weigh_by_idf:
        assert collection_frequencies is not None
        assert collection_size is not None
        assert codebook_size is not None
        assert source_indices is not None
        assert target_indices is not None
        assert source_indices.size() == (bsz1, tsz1)
        assert target_indices.size() == (bsz2, tsz2)

    # normalize vectors so we bound the dot prod to [-1, +1]
    # the dot prod between unit vectors is the cosine similarity
    source_vectors = source_vectors / torch.norm(source_vectors, dim=-1, keepdim=True)
    target_vectors = target_vectors / torch.norm(target_vectors, dim=-1, keepdim=True)

    # I just love einsum
    # cosine
    scores = torch.einsum("bth,BTh->btBT", source_vectors, target_vectors)
    scores = scores.clamp_min(0.0)
    # clamp because there are no negative scores in BM25/TF-IDF

    # # mse distances
    # scores = 1.0 - torch.cdist(
    #     source_vectors.view(1, -1, source_vectors.size(-1)),
    #     target_vectors.view(1, -1, target_vectors.size(-1)),
    # ).view(
    #     source_vectors.size(0),
    #     source_vectors.size(1),
    #     target_vectors.size(0),
    #     target_vectors.size(1),
    # ).clamp_min(0.0)

    scores = scores.masked_fill(
        target_mask.view(1, 1, *target_mask.size()) == 0,
        -2.0,
    )

    scores_red1 = scores.max(dim=3).values  # btB
    scores_red1 = scores_red1.masked_fill(
        source_mask.view(*source_mask.size(), 1) == 0,
        -2.0,
    )

    scores_red2 = scores_red1.sum(dim=1)  # bB

    loss = torch.nn.functional.cross_entropy(
        scores_red2,
        start_from + torch.arange(scores_red2.size(0)).to(scores_red2.device),
    )
    return loss

    # if mask_match and source_mask is not None:
    #     # masking the loss on queries
    #     max_scores = max_scores.masked_fill(source_mask.unsqueeze(2) == 0, -2.0)
    #     max_scores = max_scores.sum(1)
    #     # step b) mean over queries
    # else:
    #     # step b)
    #     max_scores = max_scores.sum(1)

    # assert max_scores.size() == (bsz1, bsz2)

    # with torch.autocast(device_type=source_vectors.device.type, dtype=torch.float64):

    #     if loss == "simclr":
    #         # step c) simclr loss
    #         sb1, st1, sb2, st2 = scores.size()
    #         scores = scores.view(sb1 * st1, sb2 * st2)
    #         labels = torch.arange(sb1 * st1, device=scores.device)
    #         loss_sim = torch.nn.functional.cross_entropy(
    #             scores / softmax_loss_temperature,  # temperature
    #             labels + start_from,
    #         )
    #         return loss_sim

    #     elif loss == "colbert_softmax":
    #         # step c) softmax loss
    #         loss_sim = torch.nn.functional.cross_entropy(
    #             max_scores / softmax_loss_temperature,  # temperature
    #             torch.arange(
    #                 start_from,
    #                 start_from + bsz1,
    #                 device=source_vectors.device,
    #                 dtype=torch.long,
    #             ),
    #         )
    #         return loss_sim

    # indices to identify positives and negatives
    gather_indices = torch.arange(
        start_from, start_from + bsz1, device=source_vectors.device, dtype=torch.long
    ).unsqueeze(1)
    assert gather_indices.size() == (bsz1, 1)

    # we take the diagonal to get the positives
    positives = max_scores.gather(1, gather_indices)
    assert positives.size() == (bsz1, 1)

    # we mask off-diagonal elements to get the negatives
    negatives = max_scores.scatter(1, gather_indices, 0.0)
    assert negatives.size() == (bsz1, bsz2)

    # step c) hinge loss
    loss_sim = torch.nn.functional.relu(hinge_margin - positives + negatives)

    if use_topk > 0:
        loss_sim, _ = loss_sim.topk(use_topk, dim=1)

    return loss_sim.mean()


def extract_vectors(
    which: typing.Literal[
        "encoder",
        "encoder_downproject",
        "fsq_prequant",
        "fsq",
        "fsq_upproject",
        "contriever_emb",
    ],
    output: "AutoencoderOutput",
    autoencoder: typing.Optional["Autoencoder"] = None,
    start_idx: int = 0,
    end_idx: int | None = None,
) -> torch.Tensor:

    match which:
        case "contriver_emb":
            return output[1][start_idx:end_idx]

        case "encoder":
            return output.encoder_outputs[start_idx:end_idx]

        case "encoder_downproject":
            assert autoencoder is not None
            return autoencoder.downproject(
                extract_vectors(
                    "encoder",
                    output=output,
                    autoencoder=autoencoder,
                    start_idx=start_idx,
                    end_idx=end_idx,
                )
            )

        case "fsq_prequant":
            assert autoencoder is not None
            return extract_vectors(
                "encoder_downproject",
                output=output,
                autoencoder=autoencoder,
                start_idx=start_idx,
                end_idx=end_idx,
            ).tanh()

        case "fsq":
            return output.quantizer_outputs[0][start_idx:end_idx]

        case "fsq_upproject":
            assert autoencoder is not None

            levels_vectors = extract_vectors(
                "fsq",
                output=output,
                autoencoder=autoencoder,
                start_idx=start_idx,
                end_idx=end_idx,
            )

            return autoencoder.upproject(levels_vectors)

        case _:
            raise ValueError(f"Unknown vector extraction method: {which}")


def extract_latent_embeddings(
    which: typing.Literal[
        "fsq", "fsq_upproject", "fsq_upproject_embeddings", "contriever_emb"
    ],
    autoencoder: "Autoencoder",
) -> torch.Tensor:
    """
    Extracts the latent embeddings from the FSQ
    """

    indices = torch.arange(autoencoder.quantizer.codebook_size).to(autoencoder.device())

    match which:
        case "contriever_emb":
            levels_vectors = autoencoder.quantizer.indices_to_codes(
                indices.unsqueeze(0)
            ).squeeze(0)
            return levels_vectors
        case "fsq":
            levels_vectors = autoencoder.quantizer.indices_to_codes(
                indices.unsqueeze(0)
            ).squeeze(0)
            return levels_vectors

        case "fsq_upproject":
            levels_vectors = extract_latent_embeddings("fsq", autoencoder)
            return autoencoder.upproject(levels_vectors)

        case "fsq_upproject_embeddings":
            fsq_upproject_vectors = extract_latent_embeddings(
                "fsq_upproject", autoencoder
            )
            return fsq_upproject_vectors + autoencoder.embed2(indices)

        case _:
            raise ValueError(f"Unknown latent embedding extraction method: {which}")


@dataclasses.dataclass
class StreamingTFIDFResults:
    positive_tf_idf: float
    max_negative_tf_idf: float
    positive_intersection_length: float
    source_uniq_len: float
    source_len: float
    target_uniq_len: float
    target_len: float
    freq_divergence: float
    new_positive_intersection_length: float
    new_negative_intersection_length: float
    positive_unique_intersection_length: float
    negative_unique_intersection_length: float
    positive_same_pos_intersection_length: float
    negative_same_pos_intersection_length: float


class StreamingTFIDF:

    # how many documents to keep track of
    history_size: int

    # latent padding token. -1 == not set
    padding_token: int

    # all documents frequencies, merged
    collection_target_frequencies: Counter

    # how many documents we have seen
    seen_docs: int

    # all documents frequencies, split by batch
    _all_targets_frequencies: deque[Counter]

    # all batch sizes
    _all_seen_batch_sizes: deque[int]

    def __init__(
        self,
        history_size: int = 250,
        padding_token: int = -1,
        k1: float = 0.9,
        b: float = 0.4,
    ):
        self.history_size = history_size
        self.collection_target_frequencies = Counter()
        self.padding_token = padding_token
        self.seen_docs = 0
        self._all_targets_frequencies = deque()
        self._all_seen_batch_sizes = deque()
        self.bm25_k1 = k1
        self.bm25_b = b
        self.average_len = None

    def step(
        self,
        target_indices: torch.LongTensor,
        source_indices: torch.LongTensor,
        source_mask: torch.LongTensor,
        target_mask: torch.LongTensor,
    ) -> StreamingTFIDFResults:
        # one Counter per document in current batch
        # will be used to find the negative doc with the max tf-idf
        this_batch_document_frequencies_split = []
        # merged Counter for the current batch
        this_batch_document_frequencies = Counter()

        if source_mask is not None:
            # set indices where the mask is 0 to the padding token
            source_indices = source_indices.masked_fill(
                source_mask == 0, self.padding_token
            )
        if target_mask is not None:
            target_indices = target_indices.masked_fill(
                target_mask == 0, self.padding_token
            )
        # compute the document frequencies for each document in the batch
        for batch_idx in range(target_indices.size(0)):
            cnt = Counter(target_indices[batch_idx].reshape(-1).tolist())
            if self.padding_token in cnt:
                del cnt[self.padding_token]
            this_batch_document_frequencies_split.append(cnt)
            for key in this_batch_document_frequencies_split[-1]:
                this_batch_document_frequencies[key] += 1

        # manage history
        # TODO: refactor into a method
        # add
        self._all_targets_frequencies.append(this_batch_document_frequencies)
        self._all_seen_batch_sizes.append(target_indices.size(0))
        self.seen_docs += target_indices.size(0)
        for key, value in this_batch_document_frequencies.items():
            self.collection_target_frequencies[key] += value
        assert len(self._all_seen_batch_sizes) == len(self._all_targets_frequencies)
        # remove if over history size
        while len(self._all_targets_frequencies) > self.history_size:
            old_batch = self._all_targets_frequencies.popleft()
            for key, value in old_batch.items():
                self.collection_target_frequencies[key] -= value
            self.seen_docs -= self._all_seen_batch_sizes.popleft()
        assert len(self._all_seen_batch_sizes) == len(self._all_targets_frequencies)

        results = []

        lens = []

        positive_intersection_length = 0.0
        negative_intersection_length = 0.0
        positive_unique_intersection_length = 0.0
        negative_unique_intersection_length = 0.0
        positive_same_pos_intersection_length = 0.0
        negative_same_pos_intersection_length = 0.0
        min_len = min(source_indices.size(1), target_indices.size(1))
        for batch_idx in range(source_indices.size(0)):
            source_counter = Counter(source_indices[batch_idx].reshape(-1).tolist())
            target_counter = Counter(target_indices[batch_idx].reshape(-1).tolist())
            if self.padding_token in source_counter:
                del source_counter[self.padding_token]
            if self.padding_token in target_counter:
                del target_counter[self.padding_token]
            target_counter_neg = Counter(
                target_indices[(batch_idx + 1) % target_indices.size(0)]
                .reshape(-1)
                .tolist()
            )
            positive_intersection_length += float(
                sum((source_counter & target_counter).values())
            )
            positive_unique_intersection_length += float(
                len(set(source_counter) & set(target_counter))
            )
            negative_intersection_length += float(
                sum((source_counter & target_counter_neg).values())
            )
            negative_unique_intersection_length += float(
                len(set(source_counter) & set(target_counter_neg))
            )
            positive_same_pos_intersection_length += float(
                sum(
                    (
                        (
                            source_indices[batch_idx][:min_len].reshape(-1)
                            == target_indices[batch_idx][:min_len].reshape(-1)
                        )
                        & (source_indices[batch_idx][:min_len].reshape(-1) != self.padding_token)
                    )
                )
            )
            negative_same_pos_intersection_length += float(
                sum(
                    (
                        source_indices[batch_idx][:min_len].reshape(-1)
                        == target_indices[
                            (batch_idx + 1) % target_indices.size(0)
                        ][:min_len].reshape(-1)
                    )
                    & (source_indices[batch_idx][:min_len].reshape(-1) != self.padding_token)
                )
            )
        positive_intersection_length /= source_indices.size(0)
        positive_unique_intersection_length /= source_indices.size(0)
        negative_intersection_length /= source_indices.size(0)
        negative_unique_intersection_length /= source_indices.size(0)
        positive_same_pos_intersection_length /= source_indices.size(0)
        negative_same_pos_intersection_length /= source_indices.size(0)

        for batch_idx in range(source_indices.size(0)):

            # compute the length of the sequence without padding tokens, only unique tokens
            source_uniq_len = len(
                set(source_indices[batch_idx].reshape(-1).tolist())
                - {self.padding_token}
            )
            target_uniq_len = len(
                set(target_indices[batch_idx].reshape(-1).tolist())
                - {self.padding_token}
            )

            # compute the length of the sequence without padding tokens
            source_len = len(
                source_indices[batch_idx].reshape(-1).tolist()
            ) - source_indices[batch_idx].reshape(-1).tolist().count(self.padding_token)
            target_len = len(
                target_indices[batch_idx].reshape(-1).tolist()
            ) - target_indices[batch_idx].reshape(-1).tolist().count(self.padding_token)

            lens.append(target_len)

            # no need to be a Counter but easy to switch to non-unique tokens
            source_counter = Counter(
                set(source_indices[batch_idx].reshape(-1).tolist())
            )
            target_counter = Counter(
                set(target_indices[batch_idx].reshape(-1).tolist())
            )
            # source_counter = Counter(source_indices[batch_idx].reshape(-1).tolist())
            # target_counter = Counter(target_indices[batch_idx].reshape(-1).tolist())
            if self.padding_token in source_counter:
                del source_counter[self.padding_token]
            if self.padding_token in target_counter:
                del target_counter[self.padding_token]

            source_tot = sum(source_counter.values())
            target_tot = sum(target_counter.values())
            source_freq = {k: v / source_tot for k, v in source_counter.items()}
            target_freq = {k: v / target_tot for k, v in target_counter.items()}
            all_keys = set(source_freq.keys()).union(set(target_freq.keys()))
            result = {}
            for key in all_keys:
                result[key] = source_freq.get(key, 0) - target_freq.get(key, 0)
            freq_divergence = np.mean(np.abs(list(result.values())))
            # compute the tf-idf for the gold document
            # denominator for tf
            num_toks_in_target = sum(target_counter.values())
            # numerator for idf
            self.average_len = sum(lens) / len(lens)
            num_docs_in_history = self.seen_docs
            # accumulator
            tf_idf_di = 0.0
            for latent_token, latent_token_count_in_source in source_counter.items():
                if latent_token not in target_counter:
                    continue
                # tf-idf
                # tf_idf_di += latent_token_count_in_source * (
                #     target_counter[latent_token] / num_toks_in_target *
                #     math.log2(num_docs_in_history / self.collection_target_frequencies[latent_token])
                # )

                # BM25
                curr_len = num_toks_in_target
                try:
                    idf = math.log(
                        (
                            num_docs_in_history
                            - self.collection_target_frequencies[latent_token]
                            + 0.5
                        )
                        / (self.collection_target_frequencies[latent_token] + 0.5)
                        + 1.0
                    )
                    tf = (target_counter[latent_token] * (self.bm25_k1 + 1.0)) / (
                        target_counter[latent_token]
                        + self.bm25_k1
                        * (
                            1.0
                            - self.bm25_b
                            + self.bm25_b * curr_len / self.average_len
                        )
                    )
                except ZeroDivisionError:
                    tf = 0.0
                    idf = 0.0
                except ValueError:
                    tf = 0.0
                    idf = 0.0

                tf_idf_di += tf * idf

            # find negative target in current batch with max tf-idf
            max_tfidf = -1e6
            for batch_idx_neg in range(target_indices.size(0)):
                if batch_idx_neg == batch_idx:
                    # this is the positive document
                    continue
                negative_target_counter = this_batch_document_frequencies_split[
                    batch_idx_neg
                ]
                num_docs_in_history = self.seen_docs
                num_toks_in_target = sum(negative_target_counter.values())
                tf_idf_curr = 0.0
                for (
                    latent_token,
                    latent_token_count_in_source,
                ) in source_counter.items():
                    if latent_token not in negative_target_counter:
                        continue

                    # tf_idf
                    # tf_idf_curr += latent_token_count_in_source * (
                    #         negative_target_counter[latent_token] / num_toks_in_target *
                    #         math.log2(num_docs_in_history / self.collection_target_frequencies[latent_token])
                    # )

                    # BM25
                    try:
                        curr_len = num_toks_in_target
                        idf = math.log(
                            (
                                num_docs_in_history
                                - self.collection_target_frequencies[latent_token]
                                + 0.5
                            )
                            / (self.collection_target_frequencies[latent_token] + 0.5)
                            + 1.0
                        )
                        tf = (
                            negative_target_counter[latent_token] * (self.bm25_k1 + 1.0)
                        ) / (
                            negative_target_counter[latent_token]
                            + self.bm25_k1
                            * (
                                1.0
                                - self.bm25_b
                                + self.bm25_b * curr_len / self.average_len
                            )
                        )
                    except ZeroDivisionError:
                        tf = 0.0
                        idf = 0.0
                    except ValueError:
                        tf = 0.0
                        idf = 0.0

                    tf_idf_curr += tf * idf

                if tf_idf_curr > max_tfidf:
                    max_tfidf = tf_idf_curr

            results.append(
                StreamingTFIDFResults(
                    positive_tf_idf=tf_idf_di,
                    max_negative_tf_idf=max_tfidf,
                    positive_intersection_length=float(
                        sum((source_counter & target_counter).values())
                    ),
                    source_uniq_len=float(source_uniq_len),
                    source_len=float(source_len),
                    target_uniq_len=float(target_uniq_len),
                    target_len=float(target_len),
                    freq_divergence=float(freq_divergence),
                    new_positive_intersection_length=positive_intersection_length,
                    new_negative_intersection_length=negative_intersection_length,
                    positive_unique_intersection_length=positive_unique_intersection_length,
                    negative_unique_intersection_length=negative_unique_intersection_length,
                    positive_same_pos_intersection_length=positive_same_pos_intersection_length,
                    negative_same_pos_intersection_length=negative_same_pos_intersection_length,
                )
            )

        step_output = StreamingTFIDFResults(
            positive_tf_idf=sum(r.positive_tf_idf for r in results) / len(results),
            max_negative_tf_idf=sum(r.max_negative_tf_idf for r in results)
            / len(results),
            positive_intersection_length=sum(
                r.positive_intersection_length for r in results
            )
            / len(results),
            source_uniq_len=sum(r.source_uniq_len for r in results) / len(results),
            source_len=sum(r.source_len for r in results) / len(results),
            target_uniq_len=sum(r.target_uniq_len for r in results) / len(results),
            target_len=sum(r.target_len for r in results) / len(results),
            freq_divergence=sum(r.freq_divergence for r in results) / len(results),
            new_positive_intersection_length=positive_intersection_length,
            new_negative_intersection_length=negative_intersection_length,
            positive_unique_intersection_length=positive_unique_intersection_length,
            negative_unique_intersection_length=negative_unique_intersection_length,
            positive_same_pos_intersection_length=positive_same_pos_intersection_length,
            negative_same_pos_intersection_length=negative_same_pos_intersection_length,
        )

        return step_output


def entropy_regularization(
    sequence_vectors: torch.Tensor,
    # source_indices: torch.Tensor,
    codebook_vectors: torch.Tensor,
    mask: torch.Tensor,
    distance="cosine",
    softmax_temperature: float = 0.25,
    softmax_midpoint: float = 0.95,
    invert: bool = False,
    random_select_entropy_subset: int = 0,
):
    """
    Computes the entropy regularization loss between the sequence and the codebook vectors

    @param sequence_vectors: [B, T, H] the query or document vectors
    @param codebook_vectors: [C, H] the codebook vectors
    @param distance: "cosine" or "mse"
    @param softmax_temperature: the temperature for the softmax
    @param softmax_midpoint: the midpoint for the softmax. Defines the center of the softmax. the midpoint should be the value
        at which one logit is a match (e.g. 0.95 for cosine similarity). This helps to concentrate prob mass on hits without explicit clamping

    WARNING: if you overdo the temperature (too low), there will be occasional NaNs in the loss.
    """

    # sequence_vectors = sequence_vectors * 2.0
    # sequence_vectors_mean = torch.mean(sequence_vectors, dim=(0, 1), keepdim=True)
    # covar_matrix = torch.einsum(
    #     "bdh,bdH->bdhH",
    #     sequence_vectors - sequence_vectors_mean,
    #     sequence_vectors - sequence_vectors_mean,
    # ).sum(dim=(0, 1)) / (sequence_vectors.size(0) * sequence_vectors.size(1) - 1)
    # diag = torch.eye(covar_matrix.size(0), device=covar_matrix.device, dtype=covar_matrix.dtype)
    # covar_matrix = covar_matrix * (1.0 - diag)
    # return (covar_matrix ** 2.0).sum()  / covar_matrix.size(0)

    if torch.isnan(sequence_vectors).any():
        logging.error("source_level_vectors has NaN")
        return torch.zeros_like(sequence_vectors[0, 0, 0])

    elif torch.isnan(codebook_vectors).any():
        logging.error("codebook_level_vectors has NaN")
        return torch.zeros_like(sequence_vectors[0, 0, 0])

    if mask is not None:
        sequence_vectors = sequence_vectors * mask.unsqueeze(-1)

    assert (
        softmax_midpoint == 0.0
    ), "Softmax midpoint is not yet supported for MSE distance"
    logits = -torch.cdist(
        sequence_vectors.view(1, -1, sequence_vectors.size(-1)),
        codebook_vectors.unsqueeze(0),
    ).view(
        sequence_vectors.size(0),
        sequence_vectors.size(1),
        codebook_vectors.size(0),
    )
    logits = logits / softmax_temperature

    # make into a prob dist over the codebook

    # dist = logits.argmax(dim=-1).view(-1).tolist()
    # cnt = Counter(dist)
    # pdist = torch.tensor([cnt[i] / len(dist) for i in range(codebook_vectors.size(0))], device=sequence_vectors.device, dtype=torch.float32)
    # entrpy = math.log(codebook_vectors.size(0)) - sum([
    #     (pdist[i] * math.log(pdist[i])) if pdist[i] > 0 else 0.0
    #     for i in range(codebook_vectors.size(0))])

    with torch.autocast(device_type=sequence_vectors.device.type, dtype=torch.float64):

        pdist = logits.log_softmax(dim=-1).clamp_min(math.log(1e-6)).exp()

        pdist = (pdist * mask.unsqueeze(-1).float()).sum(1)  # Bc
        pdist = pdist / mask.float().sum(1).unsqueeze(-1)
        pdist = pdist.mean(0)
        if torch.isnan(pdist).any():
            logging.error("pdist has NaN")
            return torch.zeros_like(sequence_vectors[0, 0, 0]), torch.zeros_like(
                sequence_vectors[0, 0, 0]
            )

        assert pdist.size() == (codebook_vectors.size(0),)

        # mean over dim 0 and 1

        # entropy, taking care of log(0)
        norm_entropy = -torch.sum(
            pdist * torch.nan_to_num(torch.log(pdist), nan=0.0, neginf=0.0)
        ) / math.log(codebook_vectors.size(0))

        # raise if nan
        if torch.isnan(norm_entropy):
            print(norm_entropy)
            raise ValueError("NaN norm entropy")

        neg_norm_entropy = 1.0 - norm_entropy

        return neg_norm_entropy  # , torch.zeros_like(neg_norm_entropy)


def rerank(
    query_indices: torch.Tensor,
    documents_indices: torch.Tensor,
    embedding_table: torch.Tensor,
):
    query_embedded = torch.nn.functional.embedding(
        query_indices, embedding_table
    )  # (batch_size, len_query, hidden_size)
    documents_embedded = torch.nn.functional.embedding(
        documents_indices, embedding_table
    )  # (batch_size, num_docs, len_doc, hidden_size)
    similarities = (
        torch.einsum("bqh,brdh->bqrd", query_embedded, documents_embedded)
        .max(dim=-1)
        .values.sum(dim=1)
    )  # (batch_size, num_docs)
    return similarities


def get_bigbatch(input):
    subbatch = {}
    subbatch["input_ids"] = torch.cat(input["input_ids"], dim=0)
    subbatch["attention_mask"] = torch.cat(input["attention_mask"], dim=0)
    return SimpleNamespace(**subbatch)


class DatasetType(Enum):
    QD2D = 1
    QD2QD = 2
    QDQD2QDDQ = 3
    QD2DQ = 4
    Q2D = 5
    Q2Q = 6
    D2D = 7
    QD2DQ_RLHF = 8
    QDD2DQD = 9


def add_special_task_descriptor(documents, task_descriptor):
    new_documents = []
    for i in range(len(documents)):
        new_documents.append(task_descriptor + documents[i])
    return new_documents


def mask_word(document, mask_percentage, autoencoder):
    words_document = document.split(" ")
    indices_to_mask = random.sample(
        range(len(words_document)), int(len(words_document) * mask_percentage)
    )
    for index in indices_to_mask:
        words_document[index] = autoencoder.tokenizer.mask_token
    return " ".join(words_document)


def mask_words(documents, mask_percentage, autoencoder):
    words_documents = [document.split() for document in documents]
    masked_documents = []
    for words_document in words_documents:
        indices_to_mask = random.sample(
            range(len(words_document)), int(len(words_document) * mask_percentage)
        )
        for index in indices_to_mask:
            words_document[index] = autoencoder.tokenizer.mask_token
        masked_documents.append(" ".join(words_document))
    return masked_documents


def gen_input_and_output_batch(
    autoencoder,
    QD_DATASET,
    documents,
    queries,
    mask_docs,
    mask_queries,
    hard_neg_docs,
    SEQ_LEN,
    MASK_PERC,
):
    hard_neg_docs_batch = None
    if QD_DATASET == DatasetType.QD2D:
        autoenc_input_batch = autoencoder.tokenize(
            documents + queries, max_length=SEQ_LEN, pad_to_max_length=True
        )
        autoenc_output_batch = autoencoder.tokenize(
            documents + documents, max_length=SEQ_LEN, pad_to_max_length=True
        )
    elif QD_DATASET == DatasetType.QD2QD:
        masked_documents = mask_words(documents, MASK_PERC, autoencoder)
        autoenc_input_batch = autoencoder.tokenize(
            masked_documents + queries, max_length=SEQ_LEN, pad_to_max_length=True
        )
        autoenc_output_batch = autoencoder.tokenize(
            documents + queries, max_length=SEQ_LEN, pad_to_max_length=True
        )
    elif QD_DATASET == DatasetType.QDQD2QDDQ:
        # masked_documents = mask_words(documents, MASK_PERC)
        task_q_to_d_q = add_special_task_descriptor(queries, "query: ")
        task_q_to_d_d = add_special_task_descriptor(documents, "query: ")
        task_d_to_q_d = add_special_task_descriptor(documents, "document: ")
        task_q_to_q_q = add_special_task_descriptor(queries, "document: ")
        autoenc_input_batch = autoencoder.tokenize(
            masked_documents + queries + task_q_to_d_q + task_d_to_q_d,
            max_length=SEQ_LEN,
        )
        autoenc_output_batch = autoencoder.tokenize(
            documents + queries + task_q_to_d_d + task_q_to_q_q,
            max_length=SEQ_LEN,
            pad_to_max_length=True,
        )
    elif QD_DATASET == DatasetType.QD2DQ:
        # input
        task_q_to_d_q = add_special_task_descriptor(mask_queries, "query: ")
        task_d_to_q_d = add_special_task_descriptor(mask_docs, "document: ")
        # output
        task_q_to_d_d = add_special_task_descriptor(documents, "query: ")
        task_d_to_q_q = add_special_task_descriptor(queries, "document: ")
        if hard_neg_docs is not None:
            hard_neg_docs = add_special_task_descriptor(hard_neg_docs, "document: ")
            hard_neg_docs_batch = autoencoder.tokenize(
                hard_neg_docs,
                max_length=SEQ_LEN,
                pad_to_max_length=True,
            )
        autoenc_input_batch = autoencoder.tokenize(
            task_q_to_d_q + task_d_to_q_d,
            max_length=SEQ_LEN,
            pad_to_max_length=True,
        )
        autoenc_output_batch = autoencoder.tokenize(
            task_q_to_d_d + task_d_to_q_q,
            max_length=SEQ_LEN,
            pad_to_max_length=True,
        )
    elif QD_DATASET == DatasetType.QDD2DQD:
        masked_documents = mask_words(documents, MASK_PERC, autoencoder)
        half_len = len(documents) // 2
        task_q_to_d_q = add_special_task_descriptor(queries[:half_len], "query: ")
        task_q_to_d_d = add_special_task_descriptor(documents[:half_len], "query: ")
        task_d_to_q_d = add_special_task_descriptor(
            masked_documents[:half_len], "document: "
        )
        task_d_to_q_q = add_special_task_descriptor(queries[:half_len], "document: ")
        task_d_to_d_in = add_special_task_descriptor(masked_documents, "rec: ")
        task_d_to_d_out = add_special_task_descriptor(documents, "rec: ")
        autoenc_input_batch = autoencoder.tokenize(
            task_q_to_d_q + task_d_to_q_d + task_d_to_d_in,
            max_length=SEQ_LEN,
        )
        autoenc_output_batch = autoencoder.tokenize(
            task_q_to_d_d + task_d_to_q_q + task_d_to_d_out,
            max_length=SEQ_LEN,
        )
    elif QD_DATASET == DatasetType.QD2DQ_RLHF:
        task_q_to_d_q = add_special_task_descriptor(queries, "query: ")
        task_q_to_d_d = add_special_task_descriptor(documents, "query: ")
        task_d_to_q_d = add_special_task_descriptor(documents, "document: ")
        task_d_to_q_q = add_special_task_descriptor(queries, "document: ")
        task_q_to_d_d_wrong = add_special_task_descriptor(
            documents[1:] + documents[:1], "query: "
        )
        task_d_to_q_q_wrong = add_special_task_descriptor(
            queries[1:] + queries[:1], "document: "
        )

        autoenc_input_batch = autoencoder.tokenize(
            task_q_to_d_q + task_d_to_q_d + task_q_to_d_q + task_d_to_q_d,
            max_length=SEQ_LEN,
        )
        autoenc_output_batch = autoencoder.tokenize(
            task_q_to_d_d + task_d_to_q_q + task_q_to_d_d_wrong + task_d_to_q_q_wrong,
            max_length=SEQ_LEN,
        )
    elif QD_DATASET == DatasetType.Q2D:  # Q2D_FIRST:
        autoenc_input_batch = autoencoder.tokenize(queries, max_length=SEQ_LEN)
        autoenc_output_batch = autoencoder.tokenize(documents, max_length=SEQ_LEN)
    elif QD_DATASET == DatasetType.Q2Q:  # Q2Q_AUTOREGRESSIVE:
        autoenc_input_batch = autoencoder.tokenize(queries, max_length=SEQ_LEN)
        autoenc_output_batch = autoencoder.tokenize(queries, max_length=SEQ_LEN)
    elif QD_DATASET == DatasetType.D2D:
        masked_documents = mask_words(documents, MASK_PERC, autoencoder)
        autoenc_input_batch = autoencoder.tokenize(masked_documents, max_length=SEQ_LEN)
        autoenc_output_batch = autoencoder.tokenize(documents, max_length=SEQ_LEN)
    return autoenc_input_batch, autoenc_output_batch, hard_neg_docs_batch


class GradCache(ABC):

    def __init__(
        self,
        params,
        autoencoder,
        token_counter,
    ):
        self.QD_DATASET = params["dataset_type"]
        self.SEQ_LEN = params["seq_len"]
        self.CONTR_VECTOR_NAME = params["contr_vector_name"]
        self.CONTR_LEVEL_VECTOR_NAME = params["contr_level_vector_name"]
        self.grad_cache_batch = params["grad_cache_batch"]
        self.BATCH_SIZE = params["batch_size"]
        self.CONTR_LOSS_TYPE = params["contr_loss"]
        self.HINGE_MARGIN = params["hinge_margin"]
        self.SOFTMAX_TEMPERATURE = params["softmax_temp"]
        self.EXACT_MATCH_COLBERT = params["exact_match_colbert"]
        self.ACCUMULATE_GRADIENTS = params["accumulate_gradients"]
        self.GRAD_CACHE_ACCUMULATE_GRADIENTS = params["grad_cache_accumulate_gradients"]
        self.SUP_LOSS_ALPHA = params["sup_loss_alpha"]
        self.CONTR_LOSS_ALPHA = params["contr_loss_alpha"]
        self.ENTROPY_LOSS_ALPHA = params["entropy_loss_alpha"]
        self.ENTROPY_VECTOR_NAME = params["entropy_vec_name"]
        self.ENTROPY_SOFTMAX = params["entropy_softmax"]
        self.ENTROPY_SOFTMAX_MIDPOINT = params["entropy_softmax_midpoint"]
        self.ENTROPY_DISTANCE = params["entropy_distance"]
        self.RANDOM_SELECT_ENTROPY_SUBSET = params["random_select_entropy_subset"]
        self.autoencoder = autoencoder
        self.token_counter = token_counter
        indices = torch.arange(
            self.autoencoder.quantizer.codebook_size, device=autoencoder.device()
        )
        self.codes = self.autoencoder.quantizer.indices_to_codes(indices)
        self.CONTR_MASK_NON_UNIQUE = params["contr_mask_non_unique"]
        self.MASK_PERC = params["crop_mask_doc_perc"]
        self.hard_neg_docs = None

    def reset_cache(self):
        self.source_cache = []
        self.target_cache = []
        self.hard_neg_cache = []
        self.source_level_cache = []
        self.target_level_cache = []
        self.hard_neg_level_cache = []
        self.source_mask_cache = []
        self.target_mask_cache = []
        self.source_quantizer_indices_cache = []
        self.target_quantizer_indices_cache = []
        self.autoenc_input_subbatches = []
        self.autoenc_output_subbatches = []
        self.source_masks = []
        self.target_masks = []
        self.hard_neg_masks = []

    def set_docs_and_queries(
        self, documents, queries, mask_docs, mask_queries, hard_neg_docs
    ):
        self.documents = documents
        self.mask_docs = mask_docs
        self.queries = queries
        self.mask_queries = mask_queries
        self.hard_neg_docs = hard_neg_docs
        assert len(self.documents) == len(self.queries)
        assert len(self.documents) == len(self.mask_docs)
        assert len(self.documents) == len(self.mask_queries)

    def create_cache(self):
        if self.GRAD_CACHE_ACCUMULATE_GRADIENTS == 1:
            return
        self.reset_cache()
        for i in range(0, len(self.documents), self.grad_cache_batch // 2):
            with torch.no_grad():
                (
                    autoenc_input_subbatch,
                    autoenc_output_subbatch,
                    hard_neg_docs_subbatch,
                ) = gen_input_and_output_batch(
                    autoencoder=self.autoencoder,
                    QD_DATASET=self.QD_DATASET,
                    documents=self.documents[i : i + self.grad_cache_batch // 2],
                    queries=self.queries[i : i + self.grad_cache_batch // 2],
                    mask_docs=self.mask_docs[i : i + self.grad_cache_batch // 2],
                    mask_queries=self.mask_queries[i : i + self.grad_cache_batch // 2],
                    hard_neg_docs=(
                        self.hard_neg_docs[i : i + self.grad_cache_batch // 2]
                        if self.hard_neg_docs is not None
                        else None
                    ),
                    SEQ_LEN=self.SEQ_LEN,
                    MASK_PERC=self.MASK_PERC,
                )
                cur_minibatch_size = len(autoenc_input_subbatch["input_ids"])
                output_d2d: AutoencoderOutput = self.autoencoder(
                    autoenc_input_subbatch,
                    autoenc_output_subbatch,
                    steps=1,
                    encoding_only=True,
                )
                if self.hard_neg_docs is not None:
                    output_hard_neg: AutoencoderOutput = self.autoencoder(
                        hard_neg_docs_subbatch,
                        hard_neg_docs_subbatch,
                        steps=1,
                        encoding_only=True,
                    )
                    hard_neg_vectors = extract_vectors(
                        self.CONTR_VECTOR_NAME,
                        output=output_hard_neg,
                        autoencoder=self.autoencoder,
                    )
                    hard_neg_level_vectors = extract_vectors(
                        self.CONTR_LEVEL_VECTOR_NAME,
                        output=output_hard_neg,
                        autoencoder=self.autoencoder,
                    )
                    self.hard_neg_cache.append(hard_neg_vectors)
                    self.hard_neg_level_cache.append(hard_neg_level_vectors)
            source_mask = autoenc_input_subbatch["attention_mask"][
                : cur_minibatch_size // 2
            ]
            target_mask = autoenc_input_subbatch["attention_mask"][
                cur_minibatch_size // 2 :
            ]
            hard_neg_mask = (
                hard_neg_docs_subbatch["attention_mask"]
                if self.hard_neg_docs is not None
                else None
            )
            source_vectors = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                end_idx=cur_minibatch_size // 2,
            )
            target_vectors = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                start_idx=cur_minibatch_size // 2,
            )
            source_level_vectors = extract_vectors(
                self.CONTR_LEVEL_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                end_idx=cur_minibatch_size // 2,
            )
            target_level_vectors = extract_vectors(
                self.CONTR_LEVEL_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                start_idx=cur_minibatch_size // 2,
            )
            quantizer_indices = output_d2d.quantizer_outputs[1].reshape(
                -1, self.autoencoder.effective_length
            )
            source_indices = quantizer_indices[: cur_minibatch_size // 2]
            target_indices = quantizer_indices[cur_minibatch_size // 2 :]
            # source_mask = mask_non_unique(source_indices)
            # target_mask = mask_non_unique(target_indices)
            self.source_cache.append(source_vectors)
            self.target_cache.append(target_vectors)
            self.source_masks.append(source_mask)
            self.target_masks.append(target_mask)
            self.hard_neg_masks.append(hard_neg_mask)
            self.source_level_cache.append(source_level_vectors)
            self.target_level_cache.append(target_level_vectors)
            self.source_quantizer_indices_cache.append(source_indices)
            self.target_quantizer_indices_cache.append(target_indices)
            self.source_mask_cache.append(source_mask)
            self.target_mask_cache.append(target_mask)
            self.autoenc_input_subbatches.append(autoenc_input_subbatch)
            self.autoenc_output_subbatches.append(autoenc_output_subbatch)
        self.source_cache = torch.cat(self.source_cache, dim=0)
        self.target_cache = torch.cat(self.target_cache, dim=0)
        self.source_masks = torch.cat(self.source_masks, dim=0)
        self.target_masks = torch.cat(self.target_masks, dim=0)
        self.hard_neg_masks = (
            torch.cat(self.hard_neg_masks, dim=0)
            if self.hard_neg_docs is not None
            else None
        )
        self.source_cache.requires_grad_(True)
        self.target_cache.requires_grad_(True)
        self.source_cache.retain_grad()
        self.target_cache.retain_grad()
        self.source_level_cache = torch.cat(self.source_level_cache, dim=0)
        self.target_level_cache = torch.cat(self.target_level_cache, dim=0)
        if self.hard_neg_docs is not None:
            self.hard_neg_cache = torch.cat(self.hard_neg_cache, dim=0)
            self.hard_neg_cache.requires_grad_(True)
            self.hard_neg_cache.retain_grad()
            self.hard_neg_level_cache = torch.cat(self.hard_neg_level_cache, dim=0)
        self.source_quantizer_indices_cache = torch.cat(
            self.source_quantizer_indices_cache, dim=0
        )
        self.target_quantizer_indices_cache = torch.cat(
            self.target_quantizer_indices_cache, dim=0
        )
        self.source_mask_cache = torch.cat(self.source_mask_cache, dim=0)
        self.target_mask_cache = torch.cat(self.target_mask_cache, dim=0)

    def calculate_contr_loss_in_subbatch(self):
        if self.GRAD_CACHE_ACCUMULATE_GRADIENTS == 1:
            return
        self.contr_loss_val = 0.0
        if self.hard_neg_docs is not None:
            target_cache_incl_hard_neg = torch.cat(
                (self.target_cache, self.hard_neg_cache), dim=0
            )
            target_level_cache_incl_hard_neg = torch.cat(
                (self.target_level_cache, self.hard_neg_level_cache), dim=0
            )
            target_masks_incl_hard_neg = torch.cat(
                (self.target_masks, self.hard_neg_masks), dim=0
            )
        else:
            target_cache_incl_hard_neg = self.target_cache
            target_level_cache_incl_hard_neg = self.target_level_cache
            target_masks_incl_hard_neg = self.target_masks
        for i in range(0, self.source_cache.size(0), self.grad_cache_batch):
            contr_loss_val = contr_loss(
                start_from=i,
                source_vectors=self.source_cache[i : i + self.grad_cache_batch],
                target_vectors=target_cache_incl_hard_neg,
                hinge_margin=self.HINGE_MARGIN,
                mask_match=True,
                source_mask=self.source_masks[i : i + self.grad_cache_batch],
                target_mask=target_masks_incl_hard_neg,
                exact_match=self.EXACT_MATCH_COLBERT,
                source_level_vectors=self.source_level_cache[
                    i : i + self.grad_cache_batch
                ],
                target_level_vectors=target_level_cache_incl_hard_neg,
                weigh_by_idf=False,
                # collection_frequencies=step_stats.collection_target_frequencies,
                # collection_size=len(step_stats.collection_target_frequencies),
                # codebook_size=autoencoder.quantizer.codebook_size,
                source_indices=self.source_quantizer_indices_cache[
                    i : i + self.grad_cache_batch
                ],
                target_indices=self.target_quantizer_indices_cache,
                use_topk=-1,
                loss=self.CONTR_LOSS_TYPE,
                softmax_loss_temperature=self.SOFTMAX_TEMPERATURE,
            )
            self.contr_loss_val += contr_loss_val.item()
            # get gradients to our original representation vectors
            retain_graph = i < self.source_cache.size(0) - self.grad_cache_batch
            contr_loss_val.backward(retain_graph=retain_graph)

    def get_total_loss_and_backward(self):
        if self.GRAD_CACHE_ACCUMULATE_GRADIENTS == 1:
            return self.get_total_loss_and_backward_single()
        else:
            return self.get_total_loss_and_backward_cached()

    def get_total_loss_and_backward_cached(self):
        total_batch_loss = 0.0
        total_batch_ae_loss = 0.0
        total_batch_entr_loss = 0.0
        quant_outputs = []
        for i in range(0, len(self.documents), self.grad_cache_batch // 2):
            autoenc_input_subbatch, autoenc_output_subbatch, _ = (
                gen_input_and_output_batch(
                    autoencoder=self.autoencoder,
                    QD_DATASET=self.QD_DATASET,
                    documents=self.documents[i : i + self.grad_cache_batch // 2],
                    queries=self.queries[i : i + self.grad_cache_batch // 2],
                    mask_docs=self.mask_docs[i : i + self.grad_cache_batch // 2],
                    mask_queries=self.mask_queries[i : i + self.grad_cache_batch // 2],
                    hard_neg_docs=None,
                    SEQ_LEN=self.SEQ_LEN,
                    MASK_PERC=self.MASK_PERC,
                )
            )
            cur_minibatch_size = len(autoenc_input_subbatch["input_ids"])
            output_d2d: AutoencoderOutput = self.autoencoder(
                autoenc_input_subbatch,
                autoenc_output_subbatch,
                steps=1,
                encoding_only=self.SUP_LOSS_ALPHA == 0.0,
            )
            cur_minibatch_size = len(autoenc_input_subbatch["input_ids"])
            source_vectors_with_grad = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                end_idx=cur_minibatch_size // 2,
            )
            target_vectors_with_grad = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                start_idx=cur_minibatch_size // 2,
            )
            if self.SUP_LOSS_ALPHA > 0.0:
                sup_loss = self.autoencoder.loss(autoenc_output_subbatch, output_d2d)
            else:
                sup_loss = torch.zeros(1, device=self.autoencoder.device())

            contr_loss_val = (
                (
                    source_vectors_with_grad
                    * self.source_cache.grad[i : i + self.grad_cache_batch // 2]
                ).sum()
                + (
                    target_vectors_with_grad
                    * self.target_cache.grad[i : i + self.grad_cache_batch // 2]
                ).sum()
            ).sum()
            codebook_vec = extract_latent_embeddings(
                self.ENTROPY_VECTOR_NAME, self.autoencoder
            )
            entr_loss1 = entropy_regularization(
                sequence_vectors=extract_vectors(
                    which=self.ENTROPY_VECTOR_NAME,
                    output=output_d2d,
                    autoencoder=self.autoencoder,
                    end_idx=cur_minibatch_size // 2,
                ),
                # sequence_indices=self.source_quantizer_indices_cache[
                #     i : i + self.grad_cache_batch // 2
                # ],
                mask=autoenc_input_subbatch.attention_mask[: cur_minibatch_size // 2],
                codebook_vectors=codebook_vec,
                distance=self.ENTROPY_DISTANCE,
                softmax_temperature=self.ENTROPY_SOFTMAX,
                softmax_midpoint=self.ENTROPY_SOFTMAX_MIDPOINT,
                random_select_entropy_subset=self.RANDOM_SELECT_ENTROPY_SUBSET,
            )
            entr_loss2 = entropy_regularization(
                sequence_vectors=extract_vectors(
                    which=self.ENTROPY_VECTOR_NAME,
                    output=output_d2d,
                    autoencoder=self.autoencoder,
                    start_idx=cur_minibatch_size // 2,
                ),
                # sequence_indices=self.target_quantizer_indices_cache[
                #     i : i + self.grad_cache_batch // 2
                # ],
                mask=autoenc_input_subbatch.attention_mask[cur_minibatch_size // 2 :],
                codebook_vectors=codebook_vec,
                distance=self.ENTROPY_DISTANCE,
                softmax_temperature=self.ENTROPY_SOFTMAX,
                softmax_midpoint=self.ENTROPY_SOFTMAX_MIDPOINT,
                random_select_entropy_subset=self.RANDOM_SELECT_ENTROPY_SUBSET,
            )
            entr_loss = (entr_loss1 + entr_loss2) / 2.0
            loss = (
                (sup_loss / self.GRAD_CACHE_ACCUMULATE_GRADIENTS) * self.SUP_LOSS_ALPHA
                + contr_loss_val * self.CONTR_LOSS_ALPHA
                + (entr_loss / self.GRAD_CACHE_ACCUMULATE_GRADIENTS)
                * self.ENTROPY_LOSS_ALPHA
            )
            retain_graph = i < len(self.documents) - self.grad_cache_batch // 2
            loss.backward(retain_graph=retain_graph)
            # for param in self.autoencoder.parameters():
            #     if param.grad is not None and torch.isnan(param.grad).any():
            #         torch.nan_to_num_(param.grad, nan=0.0, posinf=0.0, neginf=0.0)

            total_batch_ae_loss += sup_loss.item()
            total_batch_entr_loss += entr_loss.item()
            total_batch_loss = (
                self.SUP_LOSS_ALPHA * total_batch_ae_loss
                + self.ENTROPY_LOSS_ALPHA * total_batch_entr_loss
                + self.CONTR_LOSS_ALPHA * self.contr_loss_val
            )

            documents_indices = output_d2d.quantizer_outputs[1][
                cur_minibatch_size // 2 :
            ].long()
            queries_indices = output_d2d.quantizer_outputs[1][
                : cur_minibatch_size // 2
            ].long()
            # save some info for logging and tfidf calculation
            step_stats: StreamingTFIDFResults = self.token_counter.step(
                target_indices=documents_indices,
                source_indices=queries_indices,
                source_mask=self.source_masks[i : i + cur_minibatch_size // 2],
                target_mask=self.target_masks[i : i + cur_minibatch_size // 2],
            )
            quant_outputs.append(output_d2d.quantizer_outputs[1])

        quant_outputs = torch.cat(quant_outputs, dim=0)
        return {
            "sup_loss": total_batch_ae_loss / self.GRAD_CACHE_ACCUMULATE_GRADIENTS,
            "contr_loss": self.contr_loss_val / self.GRAD_CACHE_ACCUMULATE_GRADIENTS,
            "entr_loss": total_batch_entr_loss / self.GRAD_CACHE_ACCUMULATE_GRADIENTS,
            "tot_loss": total_batch_loss / self.GRAD_CACHE_ACCUMULATE_GRADIENTS,
            "quant_indices": output_d2d.quantizer_outputs[1],
            "output_d2d": output_d2d,
            "step_stats": step_stats,
            "quant_outputs": quant_outputs,
        }

    def get_total_loss_and_backward_single(self):
        tot_sup_loss = 0.0
        tot_contr_loss = 0.0
        tot_loss = 0.0
        for i in range(self.ACCUMULATE_GRADIENTS):
            autoenc_input_batch, autoenc_output_batch, hard_neg_docs_subbatch = (
                gen_input_and_output_batch(
                    autoencoder=self.autoencoder,
                    QD_DATASET=self.QD_DATASET,
                    documents=self.documents,
                    queries=self.queries,
                    mask_docs=self.mask_docs,
                    mask_queries=self.mask_queries,
                    hard_neg_docs=self.hard_neg_docs,
                    SEQ_LEN=self.SEQ_LEN,
                    MASK_PERC=self.MASK_PERC,
                )
            )
            output_d2d: AutoencoderOutput = self.autoencoder(
                autoenc_input_batch,
                autoenc_output_batch,
                steps=1,
                encoding_only=True,
            )
            source_vectors_with_grad = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                end_idx=self.BATCH_SIZE // 2,
            )
            source_level_vectors_with_grad = extract_vectors(
                self.CONTR_LEVEL_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                end_idx=self.BATCH_SIZE // 2,
            )
            target_vectors_with_grad = extract_vectors(
                self.CONTR_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                start_idx=self.BATCH_SIZE // 2,
            )
            target_level_vectors_with_grad = extract_vectors(
                self.CONTR_LEVEL_VECTOR_NAME,
                output=output_d2d,
                autoencoder=self.autoencoder,
                start_idx=self.BATCH_SIZE // 2,
            )
            quantizer_indices = output_d2d.quantizer_outputs[1].reshape(
                -1, self.autoencoder.effective_length
            )
            source_indices = quantizer_indices[: self.BATCH_SIZE // 2]
            target_indices = quantizer_indices[self.BATCH_SIZE // 2 :]
            source_mask = autoenc_input_batch["attention_mask"][: self.BATCH_SIZE // 2]
            target_mask = autoenc_input_batch["attention_mask"][self.BATCH_SIZE // 2 :]
            self.source_quantizer_indices_cache = self.source_masks = []
            self.source_quantizer_indices_cache = source_indices
            self.source_masks = source_mask
            self.target_quantizer_indices_cache = self.target_masks = []
            self.target_masks = target_mask
            self.target_quantizer_indices_cache = target_indices

            if self.hard_neg_docs is not None:
                output_hard_neg: AutoencoderOutput = self.autoencoder(
                    hard_neg_docs_subbatch,
                    hard_neg_docs_subbatch,
                    steps=1,
                    encoding_only=True,
                )
                hard_neg_vectors = extract_vectors(
                    self.CONTR_VECTOR_NAME,
                    output=output_hard_neg,
                    autoencoder=self.autoencoder,
                )
                hard_neg_level_vectors = extract_vectors(
                    self.CONTR_LEVEL_VECTOR_NAME,
                    output=output_hard_neg,
                    autoencoder=self.autoencoder,
                )
                hard_level_mask = hard_neg_docs_subbatch["attention_mask"]
                target_incl_hard_neg = torch.cat(
                    (hard_neg_vectors, target_vectors_with_grad), dim=0
                )
                target_level_incl_hard_neg = torch.cat(
                    (hard_neg_level_vectors, target_level_vectors_with_grad), dim=0
                )
                target_masks_incl_hard_neg = torch.cat(
                    (hard_level_mask, target_mask), dim=0
                )
            else:
                target_incl_hard_neg = target_vectors_with_grad
                target_level_incl_hard_neg = target_level_vectors_with_grad
                target_masks_incl_hard_neg = target_mask
            contr_loss_val = contr_loss(
                start_from=0,
                source_vectors=source_vectors_with_grad,
                target_vectors=target_incl_hard_neg,
                hinge_margin=self.HINGE_MARGIN,
                mask_match=self.CONTR_MASK_NON_UNIQUE,
                source_mask=source_mask,
                target_mask=target_masks_incl_hard_neg,
                exact_match=self.EXACT_MATCH_COLBERT,
                source_level_vectors=source_level_vectors_with_grad,
                target_level_vectors=target_level_incl_hard_neg,
                weigh_by_idf=False,
                source_indices=source_indices,
                target_indices=target_indices,
                use_topk=-1,
                loss=self.CONTR_LOSS_TYPE,
                softmax_loss_temperature=self.SOFTMAX_TEMPERATURE,
            )
            if self.SUP_LOSS_ALPHA > 0.0:
                sup_loss = self.autoencoder.loss(autoenc_output_batch, output_d2d)
            else:
                sup_loss = torch.zeros(1, device=self.autoencoder.device())
            codebook_vec = extract_latent_embeddings(
                self.ENTROPY_VECTOR_NAME, self.autoencoder
            )
            entr_loss1 = entropy_regularization(
                sequence_vectors=extract_vectors(
                    which=self.ENTROPY_VECTOR_NAME,
                    output=output_d2d,
                    autoencoder=self.autoencoder,
                    end_idx=self.BATCH_SIZE // 2,
                ),
                # source_indices=source_indices,
                codebook_vectors=codebook_vec,
                mask=source_mask,
                distance=self.ENTROPY_DISTANCE,
                softmax_temperature=self.ENTROPY_SOFTMAX,
                softmax_midpoint=self.ENTROPY_SOFTMAX_MIDPOINT,
            )
            entr_loss2 = entropy_regularization(
                sequence_vectors=extract_vectors(
                    which=self.ENTROPY_VECTOR_NAME,
                    output=output_d2d,
                    autoencoder=self.autoencoder,
                    start_idx=self.BATCH_SIZE // 2,
                ),
                # source_indices=target_indices,
                codebook_vectors=codebook_vec,
                mask=target_mask,
                distance=self.ENTROPY_DISTANCE,
                softmax_temperature=self.ENTROPY_SOFTMAX,
                softmax_midpoint=self.ENTROPY_SOFTMAX_MIDPOINT,
            )
            entr_loss = (entr_loss1 + entr_loss2) / 2.0
            tot_loss = (
                sup_loss * self.SUP_LOSS_ALPHA
                + contr_loss_val * self.CONTR_LOSS_ALPHA
                + entr_loss * self.ENTROPY_LOSS_ALPHA
            )
            (tot_loss / self.ACCUMULATE_GRADIENTS).backward()
            # for param in self.autoencoder.parameters():
            #     if param.grad is not None and torch.isnan(param.grad).any():
            #         torch.nan_to_num_(param.grad, nan=0.0, posinf=0.0, neginf=0.0)
            tot_sup_loss = sup_loss.item()
            tot_contr_loss = contr_loss_val.item()
            tot_loss = tot_loss.item()

            documents_indices = output_d2d.quantizer_outputs[1][
                self.BATCH_SIZE // 2 :
            ].long()
            queries_indices = output_d2d.quantizer_outputs[1][
                : self.BATCH_SIZE // 2
            ].long()
            # save some info for logging and tfidf calculation
            step_stats: StreamingTFIDFResults = self.token_counter.step(
                target_indices=documents_indices,
                source_indices=queries_indices,
                source_mask=source_mask,
                target_mask=target_mask,
            )
        return {
            "sup_loss": tot_sup_loss / self.ACCUMULATE_GRADIENTS,
            "contr_loss": tot_contr_loss / self.ACCUMULATE_GRADIENTS,
            "entr_loss": entr_loss / self.ACCUMULATE_GRADIENTS,
            "tot_loss": tot_loss / self.ACCUMULATE_GRADIENTS,
            "quant_indices": output_d2d.quantizer_outputs[1],
            "output_d2d": output_d2d,
            "step_stats": step_stats,
            "quant_outputs": output_d2d.quantizer_outputs[1],
        }
