# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Full definition of a GPT NeoX Language Model, all of it in this single file.

Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""

import math
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from collections.abc import Callable

import numpy as np
import torch
import torch.nn as nn
from typing_extensions import Self

import torch.distributed as dist

from tqdm import tqdm

from litgpt.model import GPT, Block
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer

####################################################################################################
# Helper functions.
####################################################################################################


def cumulative_mean_pooling(hidden_states: torch.Tensor):
    """Cumulative mean pooling on the sequence dimension."""
    # Calculate the cumulative sum along the seq_len dimension
    cumsum = hidden_states.cumsum(dim=1)  # (bsz, seq_len, d)
    # Creating a range tensor
    seq_len_range = (
        torch.arange(1, hidden_states.size(1) + 1, device=cumsum.device).unsqueeze(-1).unsqueeze(0)
    )  # (1, seq_len, 1)
    # Dividing the cumulative sum by the range tensor to get the cumulative mean
    cumulative_mean = cumsum.true_divide(seq_len_range)  # (bsz, seq_len, d)

    return cumulative_mean


class PSLM(nn.Module):
    def __init__(
        self,
        config: Config,
        objective: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        gradient_checkpointing: bool = False,
        negatives_cross_device: bool = False,
        negatives_cross_device_group_size: bool = False,
        suffix_is_prefix: bool = False,
        batch_prefix_and_suffix: bool = False,
        flip_rope_embedding_suffix: bool = False,
        add_suf_pre_tokens: bool = False,
        tokenizer: Optional[Tokenizer] = None,
        nope_pos_embeddings: bool = False,
        keep_eos: bool = False,
    ) -> None:
        super().__init__()
        # HACK: added config to the class so `litgpt.optim.get_param_groups` won't complain
        self.config = config
        if nope_pos_embeddings:
            config.nope_pos_embeddings = True
        else:
            config.nope_pos_embeddings = False

        self.prefix_model = GPTRetrieval(
            config,
            gradient_checkpointing,
            flip_rope_embedding_suffix,
            add_suf_pre_tokens,
            nope_pos_embeddings,
        )  # model with causal attention
        self.suffix_is_prefix = suffix_is_prefix
        if suffix_is_prefix:
            print("--------------Setting Suffix as Prefix--------------")
            self.suffix_model = self.prefix_model
        else:
            self.suffix_model = GPTRetrieval(
                config, gradient_checkpointing, flip_rope_embedding_suffix, add_suf_pre_tokens, nope_pos_embeddings
            )  # model with anti-causal attention

        self.batch_prefix_and_suffix = batch_prefix_and_suffix

        self.objective = objective
        self.max_seq_length = config.block_size
        self.keep_eos = keep_eos

        self.negatives_cross_device = negatives_cross_device
        self.negatives_cross_device_group_size = negatives_cross_device_group_size

        if not dist.is_initialized():
            print("WARNING: Cannot do retrieval training without distributed initialization, undefined behavior")
        else:
            self.rank = dist.get_rank()
            self.world_size = dist.get_world_size()
            self.grp_rank = dist.get_rank() # if there is only one group, then this is the same as rank

        if self.negatives_cross_device:

            # We'll declare special distributed groups for the negatives_cross_device gathering.
            # We will use this group for all-gathering the embedding features produced by different devices.

            # Default case is to have one group for all ranks.
            if self.negatives_cross_device_group_size is None:
                self.negatives_cross_device_group_size = self.world_size

            # Confirm a valid group size has been provided.
            if self.world_size % self.negatives_cross_device_group_size != 0:
                raise ValueError(
                    f"world_size ({self.world_size}) must be divisible by negatives_cross_device_group_size ({self.negatives_cross_device_group_size})"
                )
            # When we have more than one group, we need to declare all the groups in the same order, on all ranks.
            # Then, store the one this rank belongs to as the main handle we will use.
            all_rank_groups = [
                [i for i in range(j, j + self.negatives_cross_device_group_size)]
                for j in range(0, self.world_size, self.negatives_cross_device_group_size)
            ]
            self.negatives_cross_device_all_groups = [dist.new_group(ranks=grp_ranks) for grp_ranks in all_rank_groups]
            self.negatives_cross_device_group = self.negatives_cross_device_all_groups[
                self.rank // self.negatives_cross_device_group_size
            ]
            # self.grp_rank = int(dist.get_group_rank(self.negatives_cross_device_group, self.rank))
            # function breaks compile, compute manually
            self.grp_rank = self.rank % self.negatives_cross_device_group_size

        self.add_suf_pre_tokens = add_suf_pre_tokens
        if self.add_suf_pre_tokens:
            print("--------------Adding Suffix and Prefix Tokens--------------")
            print(f"This uses a XXXX-13 sequence length of self.max_seq_length {self.max_seq_length}")
        self.tokenizer = tokenizer
        self.nope_pos_embeddings = nope_pos_embeddings

    @torch._dynamo.disable(recursive=False)
    def forward(
        self,
        idx: Tuple[torch.Tensor, torch.Tensor],
        position_ids: Optional[torch.Tensor] = None,
        attn_mask: Tuple[torch.Tensor, torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        mean_pooling: bool = False,
        fixed_length: bool = False,
        track_memory_finegrained: bool = False,
    ) -> dict:
        memory_stats = {}
        def check_memory(rank=self.rank, curr_device=idx[0].device):
            memory_allocated_per_gpu = torch.cuda.memory_allocated(curr_device) / 1024**3
            memory_reserved_per_gpu = torch.cuda.memory_reserved(curr_device) / 1024**3
            max_memory_allocated_per_gpu = torch.cuda.max_memory_allocated(curr_device) / 1024**3
            max_memory_reserved_per_gpu = torch.cuda.max_memory_reserved(curr_device) / 1024**3
            # if rank == 0:
            #     print(f"Memory Allocated (curr): {memory_allocated_per_gpu:.2f} GB")
            #     print(f"Memory Allocated (XXXX-13): {max_memory_allocated_per_gpu:.2f} GB")
            #     print(f"Memory Reserved (curr): {memory_reserved_per_gpu:.2f} GB")
            #     print(f"Memory Reserved (XXXX-13): {max_memory_reserved_per_gpu:.2f} GB")
            return {
                "memory_allocated_per_gpu": memory_allocated_per_gpu,
                "max_memory_allocated_per_gpu": max_memory_allocated_per_gpu,
                "memory_reserved_per_gpu": memory_reserved_per_gpu,
                "max_memory_reserved_per_gpu": max_memory_reserved_per_gpu,
            }

        #### Prepare inputs to the models ##########################################################
        prefix_idx, suffix_input_ids = idx
        prefix_attn_mask, suffix_attn_mask = attn_mask

        bsz, _ = prefix_idx.size()
        suffix_bsz = bsz if suffix_input_ids is None else suffix_input_ids.size(0) # suffix could have different batch size (e.g. when finetuning with hard negatives)
        tmp_idx = prefix_idx if suffix_input_ids is None else suffix_input_ids

        if self.add_suf_pre_tokens:
            prefix_attn_mask_add = torch.tensor([[True]]).repeat(prefix_attn_mask.size(0), 1)
            prefix_attn_mask = torch.cat([prefix_attn_mask_add.to(prefix_attn_mask.device), prefix_attn_mask], dim=1)
            prefix_token_idx = torch.tensor([[self.tokenizer.prefix_token_id]]).repeat(prefix_idx.size(0), 1)
            prefix_idx = torch.cat((prefix_token_idx.to(prefix_idx.device), prefix_idx), dim=1)
            if prefix_idx.size(1) > self.max_seq_length:
                prefix_idx = prefix_idx[:, : self.max_seq_length]
                prefix_attn_mask = prefix_attn_mask[:, : self.max_seq_length]
                tmp_idx = (
                    prefix_idx[:, 1:] if suffix_input_ids is None else suffix_input_ids
                )  # overriding the input for suffix
            if suffix_attn_mask is not None:
                suffix_attn_mask_add = torch.tensor([[True]]).repeat(suffix_attn_mask.size(0), 1)
                suffix_attn_mask = torch.cat(
                    [suffix_attn_mask_add.to(suffix_attn_mask.device), suffix_attn_mask], dim=1
                )
                if suffix_attn_mask.size(1) > self.max_seq_length:
                    suffix_attn_mask = suffix_attn_mask[:, : self.max_seq_length]
            attn_mask = (prefix_attn_mask, suffix_attn_mask)
            assert prefix_idx[0][0] == self.tokenizer.prefix_token_id

        suffix_attn_mask = prefix_attn_mask if suffix_input_ids is None else suffix_attn_mask

        if not fixed_length:
            if bsz > 1:
                # if bsz > 1, then there's possibility of having pad_tokens, we need to bring them to the front before reversing the sequence
                # (so that they're at the end after reversing)
                # Creating a mask where True (indicates the padding token)
                if self.keep_eos:
                    mask = (
                        ~suffix_attn_mask.cpu()
                    )  # we're taking the inverse cos initially we have True where there are tokens and False where there are pad tokens
                else:
                    mask = (tmp_idx == self.tokenizer.eos_id).cpu()  # bcos we use eos_id as the pad token
                # Argsorting the mask in descending order (True first)
                sorted_indices = torch.argsort(mask, descending=True, stable=True)
                # Apply the sorted indices to rearrange elements in each row
                tmp_idx = torch.gather(tmp_idx, 1, sorted_indices.to(tmp_idx.device))

        # we reverse the sequence to get the anti-causal taste w/ causal attention mask, otherise we use same sequence w/ anti-causal attention mask
        suffix_idx = torch.flip(tmp_idx, [1])

        if self.add_suf_pre_tokens:
            suffix_token_idx = torch.tensor([[self.tokenizer.suffix_token_id]]).repeat(suffix_idx.size(0), 1)
            suffix_idx = torch.cat((suffix_token_idx.to(suffix_idx.device), suffix_idx), dim=1)
            if suffix_idx.size(1) > self.max_seq_length:
                suffix_idx = suffix_idx[:, : self.max_seq_length]
            # confirm that the suffix_idx's first token is the suffix token
            assert suffix_idx[0][0] == self.tokenizer.suffix_token_id

        #### Call the models to create features ##########################################################
        if not self.batch_prefix_and_suffix:
            prefix_output = self.prefix_model(prefix_idx, position_ids, output_hidden_states=output_hidden_states)
            suffix_output = self.suffix_model(
                suffix_idx, position_ids, output_hidden_states=output_hidden_states, is_suffix=True
            )
        else:
            prefix_output = None
            suffix_output = None

        if not self.batch_prefix_and_suffix:
            prefix_hidden_states_bsz_T_d = prefix_output["hidden_states"][-1]
            suffix_hidden_states_bsz_T_d = suffix_output["hidden_states"][-1]
        else:
            # we cat the two sets of idx and positions in the batch dimension, len=bsz tells us the split index
            # and then break them apart after the forward pass

            combined_idx = torch.cat((prefix_idx, suffix_idx), dim=0)

            assert position_ids is None, "position_ids is not supported with batch_prefix_and_suffix"
            combined_position_ids = None

            combined_output = self.prefix_model(
                combined_idx, combined_position_ids, output_hidden_states=output_hidden_states
            )

            combined_hidden_states_bsz_T_d = combined_output["hidden_states"][-1]
            prefix_hidden_states_bsz_T_d = combined_hidden_states_bsz_T_d[:bsz]
            suffix_hidden_states_bsz_T_d = combined_hidden_states_bsz_T_d[bsz:]

        if track_memory_finegrained:
           memory_stats.update({f"after_model_fwd/{k}": v for k, v in check_memory().items()})


        #### Process the features in prep for contrastive loss #####################################
        if not fixed_length:
            if bsz > 1:
                # Removing the pad_tokens' hidden states as we don't want to calculate loss on them
                prefix_mask = (prefix_idx != self.tokenizer.eos_id) if not self.keep_eos else prefix_attn_mask
                prefix_hidden_states_bsz_T_d = [prefix_hidden_states_bsz_T_d[i][prefix_mask[i]] for i in range(bsz)]
                suffix_mask = (suffix_idx != self.tokenizer.eos_id) if not self.keep_eos else suffix_attn_mask
                suffix_hidden_states_bsz_T_d = [suffix_hidden_states_bsz_T_d[i][suffix_mask[i]] for i in range(suffix_bsz)]

        if self.add_suf_pre_tokens:
            # after the forward pass, we remove the prefix-suffix meta token hidden states. since they've been causally injected into
            # the rest of the sequences' hidden states so the meta tokens still contribute to the loss
            prefix_hidden_states_bsz_T_d = [x_T_d[1:, :] for x_T_d in prefix_hidden_states_bsz_T_d]
            suffix_hidden_states_bsz_T_d = [x_T_d[1:, :] for x_T_d in suffix_hidden_states_bsz_T_d]

        # removing the last token's hidden states where attn_mask is all 1s (BUT we don't do it for phase 2 type data (query-doc))
        if suffix_input_ids is None:
            prefix_hidden_states_bsz_T_d = [x_T_d[:-1, :] for x_T_d in prefix_hidden_states_bsz_T_d]
            suffix_hidden_states_bsz_T_d = [x_T_d[:-1, :] for x_T_d in suffix_hidden_states_bsz_T_d]
        if mean_pooling:
            # prefix_hidden_states_bsz_T_d = [cumulative_mean_pooling(x_T_d.unsqueeze(0)).squeeze(0) for x_T_d in prefix_hidden_states_bsz_T_d]
            # suffix_hidden_states_bsz_T_d = [cumulative_mean_pooling(x_T_d.unsqueeze(0)).squeeze(0) for x_T_d in suffix_hidden_states_bsz_T_d]
            prefix_hidden_states_bsz_T_d = [
                x_T_d.mean(dim=0).unsqueeze(0) for x_T_d in prefix_hidden_states_bsz_T_d
            ]
            suffix_hidden_states_bsz_T_d = [
                x_T_d.mean(dim=0).unsqueeze(0) for x_T_d in suffix_hidden_states_bsz_T_d
            ]
        if suffix_input_ids is None:
            suffix_hidden_states_bsz_T_d = [
                torch.flip(x_T_d, [0]) for x_T_d in suffix_hidden_states_bsz_T_d
            ]  # reversing the token order of suffix_hidden_states to account for the reversing we did before forward pass (this will now line up correctly with the prefix embeddings)
        # we pick the last token (or mean pooled) if we're not doing phase 1 pretraining (e.g. query-doc finetuning)
        prefix_model_output_bsz_d = (
            prefix_hidden_states_bsz_T_d
            if suffix_input_ids is None or mean_pooling
            else [x_T_d[-1, :].unsqueeze(0) for x_T_d in prefix_hidden_states_bsz_T_d]
        )
        suffix_model_output_bsz_d = (
            suffix_hidden_states_bsz_T_d
            if suffix_input_ids is None or mean_pooling
            else [x_T_d[-1, :].unsqueeze(0) for x_T_d in suffix_hidden_states_bsz_T_d]
        )

        outputs = {
            "loss": None,
            "accuracy": None,
        }
        if self.objective is not None:
            if self.negatives_cross_device:
                if track_memory_finegrained:
                    memory_stats.update({f"before_dist_gather/{k}": v for k, v in check_memory().items()})
                all_gathered_suffixes, t_sizes = self._dist_gather_tensor(suffix_model_output_bsz_d, bsz, pad_tensor=suffix_input_ids is None)
                if track_memory_finegrained:
                    memory_stats.update({f"after_dist_gather/{k}": v for k, v in check_memory().items()})
                loss, accuracy = self.objective.distributed_loss(
                    [prefix_model_output_bsz_d, all_gathered_suffixes],
                    t_sizes=t_sizes,
                    group_rank=self.grp_rank,
                    track_memory_finegrained=track_memory_finegrained,
                    memory_stats=memory_stats,
                )
                if track_memory_finegrained:
                    memory_stats.update({f"after_loss_fn/{k}": v for k, v in check_memory().items()})
                outputs["loss"], outputs["accuracy"] = loss, accuracy
            else:
                outputs = self.objective([prefix_model_output_bsz_d, suffix_model_output_bsz_d])
        
        if track_memory_finegrained:
            outputs["memory_stats"] = memory_stats
        return outputs

    @torch._dynamo.disable(recursive=True)
    def _dist_gather_tensor(self, t: Optional[torch.Tensor], bsz: int, pad_tensor: bool = True):
        """
        Args:
            t: a list of tensors or a single tensor of shape [b, n, d]
            bsz: it's the same as micro_batch_size (NOTE: it's not the same as contrastive_bsz)
        """

        if t is None:
            return None
        if isinstance(t, list):
            t = torch.cat(t, dim=0)

        # HACK 1: Since tok_bsz or contrastive_bsz isn't same across all devices, we need to keep track of the sizes
        # so that we can figure out the indices of the tensors in the all_gathered tensor (it'll be useful to find corresponding prefix suffix pairs and construct labels later)
        t_size = torch.tensor([t.size(0)], device=t.device)

        # HACK 2: We want to extend the tensor to be all of same size across all devices so all_gather won't complain
        if pad_tensor:
            pad_len = (bsz * self.max_seq_length) - t.size(0)
            t = torch.nn.functional.pad(t, (0, 0, 0, pad_len), mode="constant", value=torch.nan)

        # GATHER 1: gathering tensors across devices in the same group
        all_tensors = dist.nn.functional._AllGather.apply(self.negatives_cross_device_group, t)

        # GATHER 2: gathering tensor sizes across devices in the same group
        all_t_sizes = [torch.empty_like(t_size) for _ in range(self.negatives_cross_device_group_size)]
        dist.all_gather(all_t_sizes, t_size, group=self.negatives_cross_device_group)
        all_t_sizes[self.grp_rank] = t_size

        all_t_sizes = torch.cat(all_t_sizes, dim=0)
        all_tensors = torch.cat(all_tensors, dim=0)

        # unpadding with masking here
        if pad_tensor:
            mask = ~torch.isnan(all_tensors).any(dim=1)
            all_tensors = all_tensors[mask]

        return all_tensors, all_t_sizes

    @torch.no_grad()
    def encode(
        self,
        sentences: Union[List[str], str],
        batch_size: int = 256,
        max_length: int = 512,
        instruction: str = "",  # ignore this
        normalize_embeddings: bool = False,
        convert_to_tensor: bool = False,
        recast: bool = False,
        encoding_mode: str = "prefix",
        add_bos: bool = False,
        add_eos: bool = False,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        pooling_method: str = "lasttoken",  # One of ['cls', 'lasttoken', 'mean', 'weightedmean']
        **kwargs,
    ) -> np.ndarray:
        print(f"#########INSIDE ENCODE METHOD: encoding_mode: {encoding_mode}############")
        print(f"Instruction: {instruction}")
        num_gpus = torch.cuda.device_count()
        if num_gpus > 1:
            batch_size *= num_gpus

        input_was_string = False
        if isinstance(sentences, str):
            sentences = [sentences]
            input_was_string = True

        all_embeddings = []
        FIRST_PRINT = True

        if self.tokenizer.pad_id is None:
            self.tokenizer.pad_id = -1
        for start_index in tqdm(range(0, len(sentences), batch_size), desc="Batches", disable=len(sentences) < 256):
            sentences_batch = sentences[start_index : start_index + batch_size]
            if FIRST_PRINT:
                print(instruction.format(text=sentences_batch[0]) if instruction else sentences_batch[0])
                FIRST_PRINT = False
            # strip them
            sentences_batch = [
                self.tokenizer.encode(
                    instruction.format(text=s) if instruction else s, eos=add_eos, bos=add_bos, max_length=max_length
                )
                for s in sentences_batch
            ]
            input_ids_len = [len(s) for s in sentences_batch]
            max_len_sentence = XXXX-13(input_ids_len)

            def pad_right(x, pad_id, max_len):
                # pad right based on the longest sequence
                n = max_len - len(x)
                return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))

            input_ids = torch.stack(
                [pad_right(x, pad_id=self.tokenizer.pad_id, max_len=max_len_sentence) for x in sentences_batch]
            ).to(device)
            input_ids = input_ids[:, :max_length]  # truncate to max_lengt

            attn_mask = (input_ids != self.tokenizer.pad_id).float()
            input_ids[input_ids == self.tokenizer.pad_id] = (
                self.tokenizer.eos_id
            )  # replace pad tokens with eos tokens for forward pass
            if FIRST_PRINT:
                print(input_ids[0])
                # FIRST_PRINT = False
            if self.add_suf_pre_tokens:
                max_length = 2048
                prefix_attn_mask_add = torch.tensor([[True]]).repeat(attn_mask.size(0), 1)
                attn_mask = torch.cat([prefix_attn_mask_add.to(attn_mask.device), attn_mask], dim=1)
                if encoding_mode == "prefix":
                    prefix_token_idx = torch.tensor([[self.tokenizer.prefix_token_id]]).repeat(input_ids.size(0), 1)
                    input_ids = torch.cat((prefix_token_idx.to(input_ids.device), input_ids), dim=1)
                if input_ids.size(1) > max_length:
                    input_ids = input_ids[:, :max_length]
                    attn_mask = attn_mask[:, :max_length]

            if encoding_mode == "prefix":
                outputs = self.prefix_model(input_ids, output_hidden_states=True)
                last_hidden_state = outputs["hidden_states"][-1]
            if encoding_mode == "suffix":
                if batch_size > 1:
                    # if bsz > 1, then there's possibility of having pad_tokens, we need to bring them to the front before reversing the sequence
                    # (so that they're at the end after reversing)
                    # Creating a mask where True (indicates the padding token)
                    mask = (
                        ~attn_mask.bool().cpu()
                    )  # we're taking the inverse cos initially we have True where there are tokens and False where there are pad tokens
                    sorted_indices = torch.argsort(mask, descending=True, stable=True)
                    input_ids = torch.gather(input_ids, 1, sorted_indices.to(input_ids.device))

                input_ids = torch.flip(input_ids, [1])
                if self.add_suf_pre_tokens:
                    # add suffix token at the beginning of the sequence
                    suffix_token_idx = torch.tensor([[self.tokenizer.suffix_token_id]]).repeat(input_ids.size(0), 1)
                    input_ids = torch.cat((suffix_token_idx.to(input_ids.device), input_ids), dim=1)

                if input_ids.size(1) > max_length:
                    input_ids = input_ids[:, :max_length]
                    attn_mask = attn_mask[:, :max_length]

                outputs = self.suffix_model(input_ids, output_hidden_states=True, is_suffix=True)
                last_hidden_state = outputs["hidden_states"][-1]

            embeddings = self.pooling(last_hidden_state, attn_mask, recast=recast, pooling_method=pooling_method)
            # Normalize can change the dtype (https://discuss.pytorch.org/t/tensor-in-float16-is-transformed-into-float32-after-torch-norm/110891)
            if normalize_embeddings:
                in_dtype = embeddings.dtype
                embeddings = torch.nn.functional.normalize(embeddings, dim=-1).to(in_dtype)
            embeddings = cast(torch.Tensor, embeddings)
            if convert_to_tensor:
                all_embeddings.append(embeddings)
            else:
                # NumPy does not support bfloat16
                all_embeddings.append(embeddings.cpu().to(torch.float32).numpy())

        all_embeddings = (
            torch.cat(all_embeddings, dim=0) if convert_to_tensor else np.concatenate(all_embeddings, axis=0)
        )
        if input_was_string:
            all_embeddings = all_embeddings[0]

        return all_embeddings

    def encode_queries(self, queries: Union[List[str], str], **kwargs) -> np.ndarray:
        print("#########INSIDE ENCODE QUERIES METHOD############")
        return self.encode(queries, encoding_mode="prefix", **kwargs)

    def encode_corpus(self, corpus: list[str] | list[dict[str, str]], **kwargs) -> np.ndarray:
        print("#########INSIDE ENCODE CORPUS METHOD############")
        if isinstance(corpus, dict):
            corpus = [corpus]
        if isinstance(corpus, list) and isinstance(corpus[0], dict):
            corpus = [doc["title"] + " " + doc["text"] if "title" in doc else doc["text"] for doc in corpus]
        elif isinstance(corpus, list) and isinstance(corpus[0], str):
            pass
        return self.encode(corpus, encoding_mode="suffix", **kwargs)

    def pooling(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor = None,
        recast: bool = False,
        pooling_method: str = "lasttoken",
    ) -> torch.Tensor:
        """
        Args:
            hidden_state: [b, n, d]
            attention_mask: [b, n]
        """
        # In case the model is distributed across multiple devices; hidden_state may end up on diff device
        hidden_state = hidden_state.to(attention_mask.device)
        if pooling_method == "lasttoken":
            b, n, d = hidden_state.size()
            # Get the last `1` in the attention mask of each item
            # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
            # except when 1) There's all 1's 2) There's 0's before the 1's
            reversed_mask = torch.flip(attention_mask, dims=(1,))
            argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
            gather_indices = attention_mask.size(1) - argmax_reverse - 1
            # If there are empty sequences, where the index would become -1 it will crash so set them to 0
            gather_indices = torch.clamp(gather_indices, min=0)
            # Turn indices from shape [b] -> [b, 1, d]
            gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
            gather_indices = gather_indices.unsqueeze(1)
            assert gather_indices.shape == (b, 1, d)
            # Gather along the seq len: [b, n, d] -> [b, d]
            # Actually no need for the attention mask as we gather the last token where attn_mask=1 but
            # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
            input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
            embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
        elif pooling_method in ["mean", "weightedmean"]:
            if pooling_method == "weightedmean":
                attention_mask *= attention_mask.cumsum(dim=1)  # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
            s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
            d = attention_mask.sum(dim=1, keepdim=True).float()
            embedding = s / d
        else:
            raise NotImplementedError(f"Unknown pooling method: {pooling_method}")
        # Recasting performs slightly worse but saves 50% space
        if recast:
            return embedding.to(hidden_state.dtype)
        return embedding


