from dataclasses import dataclass
from typing import List, Union, Any, Dict, Optional, Mapping, Tuple
from random import random, randint
import numpy as np
import torch

from transformers import PreTrainedTokenizerBase, BatchEncoding
from transformers.data.data_collator import DataCollatorMixin, _torch_collate_batch


@dataclass
class SpanCollatorForDiffusionModeling:
    tokenizer: PreTrainedTokenizerBase
    max_number_of_spans: int
    padding_prob: float = 0.0
    mask_prob: float = 0.5
    max_length: int = 256

    def torch_call(
        self, examples: List[Union[List[int], Any, Dict[str, Any]]]
    ) -> Dict[str, Any]:
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        examples = self.insert_random_padding(examples=examples)
        examples = self.insert_necessary_padding(examples=examples)
        mask = torch.tensor(self.mask_spans(examples=examples), dtype=torch.bool)
        while torch.any(mask.sum(-1) <= 0.1):
            mask[mask.sum(-1) <= 0.1] = torch.tensor(
                self.mask_spans(examples=examples[mask.sum(-1) <= 0.1]),
                dtype=torch.bool,
            )  # mask examples
        return {"input_ids": examples, "conditioning_mask": mask}

    def insert_random_padding(self, examples: List[List[int]]) -> List[List[int]]:
        new_examples = []
        for ex in examples:
            new_ex = []
            for idx, token in enumerate(ex):
                while random() < self.padding_prob:
                    new_ex.append(self.tokenizer.pad_token_id)
                new_ex.append(token)
            while random() < self.padding_prob:
                new_ex.append(self.tokenizer.pad_token_id)
            new_examples.append(new_ex)
        return new_examples

    def mask_spans(self, examples):
        conditioning_mask = []
        for ex in examples:
            num_spans = randint(0, min(self.max_number_of_spans - 1, len(ex) - 1))
            spans_idx = list(
                np.sort(
                    np.random.choice(
                        np.arange(1, len(ex)), size=num_spans, replace=False
                    )
                )
            )
            mask = []
            for start_idx, end_idx in zip([0] + spans_idx, spans_idx + [len(ex)]):
                if random() < self.mask_prob:
                    mask += [0] * (end_idx - start_idx)
                else:
                    mask += [1] * (end_idx - start_idx)
            conditioning_mask.append(mask)
        return conditioning_mask

    def insert_necessary_padding(self, examples):
        examples = [{"input_ids": e} for e in examples]
        return self.tokenizer.pad(examples, return_tensors="pt")["input_ids"]

    def __call__(self, examples):
        return self.torch_call(examples)


@dataclass
class MLMDataCollator(DataCollatorMixin):
    tokenizer: PreTrainedTokenizerBase
    mlm_probability: float = 0.5
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def torch_call(
        self, examples: List[Union[List[int], Any, Dict[str, Any]]]
    ) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], Mapping):
            batch = self.tokenizer.pad(
                examples,
                return_tensors="pt",
                pad_to_multiple_of=self.pad_to_multiple_of,
            )
        else:
            batch = {
                "input_ids": _torch_collate_batch(
                    examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of
                )
            }

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)

        _attention_mask = batch.pop("attention_mask", None)

        batch["input_ids"], batch["conditioning_mask"] = self.torch_mask_tokens(
            batch["input_ids"], special_tokens_mask=special_tokens_mask
        )

        return batch

    def torch_mask_tokens(
        self, inputs: Any, special_tokens_mask: Optional[Any] = None
    ) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 100% MASK
        """
        import torch

        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(inputs.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True
                )
                for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        conditioning_mask = ~masked_indices

        return inputs, conditioning_mask


@dataclass
class FullyRandomCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of = 8

    def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]):
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        examples = _torch_collate_batch(
            examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of
        )
        num_tokens_to_mask = torch.randint(
            high=examples.size(-1), size=(examples.size(0),)
        )
        conditioning_mask = torch.ones_like(examples, dtype=torch.bool)
        for i, mt in enumerate(num_tokens_to_mask):
            idx_perm = torch.randperm(examples.size(-1), dtype=torch.long)[:mt]
            conditioning_mask[i].index_fill_(0, idx_perm, False)
        return {"input_ids": examples, "conditioning_mask": conditioning_mask}


@dataclass
class PrefixCollator:
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of = 8

    def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]):
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        examples = _torch_collate_batch(
            examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of
        )
        num_tokens_to_mask = torch.randint(
            high=examples.size(-1), size=(examples.size(0),)
        )
        conditioning_mask = torch.empty_like(examples, dtype=torch.bool)
        for i, mt in enumerate(num_tokens_to_mask):
            conditioning_mask[i] = torch.arange(examples.size(-1)) < mt
        return {"input_ids": examples, "conditioning_mask": conditioning_mask}


@dataclass
class CombinedCollator:
    collators: Tuple
    p: float = 0.5

    def __call__(self, examples: List[Union[List[int], Any, Dict[str, Any]]]):
        num_first = int(self.p * len(examples))
        if isinstance(examples[0], (dict, BatchEncoding)):
            examples = [e["input_ids"] for e in examples]
        examples = _torch_collate_batch(
            examples,
            self.collators[0].tokenizer,
            pad_to_multiple_of=self.collators[0].pad_to_multiple_of,
        )
        conditioning_mask = torch.ones_like(examples, dtype=torch.bool)
        conditioning_mask[:num_first] = self.collators[0](examples[:num_first].numpy())[
            "conditioning_mask"
        ]
        conditioning_mask[num_first:] = self.collators[1](examples[num_first:].numpy())[
            "conditioning_mask"
        ]
        return {"input_ids": examples, "conditioning_mask": conditioning_mask}


if __name__ == "__main__":
    from transformers import AutoTokenizer

    prefix_fn = PrefixCollator(tokenizer=AutoTokenizer.from_pretrained("roberta-base"))
    random_fn = FullyRandomCollator(
        tokenizer=AutoTokenizer.from_pretrained("roberta-base")
    )
    collate_fn = CombinedCollator(collators=(prefix_fn, random_fn))
    print(collate_fn([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
