from dataclasses import dataclass
import math
import torch
from torch import autocast, nn
from torch.nn import functional as F

from tqdm import tqdm
import torch.utils.data as tud

from coarsebind_public.mol_encoder.models.loose_modules.activations import NewGELU


def get_stop_token_embs(x, idx, tokenizer, n_stop_tokens=1):
    """
    Args:
        x: batch X seq X hidden floattensor of logits.
        idx: batch X seq token long-tensor
        tokenizer: a tokenizer.
    """
    Is, Js = (idx == tokenizer.stop_token).nonzero(as_tuple=True)
    try:
        assert Js.max() + n_stop_tokens - 1 < x.shape[1]
    except Exception as Ex:
        print("Cant extract enough stops, do you have spaces? Heres your strings.")
        for I in Is:
            print(tokenizer.decode(idx[I].tolist(), end_at_stop=False, de_fim=False, special=True))
        raise Ex

    # # A quick debug that is what is being extracted is [stop][unk]
    # for I,J in zip(Is,Js):
    #     print('CHECKING STOP EXTRACTION')
    #     print(tokenizer.decode(idx[I,J:J+n_stop_tokens].tolist(), special=True, end_at_stop = False, de_fim = False))
    #     print('----')

    stop_embs = torch.cat([x[Is, Js + K] for K in range(n_stop_tokens)], -1)

    # # another quick debug that both vectors are nonzero:
    # import numpy, sys
    # numpy.set_printoptions(threshold=sys.maxsize)
    # print('STOP EMBS?', [X.mean() for X in stop_embs.mean(0).chunk(n_stop_tokens, dim=-1)],
    #       [X.std() for X in stop_embs.mean(0).chunk(n_stop_tokens, dim=-1)])
    # # print(stop_embs[:2])

    if not stop_embs.shape[0] == x.shape[0]:
        print(stop_embs.shape, x.shape)
        for row in idx:
            print(tokenizer.decode(row.tolist(), special=True, end_at_stop=False, de_fim=False))
        raise RuntimeError(
            "Some smiles in the batch do not have stop tokens. Did some tokenizations fail?"
        )

    return stop_embs


@dataclass
class SmilesTransformerConfig:
    n_layer: int = 4
    n_embd: int = 128
    n_head: int = 4
    n_seq: int = 256
    n_tok: int = 100
    biases: bool = True  # Whether to use biases in the linear layers.
    norm_embed: bool = False  # Whether to normalize post-embed.
    n_stop_tokens: int = 1
    device: None = torch.device("cpu")
    dtype: None = torch.float