class GPTRetrieval(GPT):
    def __init__(
        self,
        config: Config,
        gradient_checkpointing: bool = False,
        flip_rope_embedding_suffix: bool = False,
        add_suf_pre_tokens: bool = False,
        nope_pos_embeddings: bool = False,
    ) -> None:
        super().__init__(config, objective=None)
        self.config = config
        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
                h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
                ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
            )
        )
        self.max_seq_length = self.config.block_size
        self.mask_cache: Optional[torch.Tensor] = None
        self.gradient_checkpointing = gradient_checkpointing
        self.flip_rope_embedding_suffix = flip_rope_embedding_suffix
        self.add_suf_pre_tokens = add_suf_pre_tokens
        self.nope_pos_embeddings = nope_pos_embeddings

    def forward(
        self,
        idx: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        is_suffix: bool = False,
    ) -> dict:
        T = idx.size(1)  # sequence length
        if self.max_seq_length < T:
            raise ValueError(f"Cannot forward sequence of length {T}, XXXX-13 seq length is only {self.max_seq_length}.")

        if position_ids is not None:  # use the kv cache
            cos = self.cos.index_select(0, position_ids)
            sin = self.sin.index_select(0, position_ids)
            if self.mask_cache is None:
                raise TypeError("You need to call `gpt.set_kv_cache()`")
            mask = self.mask_cache.index_select(2, position_ids)
        else:
            cos = self.cos[:T]
            sin = self.sin[:T]
            mask = None

        # TODO: Reversing here; Move into the if statement block above to make it fast
        # We are keeping this here for more robust code to other changes
        if is_suffix and self.flip_rope_embedding_suffix:
            cos = torch.flip(cos, [0])
            sin = torch.flip(sin, [0])
        # END BLOCK OF RESERVE ROPE EMBEDDING
        # TODO: RUN INTERACTIVE DEBUGGING TO SEE IF THE REVERSE IS WORKING CORRECTLY

        all_hidden_states = () if output_hidden_states else None
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        if self.config.scale_embeddings:
            x = x * (self.config.n_embd**0.5)

        for block in self.transformer.h:
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (x,)
            if not self.gradient_checkpointing:
                x = block(x, cos, sin, mask, position_ids)
            else:
                x = self.config.checkpoint(block, x, cos, sin, mask, position_ids)
        x = self.transformer.ln_f(x)

        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (x,)
        else:
            all_hidden_states = (x,)

        return {"hidden_states": all_hidden_states}
