"""Utilities common to several backends."""
import torch
import transformers
from torch.utils.data import DataLoader

from contextlib import nullcontext, contextmanager


def get_num_workers(cfg_impl):
    if cfg_impl.threads > 0:
        return min(torch.get_num_threads() // max(1, torch.cuda.device_count()), cfg_impl.threads)
    else:
        return 0


def group_parameters(model, cfg_train):
    model_parameters = list(model.named_parameters())
    if len(cfg_train.limited_decay_keys) > 0:
        grouped_parameters = optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model_parameters if not any(nd in n for nd in cfg_train.limited_decay_keys)],
                "weight_decay": cfg_train.optim.weight_decay,
            },
            {
                "params": [p for n, p in model_parameters if any(nd in n for nd in cfg_train.limited_decay_keys)],
                "weight_decay": 0.0,
            },
        ]
    else:
        grouped_parameters = [p for n, p in model_parameters]
    return grouped_parameters


def update_ema(model, ema_model, momentum=0.995):
    """Update exponential moving average in parameters and buffers."""
    with torch.no_grad():
        for param_source, param_target in zip(model.parameters(), ema_model.parameters()):
            param_target.copy_(momentum * param_target.data + (1 - momentum) * param_source.data)
        for buffer_source, buffer_target in zip(model.buffers(), ema_model.buffers()):
            buffer_target.copy_(momentum * buffer_target.data + (1 - momentum) * buffer_source.data)


def select_optimizer_context(optimizer_context):
    """Could also support more of the craziness from
    https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/optimizations/backends.py
    """
    if optimizer_context is None:
        return nullcontext

    import torchdynamo
    from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
    from torchdynamo.optimizations import BACKENDS
    from torchdynamo.optimizations.backends import create_backend

    @create_backend
    def aot_autograd_nocut(subgraph, **kwargs):
        if not kwargs:
            # from functorch._src.aot_autograd import static_argnums
            from functorch.compile import default_decompositions
            from functorch.compile import default_partition
            from functorch.compile import ts_compile

            kwargs = {
                # these are taken from memory_efficient_fusion()
                "fw_compiler": ts_compile,
                "bw_compiler": ts_compile,
                "partition_fn": default_partition,
                "hasher_type": "StaticShapeHasher",
                "decompositions": default_decompositions,
                # "static_argnums": static_argnums,
            }

        def _wrapped_bw_compiler(*args, **kwargs):
            # stop TorchDynamo from trying to compile our generated backwards pass
            return torchdynamo.disable(bw_compiler(*args, **kwargs))

        bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
        kwargs["bw_compiler"] = _wrapped_bw_compiler

        from functorch.compile import aot_module_simplified

        return aot_module_simplified(subgraph.model, **kwargs)

    if optimizer_context == "inductor":
        return lambda: torchdynamo.optimize("inductor")
    elif optimizer_context == "aot_autograd_speedup":
        return lambda: torchdynamo.optimize(aot_autograd_speedup_strategy)
    elif optimizer_context == "nvfuser_global":
        return torch.jit.fuser("fuser2")
    elif optimizer_context == "aot_nvfuser":

        @contextmanager
        def aot_nvfuser_ctx():
            with torchdynamo.optimize(aot_autograd_speedup_strategy), torch.jit.fuser("fuser2"):
                yield

        return aot_nvfuser_ctx
    elif optimizer_context == "aot_nvfuser_forced":

        @contextmanager
        def aot_nvfuser_ctx():
            with torchdynamo.optimize(aot_autograd_speedup_strategy, nopython=True), torch.jit.fuser("fuser2"):
                yield

        return aot_nvfuser_ctx

    else:
        if optimizer_context in BACKENDS.keys():

            def dynamo_compiler(gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
                compiled_block = BACKENDS[optimizer_context](gm, example_inputs)
                if compiled_block is not None:
                    return compiled_block
                return gm.forward

            @contextmanager
            def dynamo_backend_ctx():
                with torchdynamo.optimize(dynamo_compiler):
                    yield

            return dynamo_backend_ctx
        else:
            raise ValueError(f"Invalid name for an optimization context {optimizer_context} given.")


def prepare_pretraining_dataloader(dataset, tokenizer, cfg_train, cfg_impl):

    num_workers = get_num_workers(cfg_impl)
    if cfg_train.objective.name == "masked-lm":
        collate_fn = PatchedDataCollatorForLanguageModeling(
            tokenizer=tokenizer,
            mlm=not cfg_train.objective.disable_mlm,
            mlm_probability=cfg_train.objective.mlm_probability,
            pad_to_multiple_of=8,
            use_80_20_rule=cfg_train.objective.use_80_20_rule,
            token_drop=cfg_train.objective.token_drop,
        )
    else:
        collate_fn = None

    if torch.distributed.is_initialized():
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset,
            shuffle=cfg_impl.shuffle_in_dataloader,
            drop_last=True,
        )
    else:
        if cfg_impl.shuffle_in_dataloader:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.SequentialSampler(dataset)

    if cfg_train.reverse_dataset_order:
        dataset = dataset.select(reversed(range(len(dataset))))
    repeated_dataloader = InfiniteDataLoader(
        dataset,
        sampler=sampler,
        batch_size=cfg_impl.microbatch_size,
        num_workers=num_workers,
        pin_memory=cfg_impl.pin_memory,
        drop_last=True,
        prefetch_factor=cfg_impl.prefetch_factor if num_workers > 0 else 2,
        persistent_workers=cfg_impl.persistent_workers if num_workers > 0 else False,
        collate_fn=collate_fn,
    )
    return repeated_dataloader


