import logging
import torch
from typing import List, Tuple, Set, Optional
from dataclasses import dataclass
from transformers import PreTrainedTokenizer

from leanfinder.retriever.arguments import DataArguments


logger = logging.getLogger(__name__)

_ALLOWED_TOKENS: Optional[torch.Tensor] = None

def get_allowed_tensor(tokenizer: PreTrainedTokenizer,
                       dtype: torch.dtype,
                       device: torch.device) -> torch.Tensor:
    global _ALLOWED_TOKENS

    if _ALLOWED_TOKENS is None or _ALLOWED_TOKENS.device != device:
        specials: Set[int] = set(tokenizer.all_special_ids)
        vocab_ids = list(tokenizer.get_vocab().values())
        allowed_ids = [i for i in vocab_ids if i not in specials]
        _ALLOWED_TOKENS = torch.tensor(allowed_ids, dtype=dtype, device=device)

    return _ALLOWED_TOKENS


@dataclass
class TrainCollator:
    data_args: DataArguments
    tokenizer: PreTrainedTokenizer

    def __call__(self, features: List[Tuple[str, List[str]]]):

        all_queries = [f[0] for f in features]
        all_passages = []
        for f in features:
            all_passages.extend(f[1])
        all_queries = [q[0] for q in all_queries]
        all_passages = [p[0] for p in all_passages]
        q_collated = self.tokenizer(
            all_queries,
            padding=False, 
            truncation=True,
            max_length=self.data_args.query_max_len-1 if self.data_args.append_eos_token else self.data_args.query_max_len,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )
        d_collated = self.tokenizer(
            all_passages,
            padding=False, 
            truncation=True,
            max_length=self.data_args.passage_max_len-1 if self.data_args.append_eos_token else self.data_args.passage_max_len,
            return_attention_mask=False,
            return_token_type_ids=False,
            add_special_tokens=True,
        )

        if self.data_args.append_eos_token:
            q_collated['input_ids'] = [q + [self.tokenizer.eos_token_id] for q in q_collated['input_ids']]
            d_collated['input_ids'] = [d + [self.tokenizer.eos_token_id] for d in d_collated['input_ids']]
        
        q_collated = self.tokenizer.pad(
            q_collated,
            padding=True, 
            pad_to_multiple_of=self.data_args.pad_to_multiple_of,
            return_attention_mask=True,
            return_tensors='pt',
        )
        d_collated = self.tokenizer.pad(
            d_collated,
            padding=True, 
            pad_to_multiple_of=self.data_args.pad_to_multiple_of,
            return_attention_mask=True,
            return_tensors='pt',
        )
        q_collated['input_ids'] = self.augment_with_mask_and_random(q_collated['input_ids'], self.tokenizer)
        return q_collated, d_collated
    
    def augment_with_mask_and_random(
        self,
        input_ids: torch.Tensor,
        tokenizer: PreTrainedTokenizer,
        prob: float = 0.15,
    ) -> torch.Tensor:
        special: Set[int] = set(tokenizer.all_special_ids)
        mask_token = tokenizer.mask_token_id
        if mask_token is None:
            raise ValueError("Tokenizer has no mask token defined")

        cand_mask = torch.rand_like(input_ids, dtype=torch.float) < prob
        for s in special:
            cand_mask &= input_ids.ne(s)

        if cand_mask.sum() == 0:
            return input_ids

        rand = torch.rand_like(input_ids, dtype=torch.float)

        mask_pos = cand_mask & (rand < 0.1)
        n_mask = mask_pos.sum().item()
        if n_mask:
            input_ids[mask_pos] = mask_token

        rand_pos = cand_mask & (rand >= 0.1) & (rand < 0.9)
        n_rand = rand_pos.sum().item()
        if n_rand:
            allowed = get_allowed_tensor(tokenizer,
                                         dtype=input_ids.dtype,
                                         device=input_ids.device)
            rand_idx = torch.randint(0, allowed.numel(), (n_rand,), device=input_ids.device)
            input_ids[rand_pos] = allowed[rand_idx]

        return input_ids