# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.

from functools import partial
import numpy as np
import os
import time
import torch
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, Subset
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm

from megatron.training import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.training.arguments import core_transformer_config_from_args
from megatron.core.datasets.retro.utils import get_blocks_by_rank
from megatron.core.enums import ModelType
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.legacy.model import BertModel
from megatron.training.training import setup_model_and_optimizer
from pretrain_bert import model_provider, get_batch, loss_func, forward_step

from .dataset import BertEmbeddingDataset
from .external_libs import h5py
from .huggingface import HuggingfaceEmbedder


def collate_batch(samples):
    """Collate samples of various lengths.

    This collate function handles samples with various sequence lengths, by
    padding 'text' arrays with pad_id, and other arrays with 0.
    """

    n_samples = len(samples)
    keys = list(samples[0].keys())
    tokenizer = get_tokenizer()

    # Max sample length across all samples.
    max_length_map = { key:0 for key in keys }
    for sample in samples:
        for key in keys:
            value_length = \
                len(sample[key]) if isinstance(sample[key], np.ndarray) else None
            max_length_map[key] = None \
                if value_length is None else \
                   max(max_length_map[key], value_length)

    # Pad samples.
    padded_samples = []
    for sample in samples:
        padded_sample = {}
        for key in keys:
            padded_sample[key] = \
                np.pad(
                    sample[key],
                    (0, max_length_map[key] - len(sample[key])),
                    mode="constant",
                    constant_values=tokenizer.pad_id if key == "text" else 0,
                ) \
                if isinstance(sample[key], np.ndarray) else \
                   sample[key]
        padded_samples.append(padded_sample)

    # Build batch with padded samples.
    batch = default_collate(padded_samples)

    return batch


def get_data_loader(dataset, batch_size):
    """Build data loader over data subset.

    Get a subset of the dataset (from start_idx -> end_idx), and wrap it in
    a sequential sampler and data loader.
    """

    args = get_args()

    # Sequential & batch samplers.
    batch_sampler = BatchSampler(
        sampler=SequentialSampler(dataset),
        batch_size=batch_size,
        drop_last=False,
    )

    # Data loader.
    data_loader = DataLoader(dataset,
                             batch_sampler=batch_sampler,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             collate_fn=collate_batch)

    return data_loader


def embed_data_loader(models, data_loader, tag):
    '''Iterate data loader and compute embeddings.'''

    # Verify no model parallelism.
    args = get_args()
    assert args.tensor_model_parallel_size == 1 and \
        args.pipeline_model_parallel_size == 1, \
        "since we call forward_step directly, only tp == pp == 1 allowed."

    # Data iterator.
    data_iterator = iter(data_loader)

    # Eval mode.
    for m in models:
        m.eval()

    # Embed.
    embeddings = []
    for _ in tqdm(
        range(len(data_loader)),
        "  embed%s" % ("" if tag is None else " / '%s'" % tag),
        miniters=len(data_loader) // 10,
        disable=torch.distributed.get_rank() != 0,
    ):
        with torch.no_grad():
            result = forward_step(data_iterator, models[0])
            embeddings.append(result[0].detach().cpu().numpy())

    # Concatenate embeddings.
    embeddings = np.concatenate(embeddings, axis=0)

    return embeddings


class TextDataset(torch.utils.data.Dataset):
    '''Dataset that holds a list of strings.'''

    def __init__(self, texts):
        assert isinstance(texts, list)
        for t in texts:
            assert isinstance(t, str)
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, i):
        return {"text": self.texts[i]}