def prepare_downstream_dataloader(dataset, tokenizer, mode, cfg_impl):
    if mode == "training":
        if torch.distributed.is_initialized():
            sampler = torch.utils.data.distributed.DistributedSampler(
                dataset,
                shuffle=cfg_impl.shuffle_in_dataloader,
                drop_last=True,
            )
        else:
            if cfg_impl.shuffle_in_dataloader:
                sampler = torch.utils.data.RandomSampler(dataset)
            else:
                sampler = torch.utils.data.SequentialSampler(dataset)
        batch_size_multiplier = 1
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)
        batch_size_multiplier = 8

    # Implementation details for dataloaders:
    collate_fn = transformers.DataCollatorWithPadding(tokenizer, pad_to_multiple_of=cfg_impl.pad_to_multiple_of)
    num_workers = get_num_workers(cfg_impl)

    dataloader = DataLoader(
        dataset,
        batch_size=cfg_impl.microbatch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=cfg_impl.pin_memory,
        drop_last=True if mode == "training" else False,
        prefetch_factor=cfg_impl.prefetch_factor if num_workers > 0 else 2,
        persistent_workers=False,
        collate_fn=collate_fn,
    )
    return dataloader


"""This is a minor modification of huggingface's toking masking:"""
"""original source:
https://github.com/huggingface/transformers/blob/130b987880a9b1ade5c76dc1413c12c8924fda50/src/transformers/data/data_collator.py#L748
at commit f00f22a3e290fd377b979124dcf9800b3d73eb11"""


class PatchedDataCollatorForLanguageModeling(transformers.DataCollatorForLanguageModeling):
    def __init__(self, *args, use_80_20_rule=True, token_drop=False, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_80_20_rule = use_80_20_rule
        self.token_drop = token_drop

        self.mask_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

    def torch_mask_tokens(self, inputs=None, special_tokens_mask=None):
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        if self.use_80_20_rule:
            # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
            indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
            inputs[indices_replaced] = self.mask_token

            # 10% of the time, we replace masked input tokens with random word
            indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
            random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=inputs.dtype)
            inputs[indices_random] = random_words[indices_random]

            # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        else:
            # 100% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
            inputs[masked_indices] = self.mask_token

        if self.token_drop > 0:
            inputs, labels = self.drop_tokens(inputs, labels)
        return inputs, labels

    def drop_tokens(self, input_ids, labels):
        """Drop random tokens. Hou et al., "Token Dropping for Efficient BERT Pretraining" also discuss dropping tokens
        based on more advanced strategies, which might also be helpful.

        This is the simplest strategy, randomly dropping a bunch of tokens for all layers.
        """
        reduced_seq_length = int(input_ids.shape[1] * (1 - self.token_drop))
        # There is probably a faster way to do this, but this works for now?
        token_mask = torch.argsort(torch.rand_like(input_ids, dtype=torch.float), dim=-1)
        fixed_mask = input_ids.scatter(1, token_mask[:, :reduced_seq_length], -1) == -1
        return input_ids[fixed_mask].view(input_ids.shape[0], -1), labels[fixed_mask].view(input_ids.shape[0], -1)


class InfiniteDataLoader(torch.utils.data.DataLoader):
    """Lazy copy-paste from https://gist.github.com/MFreidank/821cc87b012c53fade03b0c7aba13958."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Initialize an iterator over the dataset.
        self.dataset_iterator = super().__iter__()
        self.epoch_counter = 0

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.dataset_iterator)
        except StopIteration:
            # Dataset exhausted, use a new fresh iterator.
            self.dataset_iterator = super().__iter__()
            self.epoch_counter += 1
            if hasattr(self.sampler, "set_epoch"):
                self.sampler.set_epoch(self.epoch_counter)
            batch = next(self.dataset_iterator)
        return batch
