# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""Full definition of a GPT NeoX Language Model, all of it in this single file.

Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
"""

import math
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from collections.abc import Callable

import numpy as np
import torch
import torch.nn as nn
from typing_extensions import Self

import torch.distributed as dist

from tqdm import tqdm

from litgpt.model import GPT, Block
from litgpt.config import Config
from litgpt.tokenizer import Tokenizer

from litgpt.retrieval_attn_utils import get_ltor_masks_and_position_ids
from litgpt.multiple_negative_ranking_loss import cos_sim
from litgpt.retrieval_model import GPTRetrieval


class PrefixNet(nn.Module):
    def __init__(
        self, 
        config: Config,
        objective: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        gradient_checkpointing: bool = False,
        checkpoint_dir: Optional[str] = None,
        keep_eos: bool = False,
        nope_pos_embeddings: bool = False,
    ) -> None:
        super().__init__()
        if nope_pos_embeddings:
            config.nope_pos_embeddings = True
        else:
            config.nope_pos_embeddings = False
        self.prefix_model = GPTRetrieval(config, checkpoint_dir, gradient_checkpointing)  # model with causal attention

        self.objective = objective
        self.max_seq_length = config.block_size
        self.keep_eos = keep_eos

    def forward(
        self,
        idx: Tuple[torch.Tensor, torch.Tensor],
        attn_mask: Tuple[torch.Tensor, torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
    ) -> dict:
        prefix_idx, suffix_input_ids = idx
        prefix_attn_mask, suffix_attn_mask = attn_mask
        prefix_output = self.prefix_model(prefix_idx, position_ids, output_hidden_states=output_hidden_states)
        suffix_output = self.prefix_model(suffix_input_ids, position_ids, output_hidden_states=output_hidden_states)
        bsz, _ = prefix_idx.size()

        prefix_hidden_states_bsz_T_d = prefix_output["hidden_states"][-1]
        suffix_hidden_states_bsz_T_d = suffix_output["hidden_states"][-1]

        if bsz > 1:
            # Removing the pad_tokens' hidden states as we don't want to calculate loss on them
            prefix_mask = prefix_idx != self.prefix_model.tokenizer.eos_id if not self.keep_eos else prefix_attn_mask
            prefix_hidden_states_bsz_T_d = [prefix_hidden_states_bsz_T_d[i][prefix_mask[i]] for i in range(bsz)]
            suffix_mask = suffix_input_ids != self.prefix_model.tokenizer.eos_id if not self.keep_eos else suffix_attn_mask
            suffix_hidden_states_bsz_T_d = [suffix_hidden_states_bsz_T_d[i][suffix_mask[i]] for i in range(bsz)]

        # we pick the last token
        # prefix_model_output_bsz_d = [x_T_d[-1, :].unsqueeze(0) for x_T_d in prefix_hidden_states_bsz_T_d]
        # suffix_model_output_bsz_d = [x_T_d[-1, :].unsqueeze(0) for x_T_d in suffix_hidden_states_bsz_T_d]
        
        # taking the mean pool representation
        prefix_model_output_bsz_d = [x_T_d.mean(dim=0).unsqueeze(0) for x_T_d in prefix_hidden_states_bsz_T_d]
        suffix_model_output_bsz_d = [x_T_d.mean(dim=0).unsqueeze(0) for x_T_d in suffix_hidden_states_bsz_T_d]

        outputs = {
            "loss": None,
            "accuracy": None,
        }
        if self.objective is not None:
            outputs = self.objective([prefix_model_output_bsz_d, suffix_model_output_bsz_d])

        return outputs

    def encode(self,
                sentences: Union[List[str], Tuple[str], List[Dict], str],
                max_length: Optional[int] = None,
                add_eos: bool = False,
                to_numpy: bool = True,
                device: str = "cuda" if torch.cuda.is_available() else "cpu",
                batch_size: int = 1024,
                pooling_method: str = "mean",
                **kwargs,):
        print("#########INSIDE LM-ENCODE METHOD############")
        self.prefix_model.eval()
        if self.prefix_model.tokenizer.pad_id is None:
            self.prefix_model.tokenizer.pad_id = -1
        FIRST_PRINT = True
        if not isinstance(sentences, (tuple, list)):
            sentences = [sentences]
        sentences_batch = [self.prefix_model.tokenizer.encode(s, eos=add_eos, max_length=max_length) for s in sentences]
        all_input_ids_len = [len(s) for s in sentences_batch]
        max_len_sentence = XXXX-13(all_input_ids_len)
        def pad_right(x, pad_id, max_len):
            # pad right based on the longest sequence
            n = max_len - len(x)
            return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
        all_input_ids = torch.stack([pad_right(x, pad_id=self.prefix_model.tokenizer.pad_id, max_len=max_len_sentence) for x in sentences_batch])
        all_input_ids = all_input_ids[:, :max_length]   # truncate to max_length
        outputs = []
        for i in tqdm(range(0, len(all_input_ids), batch_size), desc="Encoding...", total=len(all_input_ids)//batch_size):
            if FIRST_PRINT:
                print(sentences[i])
                FIRST_PRINT = False
            with torch.no_grad():
                input_ids = all_input_ids[i:i+batch_size].to(device)
                attn_mask = (input_ids != self.prefix_model.tokenizer.pad_id).float()
                input_ids[input_ids == self.prefix_model.tokenizer.pad_id] = self.prefix_model.tokenizer.eos_id     # replace pad tokens with eos tokens for forward pass
                output = self.prefix_model(input_ids, output_hidden_states=True)
                last_hidden_state = output["hidden_states"][-1]
                embeddings = self.pooling(last_hidden_state, attn_mask, pooling_method=pooling_method)
                embeddings = cast(torch.Tensor, embeddings)
                outputs.append(embeddings.cpu())
        output = torch.cat(outputs, dim=0)
        if to_numpy:
            return output.float().detach().cpu().numpy()
        return output
    
    def pooling(
        self, hidden_state: torch.Tensor, attention_mask: torch.Tensor = None, recast: bool = False, pooling_method: str = "lasttoken"
    ) -> torch.Tensor:
        """
        Args:
            hidden_state: [b, n, d]
            attention_mask: [b, n]
        """
        # In case the model is distributed across multiple devices; hidden_state may end up on diff device
        hidden_state = hidden_state.to(attention_mask.device)
        if pooling_method == 'lasttoken':
            b, n, d = hidden_state.size()
            # Get the last `1` in the attention mask of each item
            # Often it is just `gather_indices = torch.argmin(attention_mask, 1, keepdim=False) - 1`
            # except when 1) There's all 1's 2) There's 0's before the 1's
            reversed_mask = torch.flip(attention_mask, dims=(1,))
            argmax_reverse = torch.argmax(reversed_mask, dim=1, keepdim=False)
            gather_indices = attention_mask.size(1) - argmax_reverse - 1
            # If there are empty sequences, where the index would become -1 it will crash so set them to 0
            gather_indices = torch.clamp(gather_indices, min=0)
            # Turn indices from shape [b] -> [b, 1, d]
            gather_indices = gather_indices.unsqueeze(-1).repeat(1, d)
            gather_indices = gather_indices.unsqueeze(1)
            assert gather_indices.shape == (b, 1, d)
            # Gather along the seq len: [b, n, d] -> [b, d]
            # Actually no need for the attention mask as we gather the last token where attn_mask=1 but
            # as some indices (which shouldn't be attended to) may be 0 due to clamp, use mask to ignore them again
            input_mask_expanded = attention_mask.unsqueeze(-1).expand((b, n, d)).float()
            embedding = torch.gather(hidden_state * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
        elif pooling_method in ['mean', 'weightedmean']:
            if pooling_method == 'weightedmean':
                attention_mask *= attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
            s = torch.sum(hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
            d = attention_mask.sum(dim=1, keepdim=True).float()
            embedding = s / d
        else: raise NotImplementedError(f"Unknown pooling method: {pooling_method}")
        # Recasting performs slightly worse but saves 50% space
        if recast: return embedding.to(hidden_state.dtype)
        return embedding