class RotaryEmbedding(torch.nn.Module):
    def __init__(
        self,
        n_seq=256,
        n_embd: int = 128,
        n_tok: int = 512,
        n_head=8,
        norm_embed=False,
        device=torch.device("cpu"),
        dtype=torch.float,
        base=10000,
    ):
        """
        Eq. (34) of https://arxiv.org/pdf/2104.09864.pdf
        also inspired by https://blog.eleuther.ai/rotary-embeddings/
        The rotation is done after the hidden dimension is split into heads.
        so, the cached sin/cos tensors operate on a space (n_embd // n_head)

        Args:
            n_seq: Maximum sequence dimension.
            n_embd: embedding dimension (pre head split)
            n_tok: size of tokenspace.
            n_head: number of attention heads.
        """
        super().__init__()
        assert n_embd % (2 * n_head) == 0
        inv_freq = 1.0 / (
            base
            ** (torch.arange(0, (n_embd // n_head), 2, device=device).float() / (n_embd // n_head))
        )
        t = torch.arange(n_seq, device=device).type_as(inv_freq)
        freqs = torch.einsum("i,j->ij", t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)  # (nseq X n_embd//n_head)
        self.cos_cached = emb.cos()
        self.sin_cached = emb.sin()
        self.n_head = n_head
        self.n_seq = n_seq
        self.n_embd = n_embd
        if norm_embed:
            raise Exception("Depreciate soon.")
            self.tok_emb = nn.Sequential(
                nn.Embedding(n_tok, n_embd, device=device, dtype=dtype),
                nn.LayerNorm(n_embd),
            )
        else:
            self.tok_emb = nn.Embedding(n_tok, n_embd, device=device, dtype=dtype)

    def forward(self, idx):
        return self.tok_emb(idx)

    def rotate(self, x):
        """
        Rotate along the embedding dimension.
        """
        return torch.cat([-x[..., x.shape[-1] // 2 :], x[..., : x.shape[-1] // 2]], -1)

    def rotary_embed(self, q, k):
        """
        Args:
            q: A query (batch, n_head, seq, n_embd//n_head)
            k: A key. (batch, n_head, seq, n_embd//n_head)
        Returns:
            q,k (with the multiplicative rotary embedding applied.)
        """
        seq_len = q.shape[2]
        cos = self.cos_cached[None, None, :seq_len, :].to(q.device)
        sin = self.sin_cached[None, None, :seq_len, :].to(q.device)
        return (q * cos) + (self.rotate(q) * sin), (k * cos) + (self.rotate(k) * sin)


class RotarySelfAttention(nn.Module):
    """
    A self attention block with rotary relative position encoding.
    (and causality)
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.biases)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.biases)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.n_seq, config.n_seq)).view(
                1, 1, config.n_seq, config.n_seq
            ),
        )
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x, rotary_embedding: RotaryEmbedding):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q, k = rotary_embedding.rotary_embed(q, k)
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y


class RotaryBlock(nn.Module):
    """A causal, rotary Self-Attention Block."""

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = RotarySelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlpf = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.biases),
            NewGELU(),
            nn.Linear(4 * config.n_embd, config.n_embd, bias=config.biases),
        )

    def forward(self, x, rotary_embedding: RotaryEmbedding):
        x = x + self.attn(self.ln_1(x), rotary_embedding)
        x = x + self.mlpf(self.ln_2(x))
        return x


class RotarySmilesTransformer(nn.Module):
    """
    Rotary string transformer for a tokenized graph
    """

    def __init__(self, config: SmilesTransformerConfig):
        super().__init__()
        self.n_seq = config.n_seq
        self.n_tok = config.n_tok
        self.n_embd = config.n_embd
        self.n_stop_tokens = config.n_stop_tokens

        if config.norm_embed:
            self.norm_embed = nn.LayerNorm(config.n_embd)
        else:
            self.norm_embed = nn.Identity()
        self.emb = RotaryEmbedding(
            n_embd=config.n_embd,
            n_seq=config.n_seq,
            n_tok=config.n_tok,
            n_head=config.n_head,
            device=config.device,
            norm_embed=config.norm_embed,
            dtype=config.dtype,
        )
        self.transformer = nn.ModuleDict(
            dict(
                h=nn.ModuleList([RotaryBlock(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.n_tok, bias=False)

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params / 1e6,))

    def encode(self, idx, tokenizer):
        """
        Only returns the vector of the [STOP] token
        which MUST be the last token before [PAD]
        """
        x = self.xformer(idx)
        return get_stop_token_embs(x, idx, tokenizer, n_stop_tokens=self.n_stop_tokens)

    def generate_greedy(self, prefix=torch.tensor([[1]]), stop_token=2, max_len=256):
        """
        Autoregressively generate. c.f. https://huggingface.co/blog/how-to-generate
        Note: the decode from forward is not autoregressive.
        this is and stops upon hitting a stop token.
        """
        generated = torch.clone(prefix)
        with torch.no_grad():
            while (generated.flatten()[-1].item() != stop_token) and generated.shape[
                1
            ] < self.n_seq:
                Y = self.forward(generated, decode=False, sampled=False)
                _, next_char = torch.topk(Y[0, generated.shape[1] - 1], k=1, dim=-1)
                generated = torch.cat([generated, next_char.unsqueeze(0)], 1)
        return generated

    def generate_topk(
        self,
        prefix=torch.tensor([[1]]),
        stop_token=2,
        inv_temp=2,
        k=10,
    ):
        """
        Args:
            inj_token: (int) if not none, will perform token injection for clip gen.
            inj_hidden: torch.float tensor if not none will be injected over inj_token
        https://arxiv.org/pdf/1805.04833.pdf
        """
        generated = torch.clone(prefix).to(self.lm_head.weight.device)
        with torch.no_grad():
            while (generated.flatten()[-1].item() != stop_token) and generated.shape[
                1
            ] < self.n_seq:
                Y = self.forward(generated)  # Y is
                logits, inds = torch.topk(Y[0, generated.shape[1] - 1], k=k, dim=-1)
                probs = F.softmax(logits * inv_temp, dim=-1)
                inds_of_inds = torch.multinomial(probs, num_samples=1).squeeze()
                generated = torch.cat(
                    [generated, (inds[inds_of_inds]).unsqueeze(0).unsqueeze(0)], 1
                )
        return generated[0].tolist()

    def generate_topk_batch(self, prefix=[[0]], stop_token=2, pad_token=0, inv_temp=2, k=10):
        """
        Works for variable length prefixes.
        """
        batch_size = len(prefix)
        min_prefix_len = min([len(p) for p in prefix])
        # fill in a zero-ed out prefix tensor.
        # which will overwrite the new columns each iteration.
        prefix_t = torch.zeros(
            (batch_size, self.n_seq),
            device=self.lm_head.weight.device,
            dtype=torch.long,
        )
        for K, row in enumerate(prefix):
            prefix_t[K, : len(row)] = torch.tensor(
                row, device=self.lm_head.weight.device, dtype=torch.long
            )

        current_t = prefix_t.clone()
        idx = min_prefix_len - 2
        has_stopped = []

        while len(has_stopped) < batch_size and idx < self.n_seq - 1:
            current_t[prefix_t > 0] = prefix_t[prefix_t > 0]
            x = self.emb(current_t)

            for block in self.transformer.h:
                x = block(x, self.emb)
            x = self.transformer.ln_f(x)
            logits = self.lm_head(x)
            logits, inds = torch.topk(logits[:, idx], k=k, dim=1)
            probs = F.softmax(logits * inv_temp, dim=1)
            inds_of_inds = torch.multinomial(probs, num_samples=1).squeeze()
            last_tokens = inds[torch.arange(batch_size), inds_of_inds]
            last_tokens[has_stopped] = pad_token

            current_t[:, idx + 1] = last_tokens

            idx += 1
            has_stopped = (current_t == stop_token).nonzero(as_tuple=True)[0]

        return current_t.tolist()

    # TODO: some consolidation here
    def xformer_blocks(
        self, x: torch.Tensor, apply_norm: bool = True, output_logits: bool = False
    ) -> torch.Tensor:
        for block in self.transformer.h:
            x = block(x, self.emb)
        if apply_norm:
            x = self.transformer.ln_f(x)

        if output_logits:
            return self.lm_head(x)
        else:
            return x

    def generate_topk_with_inj(
        self,
        prefix=[0],
        stop_token=2,
        inv_temp=1,
        k=50,  # only the topk logits can be randomly gen'd
        inj_token=None,
        inj_payload=None,
    ):
        """
        Like the above, but works in the embedding space rather than token space, so it can do
        clip injection.

        Now supporting multiple token injection.
        This will cause the prefix to implicitly lengthened.

        Args:
            inj_token: (int) if not none, will perform token injection for clip gen.
            inj_payload: torch.float tensor if not none will be injected over inj_token
                         [just n_hidden]
        https://arxiv.org/pdf/1805.04833.pdf
        """
        assert (
            len(prefix) <= self.n_seq
        ), f"Cannot forward sequence of length {len(prefix)}, n_seq is only {self.n_seq}"
        # Inject the payload
        if self.n_stop_tokens > 1:
            # the payload will be dim_embed * n_stop_tokens extra tokens in
            # the prefix are made here and the payload is chunked and appropriately injected.
            assert inj_payload.shape[-1] == self.n_embd * self.n_stop_tokens
            prefix_x = self.emb(
                torch.tensor(prefix, device=inj_payload.device, dtype=torch.long).unsqueeze(0)
            )
            # split up the tensor around the injection to accomodate var length payload.
            inj_index = prefix.index(inj_token)
            prefix_x = torch.cat(
                [prefix_x[:, :inj_index]]
                + list(inj_payload.unsqueeze(0).chunk(self.n_stop_tokens, dim=-1))
                + [prefix_x[:, inj_index + self.n_stop_tokens :]],
                1,
            )
        else:
            prefix_x = self.emb(
                torch.tensor(prefix, device=inj_payload.device, dtype=torch.long).unsqueeze(0)
            )
            prefix_x[0, prefix.index(inj_token)] = inj_payload
        # from now on embedded vectors will be generated by concatenation along the seq dim
        generated = []
        last_token = 0
        with torch.no_grad():
            while (last_token != stop_token) and len(generated) < self.n_seq - 1:
                if len(generated):
                    # concatenate the generated tokens onto the prefix.
                    gen_x = self.emb(
                        torch.tensor(
                            generated, device=inj_payload.device, dtype=torch.long
                        ).unsqueeze(0)
                    )
                    x = torch.cat([prefix_x, gen_x], 1)
                else:
                    x = prefix_x
                for block in self.transformer.h:
                    x = block(x, self.emb)
                x = self.transformer.ln_f(x)
                logits = self.lm_head(x)
                logits, inds = torch.topk(logits[0, len(prefix) + len(generated) - 1], k=k, dim=-1)
                probs = F.softmax(logits * inv_temp, dim=-1)
                inds_of_inds = torch.multinomial(probs, num_samples=1).squeeze()
                last_token = inds[inds_of_inds].item()
                generated.append(last_token)
        return prefix + generated

    def generate_top_k_with_inj_batch(
        self,
        prefix=[0],
        stop_token=2,
        pad_token=0,
        inv_temp=1,
        k=50,  # only the topk logits can be randomly gen'd
        inj_token=None,
        inj_payload=None,
        as_tensor=False,
    ):
        batch_size = inj_payload.size(0)
        assert inj_payload.dim() == 2
        assert k >= 1

        if self.n_stop_tokens > 1:
            # the payload will be dim_embed * n_stop_tokens extra tokens in
            # the prefix are made here and the payload is chunked and appropriately injected.
            assert inj_payload.shape[-1] == self.n_embd * self.n_stop_tokens
            prefix_x = (
                self.emb(torch.tensor(prefix, device=inj_payload.device, dtype=torch.long))
                .unsqueeze(0)
                .repeat(batch_size, 1, 1)
            )
            # split up the tensor around the injection to accomodate var length payload.
            inj_index = prefix.index(inj_token)
            prefix_x = torch.cat(
                [prefix_x[:, :inj_index]]
                + list(inj_payload.unsqueeze(1).chunk(self.n_stop_tokens, dim=-1))
                + [prefix_x[:, inj_index + self.n_stop_tokens :]],
                1,
            )
        else:
            assert inj_payload.shape[-1] == self.n_embd
            prefix_x = (
                self.emb(torch.tensor(prefix, device=inj_payload.device, dtype=torch.long))
                .unsqueeze(0)
                .repeat(batch_size, 1, 1)
            )
            # Inject the payload
            prefix_x[:, prefix.index(inj_token), :] = inj_payload

        generated = torch.tensor([], dtype=torch.int64, device=inj_payload.device)
        has_stopped = []
        idx = 0
        # NOTE: Fixed issue where if a batch did not stop
        # before hitting (seq_len - len(prefix)) it would blow up batch
        while len(has_stopped) < batch_size and idx < self.n_seq - len(prefix):
            if idx > 0:
                gen_x = self.emb(generated)
                x = torch.cat([prefix_x, gen_x], 1)
            else:
                x = prefix_x
            # batch x len(seq) x num_tokens
            logits = self.xformer_blocks(x, apply_norm=True, output_logits=True)
            # logits->batch_size x k , inds-> batch_size x k
            logits_topk, inds_topk = torch.topk(logits[:, len(prefix) + idx - 1], k=k, dim=1)
            probs = F.softmax(logits_topk * inv_temp, dim=1)
            inds_of_inds = torch.multinomial(probs, num_samples=1).squeeze(-1)
            last_tokens = inds_topk[torch.arange(batch_size), inds_of_inds]
            # if any of the batch has stopped, set their last token to pad, don't want to generate anymore
            last_tokens[has_stopped] = pad_token
            if len(generated):
                generated = torch.cat([generated, last_tokens.clone().unsqueeze(1)], dim=1).long()
            else:
                generated = last_tokens.clone().unsqueeze(1).long()
            idx += 1
            has_stopped = (generated == stop_token).nonzero(as_tuple=True)[0]

        # if anything hasn't stopped yet by the time it reaches the threshold, add stop token
        num_not_stopped = batch_size - len(has_stopped)

        if num_not_stopped:
            # print(f"WARNING: {num_not_stopped} sequences did not stop before reaching max length. forcing stop.")
            # get all indices that are not in has_stopped
            not_stopped = torch.tensor([i for i in range(batch_size) if i not in has_stopped])
            # create a final last_token that's all pad_token except for the not_stopped
            final_pad_or_stopped = torch.tensor([pad_token] * batch_size, device=inj_payload.device)
            final_pad_or_stopped[not_stopped] = stop_token
            # set their last token to stop token
            generated[not_stopped, -1] = stop_token
            # generated = torch.cat([generated, final_pad_or_stopped.clone().unsqueeze(1)], dim=1).long()

        if as_tensor:
            return torch.cat(
                [
                    torch.tensor(prefix, dtype=torch.long, device=generated.device)
                    .unsqueeze(0)
                    .repeat(batch_size, 1),
                    generated,
                ],
                dim=1,
            )

        token_batch = [prefix + output for output in generated.tolist()]
        return token_batch

    # largely ripped from https://github.com/jarobyte91/pytorch_beam_search
    def generate_beam_search_batch(
        self,
        prefix=torch.tensor([[1]]),
        beam_width=5,
        predictions=20,
        batch_size=128,
        inj_token=None,  # Injection token for conditioning
        inj_payload=None,  # Payload to inject at injection token location
        stop_token=2,
        force_stop=False,
        progress_bar=0,
    ):
        # get the prefix embeddings and inject the payload
        prefix_x = (
            self.emb(torch.tensor(prefix, device=inj_payload.device, dtype=torch.long))
            .unsqueeze(0)
            .repeat(inj_payload.size(0), 1, 1)
        )
        prefix_x[:, prefix.index(inj_token), :] = inj_payload

        # prefix -> tensor
        batch_prefix = (
            torch.tensor(prefix, device=inj_payload.device, dtype=torch.long)
            .unsqueeze(0)
            .repeat(inj_payload.size(0), 1)
        )

        next_probabilities = self.xformer_blocks(prefix_x, apply_norm=True, output_logits=True)[
            :, -1, :
        ]
        vocabulary_size = next_probabilities.shape[-1]
        probabilities, idx = next_probabilities.log_softmax(-1).topk(k=beam_width, dim=-1)

        # sets up our initial beam search state of (batch_size * beam_width, seq_len)
        X = (
            batch_prefix.repeat((beam_width, 1, 1))
            .transpose(0, 1)
            .flatten(end_dim=-2)
            .to(inj_payload.device)
        )

        next_chars = idx.reshape(-1, 1)
        # sets up beam search state but with embs so (batch_size * beam_width, seq_len, emb_dim)
        injected_X = (
            prefix_x.unsqueeze(1)
            .repeat(1, beam_width, 1, 1)
            .view(-1, prefix_x.size(1), prefix_x.size(2))
        )

        # concatenate the next token to the prefix
        X = torch.cat([X, next_chars], dim=-1)
        encoded_next_char = self.emb(next_chars)
        injected_X = torch.cat([injected_X, encoded_next_char], dim=1)

        predictions_iterator = range(predictions - 1)
        if progress_bar > 0:
            predictions_iterator = tqdm(predictions_iterator)

        # continues for up to predictions - 1 steps
        for i in predictions_iterator:
            dataset = tud.TensorDataset(X, injected_X)
            loader = tud.DataLoader(dataset, batch_size=batch_size)
            iterator = iter(loader)

            if progress_bar > 1:
                iterator = tqdm(iterator)

            next_probabilities = []
            for x, inj_x in iterator:
                next_probabilities.append(
                    self.xformer_blocks(inj_x, apply_norm=True, output_logits=True)[
                        :, -1, :
                    ].log_softmax(-1)
                )
            next_probabilities = torch.cat(next_probabilities, dim=0)
            next_probabilities = next_probabilities.view(-1, beam_width, vocabulary_size)
            probabilities = probabilities.unsqueeze(-1) + next_probabilities

            probabilities = probabilities.flatten(start_dim=1)
            probabilities, idx = probabilities.topk(k=beam_width, dim=-1)

            next_chars = torch.remainder(idx, vocabulary_size).flatten().unsqueeze(-1)
            best_candidates = (idx // vocabulary_size).long()
            best_candidates += (
                torch.arange(X.size(0) // beam_width, device=X.device).unsqueeze(-1) * beam_width
            )
            X = X[best_candidates].flatten(end_dim=-2)
            X = torch.cat((X, next_chars), dim=1)

            encoded_next_char = self.emb(next_chars)
            injected_X = injected_X[best_candidates].flatten(end_dim=-3)
            # print(injected_X.shape, encoded_next_char.shape)
            injected_X = torch.cat([injected_X, encoded_next_char], dim=1)
        if force_stop:
            # append stop token to all sequences
            stop_token_tensor = torch.tensor(
                [[stop_token]], device=injected_X.device, dtype=torch.long
            )
            X = torch.cat([X, stop_token_tensor.repeat(X.size(0), 1)], dim=1)
        # modify the output to be (batch_size, beam_width, seq_len)
        return X.view(-1, beam_width, X.size(-1)), probabilities

    def xformer(self, idx):
        """
        Args:
            idx: torch longtensor of token indices.

        Returns encoding of all entries in batch.
        """
        _, t = idx.size()
        assert t <= self.n_seq, f"Cannot forward sequence of length {t}, n_seq is only {self.n_seq}"
        x = self.emb(idx)
        for block in self.transformer.h:
            x = block(x, self.emb)
        x = self.transformer.ln_f(x)
        return x

    def decode_logits(self, logits):
        probs = F.softmax(logits, dim=-1)
        _, idx_next = torch.topk(probs, k=1, dim=-1)
        return logits, idx_next.squeeze()

    def forward(self, idx):
        """
        Args:
            idx: torch longtensor of token indices.
        """
        x = self.xformer(idx)
        logits = self.lm_head(x)
        return logits

    def forward_with_stop_emb(self, idx, tokenizer):
        """
        I made this a separate routine because of issues with torch DataParallel
        and functions with variable numbers of return values.
        Args:
            idx: torch longtensor of token indices.
        """
        x = self.xformer(idx)
        logits = self.lm_head(x)
        return logits, get_stop_token_embs(x, idx, tokenizer, n_stop_tokens=self.n_stop_tokens)

    def forward_with_stop_emb_and_replacement(
        self, idx, injection, tokenizer, inject_token="[UNK]"
    ):
        """
        This is specifically for e2e-CLIP a-la clipCAP.
        It injects tokens in place of
        special_token and also returns the stop-emb.

        Args:
            idx: torch longtensor of token indices. (batch X seq)
            injection: (batch X seq X emb_dim)

        Now supports multiple injected tokens:
            ie: n_stop_tokens following the clip token will
               be overwritten by the injection.
               At training time these extra tokens are appropriately
               supplied and not split over.
        """
        _, t = idx.size()
        assert t <= self.n_seq, f"Cannot forward sequence of length {t}, n_seq is only {self.n_seq}"
        assert (idx >= 0).all()

        x = self.emb(idx)
        # Injection of the special tokens
        with autocast(enabled=False, device_type="cuda"):
            hole_Is, hole_Js = (idx == tokenizer.vocab[inject_token]).nonzero(as_tuple=True)
            if torch.numel(hole_Js) > 0:
                if self.n_stop_tokens > 1:
                    chunks = injection[hole_Is].chunk(self.n_stop_tokens, dim=-1)
                    for K in range(self.n_stop_tokens):
                        x[hole_Is, hole_Js + K] = chunks[K][hole_Is]
                else:
                    x[hole_Is, hole_Js] = injection[hole_Is]
        # regular old xformer.
        for block in self.transformer.h:
            x = block(x, self.emb)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits, get_stop_token_embs(x, idx, tokenizer, self.n_stop_tokens)

    def forward_with_replacement(self, idx, injection, tokenizer, inject_token="[UNK]"):
        """
        This is specifically for e2e-CLIP a-la clipCAP.
        It injects tokens in place of special_token.
        Supports multiple injection possibility.
        Args:
            idx: torch longtensor of token indices. (batch X seq)
            injection: (batch X seq X emb_dim)
        """
        _, t = idx.size()
        assert injection.shape[-1] % self.n_embd == 0
        ntok_inj = injection.shape[-1] // self.n_embd
        assert ntok_inj == self.n_stop_tokens
        assert t <= self.n_seq, f"Cannot forward sequence of length {t}, n_seq is only {self.n_seq}"
        x = self.emb(idx)
        # Injection of the special tokens
        with autocast(enabled=False, device_type="cuda"):
            hole_Is, hole_Js = (idx == tokenizer.vocab[inject_token]).nonzero(as_tuple=True)

            # #DEBUG Check that what is being injected over is [UNK]
            # print('CHECKING STOP Injection')
            # for I,J in zip(hole_Is, hole_Js):
            #     print(tokenizer.decode(idx[I,J:J+self.n_stop_tokens].tolist(), special=True, end_at_stop = False, de_fim = False))
            # print('----')

            chunked_inj = injection.chunk(ntok_inj, dim=-1)
            for k in range(ntok_inj):
                x[hole_Is, hole_Js + k] = chunked_inj[k][hole_Is]
        # regular old xformer.
        for block in self.transformer.h:
            x = block(x, self.emb)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        return logits
