import torch
import torch.distributed as dist
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

from src.nlp.models.LlamaModel import LlamaModel


def precompute_llama_embeddings(
    model: LlamaModel,
    train_dataloader,
    val_dataloader,
):  
    assert dist.is_initialized()
    _precompute_embeddings(model, train_dataloader, "train")
    _precompute_embeddings(model, val_dataloader, "val")


def load_precomputed_llama_embeddings_dataset(split: str):
    assert split in ("train", "val"), f"Invalid split name: {split}"
    rank = dist.get_rank()
    path = f"llama-next-token-precomputed-{split}-rank-{rank}.pt"
    data = torch.load(path)
    embeddings = data["embeddings"]
    attention_masks = data["attention_masks"]
    target_ids = data["target_ids"]
    target_indices = data["target_indices"]
    dataset = TensorDataset(embeddings, attention_masks, target_ids, target_indices)
    return dataset


def _precompute_embeddings(
    model: LlamaModel,
    dataloader,
    suffix: str,
):
    model.eval()
    model.backbone.eval()  
    with torch.no_grad():
        embeddings_all = []
        attention_masks_all = []
        target_ids_all = []
        target_indices_all = []

        attention_masks_chunk = []
        embeddings_chunk = []
        target_ids_chunk = []
        target_indices_chunk = []
        for i, batch in enumerate(tqdm(dataloader)):
            texts, target_indices = batch
            
            texts = [t + model.tokenizer.eos_token for t in texts]
            inputs = model._tokenize(texts)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            input_ids = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]

            # Explicitly save the target ids to use for next token prediction.
            target_ids = input_ids[:, 1:]
            target_mask = attention_mask[:, 1:]
            target_ids = target_ids.masked_fill(target_mask == 0, -100)

            # Shift inputs here s.t. the produced embeddings correspond
            # to what we'd have when using the backbone directly.
            input_ids = input_ids[:, :-1]
            attention_mask = attention_mask[:, :-1]
            
            # Run forward pass with backbone only.
            embeddings = model.backbone(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
            
            # Collect tensors on gpu.
            attention_masks_chunk.append(attention_mask.detach())
            embeddings_chunk.append(embeddings.detach())
            target_ids_chunk.append(target_ids.detach())
            target_indices_chunk.append(target_indices.detach())

            # Move collected gpu tensors to cpu.
            if i > 0 and i % 50 == 0:
                attention_masks_chunk = torch.cat(attention_masks_chunk, dim=0).cpu()
                attention_masks_all.append(attention_masks_chunk)
                attention_masks_chunk = []

                embeddings_chunk = torch.cat(embeddings_chunk, dim=0).cpu()
                embeddings_all.append(embeddings_chunk)
                embeddings_chunk = []

                target_ids_chunk = torch.cat(target_ids_chunk, dim=0).cpu()
                target_ids_all.append(target_ids_chunk)
                target_ids_chunk = []

                target_indices_chunk = torch.cat(target_indices_chunk, dim=0).cpu()
                target_indices_all.append(target_indices_chunk)
                target_indices_chunk = []

        # Handle remainder.
        if len(embeddings_chunk) > 0:
            attention_masks_chunk = torch.cat(attention_masks_chunk, dim=0).cpu()
            attention_masks_all.append(attention_masks_chunk)
            attention_masks_chunk = []

            embeddings_chunk = torch.cat(embeddings_chunk, dim=0).cpu()
            embeddings_all.append(embeddings_chunk)
            embeddings_chunk = []

            target_ids_chunk = torch.cat(target_ids_chunk, dim=0).cpu()
            target_ids_all.append(target_ids_chunk)
            target_ids_chunk = []

            target_indices_chunk = torch.cat(target_indices_chunk, dim=0).cpu()
            target_indices_all.append(target_indices_chunk)
            target_indices_chunk = []

        embeddings_all = torch.cat(embeddings_all, dim=0)
        attention_masks_all = torch.cat(attention_masks_all, dim=0)
        target_ids_all = torch.cat(target_ids_all, dim=0)
        target_indices_all = torch.cat(target_indices_all, dim=0)

        data = {
            "embeddings": embeddings_all,
            "attention_masks": attention_masks_all,
            "target_ids": target_ids_all,  # Note: Masked values are already set to -100 here.
            "target_indices": target_indices_all,
        }
        torch.save(data, f"llama-next-token-precomputed-{suffix}-rank-{dist.get_rank()}.pt")