import random
import warnings
from collections.abc import Mapping
from dataclasses import dataclass
from random import randint
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import numpy as np

from transformers.data.data_collator import DataCollatorMixin, pad_without_fast_tokenizer_warning, _torch_collate_batch
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


# @dataclass
# class My_collator(DataCollatorMixin):
#     """
#     Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
#     are not all of the same length.

#     Args:
#         tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
#             The tokenizer used for encoding the data.
#         mlm (`bool`, *optional*, defaults to `True`):
#             Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
#             with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
#             tokens and the value to predict for the masked token.
#         mlm_probability (`float`, *optional*, defaults to 0.15):
#             The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
#         pad_to_multiple_of (`int`, *optional*):
#             If set will pad the sequence to a multiple of the provided value.
#         return_tensors (`str`):
#             The type of Tensor to return. Allowable values are "np", "pt" and "tf".

#     <Tip>

#     For best performance, this data collator should be used with a dataset having items that are dictionaries or
#     BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
#     [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.

#     </Tip>"""

#     tokenizer: PreTrainedTokenizerBase
#     mlm: bool = True
#     mlm_probability: float = 0.15
#     pad_to_multiple_of: Optional[int] = None
#     tf_experimental_compile: bool = False
#     return_tensors: str = "pt"

#     def __post_init__(self):
#         if self.mlm and self.tokenizer.mask_token is None:
#             raise ValueError(
#                 "This tokenizer does not have a mask token which is necessary for masked language modeling. "
#                 "You should pass `mlm=False` to train on causal language modeling instead."
#             )


#     def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
#         # Handle dict or lists with proper padding and conversion to tensor.
#         if isinstance(examples[0], Mapping):
#             batch = pad_without_fast_tokenizer_warning(
#                 self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
#             )
#         else:
#             batch = {
#                 "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
#             }

#         # If special token mask has been preprocessed, pop it from the dict.
#         special_tokens_mask = batch.pop("special_tokens_mask", None)
#         if self.mlm:
#             batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
#                 batch["input_ids"], special_tokens_mask=special_tokens_mask
#             )
#         else:
#             labels = batch["input_ids"].clone()
#             if self.tokenizer.pad_token_id is not None:
#                 labels[labels == self.tokenizer.pad_token_id] = -100
#             batch["labels"] = labels
#         return batch

#     def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
#         """
#         Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
#         """
#         import torch

#         labels = inputs.clone()
#         # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
#         probability_matrix = torch.full(labels.shape, self.mlm_probability)
#         if special_tokens_mask is None:
#             special_tokens_mask = [
#                 self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
#             ]
#             special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
#         else:
#             special_tokens_mask = special_tokens_mask.bool()

#         probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
#         masked_indices = torch.bernoulli(probability_matrix).bool()
#         labels[~masked_indices] = -100  # We only compute loss on masked tokens

#         # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
#         indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
#         inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

#         # 10% of the time, we replace masked input tokens with random word
#         indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
#         random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
#         inputs[indices_random] = random_words[indices_random]

#         # The rest of the time (10% of the time) we keep the masked input tokens unchanged
#         return inputs, labels


@dataclass
class My_collator(DataCollatorMixin):
    """
    Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.

    Args:
        tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
            The tokenizer used for encoding the data.
        mlm (`bool`, *optional*, defaults to `True`):
            Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
            with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
            tokens and the value to predict for the masked token.
        mlm_probability (`float`, *optional*, defaults to 0.15):
            The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
        pad_to_multiple_of (`int`, *optional*):
            If set will pad the sequence to a multiple of the provided value.
        return_tensors (`str`):
            The type of Tensor to return. Allowable values are "np", "pt" and "tf".

    <Tip>

    For best performance, this data collator should be used with a dataset having items that are dictionaries or
    BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
    [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.

    </Tip>"""
    tokenizer: PreTrainedTokenizerBase
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __post_init__(self):
        None

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        # print("examples:", examples)
        batch = pad_without_fast_tokenizer_warning(
                self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
            )
        # print("batch:", batch)
        # input_ids = batch["input_ids"][1]
        # mask = batch["label"][1]
        # print("input_ids:", self.tokenizer.decode(input_ids, skip_special_tokens=True))
        # masked = input_ids * mask + (1-mask)*self.tokenizer.pad_token_id
        # print("masked:", self.tokenizer.decode(masked, skip_special_tokens=True))
        # If special token mask has been preprocessed, pop it from the dict.
        mask = batch["label"]
        labels = batch["input_ids"].clone()
        if self.tokenizer.pad_token_id is not None:
            labels[labels == self.tokenizer.pad_token_id] = -100
        labels = labels * mask + (1-mask)*(-100)
        batch["labels"] = labels
        del batch["label"]
        return batch