class BertEmbedder:
    '''Compute Bert embeddings, from a text dataset.'''

    def __init__(self, batch_size, max_bert_seq_length, embedder_type, warmup=True):

        args = get_args()

        assert args.output_bert_embeddings

        self.models, optimizer, opt_param_scheduler = \
            setup_model_and_optimizer(model_provider,
                                      ModelType.encoder_or_decoder)
        self.batch_size = batch_size
        self.max_bert_seq_length = max_bert_seq_length

        # Init Huggingface, if in use.
        if embedder_type == "megatron":
            self.huggingface_embedder = None
        elif embedder_type == "huggingface":
            self.huggingface_embedder = HuggingfaceEmbedder(batch_size,
                                                            max_bert_seq_length)
        else:
            raise Exception("specialize for embedder type '%s'." % embedder_type)

        # Warm-up JIT.
        # - Important to separately warm up:
        #   1. batch_size == 1
        #   2. batch_size > 1
        if warmup:
            warmup_dataset = TextDataset([
                "great fleas have lesser fleas, upon their backs to bite’em,",
                "and lesser fleas have lesser fleas, and so, ad infinitum,",
                "and those great fleas, themselves, in turn have greater fleas to go on,",
                "while those again have greater still, and greater still, and so on.",
            ])
            print_rank_0("bert / warmup single.")
            for _ in range(3):
                self.embed_text("hi, bert.")            # batch size == 1
            print_rank_0("bert / warmup batch.")
            for _ in range(3):
                self.embed_text_dataset(warmup_dataset) # batch size > 1

    def embed_text_dataset(self, text_dataset, tag=None):
        '''Embed a text dataset.'''

        # Huggingface.
        if self.huggingface_embedder:
            return self.huggingface_embedder.embed_text_dataset(text_dataset)

        # Wrap in a BertEmbeddingDataset to tokenize samples.
        bert_dataset = BertEmbeddingDataset(text_dataset,
                                            self.max_bert_seq_length)

        # Embed.
        data_loader = get_data_loader(bert_dataset, self.batch_size)
        embeddings = embed_data_loader(self.models, data_loader, tag)

        return embeddings

    def embed_text(self, text):
        '''Embed a single text string.

        Primarily used for on-the-fly embeddings, particularly during
        analysis or debugging. For large scale, use 'embed_text_dataset()'.
        '''

        # Embed text.
        text_ds = TextDataset([ text ])
        embed = self.embed_text_dataset(text_ds)[0]

        return embed


class DiskDataParallelBertEmbedder:
    '''Process embeddings in blocks & save to disk.'''

    def __init__(self, embedder, block_size):
        assert isinstance(embedder, BertEmbedder)
        self.embedder = embedder
        self.block_size = block_size

    def embed_text_blocks(self, name, dirname, text_dataset,
                          missing_embedding_blocks):
        '''Process a text dataset in blocks.'''

        # Iterate blocks.
        for block_index, block_info in enumerate(missing_embedding_blocks):

            # Missing block lists are extended with None to have equal-length
            # lists. Skip the Nones.
            if block_info is not None:

                # Progress. (*note*: move world progress to here.)
                print_rank_0("embed '%s' block %d / %d ... %s." % (
                    name,
                    block_index,
                    len(missing_embedding_blocks),
                    block_info["path"],
                ))

                # Embed block.
                sub_dataset = Subset(text_dataset, range(*block_info["range"]))
                embeddings = self.embedder.embed_text_dataset(sub_dataset)

                # Save embeddings.
                f = h5py.File(block_info["path"], "w")
                f.create_dataset("data", data=embeddings)
                f.close()

            # Synchronize progress across all ranks. (for easier observation)
            print_rank_0(" > waiting for other ranks to finish block.")
            torch.distributed.barrier()

    def embed_text_dataset(self, name, dirname, text_dataset):
        '''Embed a text dataset.'''

        # Dataset dir.
        os.makedirs(dirname, exist_ok=True)

        # Missing embedding blocks (stored on disk).
        def validate(f):
            assert f["data"].shape[1] == 1024
        blocks = get_blocks_by_rank(
            dirname,
            len(text_dataset),
            self.block_size,
            validate=validate)

        # Prevent missing file race condition.
        torch.distributed.barrier()

        # Embed batches.
        self.embed_text_blocks(name, dirname, text_dataset, blocks.missing)
