from typing import Sequence, Tuple
import torch
import numpy as np


class PromptConvert(object):
    """ Convert Batch to Pre-train Task Needed Format
    input:    batch [(label1, sequence1), (label2, sequence2), ...]
    output:   labels, str_sequences, origin_tokens, masked_tokens(have been masked and padding), mask_ids
    """
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.pad_idx = alphabet.padding_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx

    def __call__(self, seqences: Sequence[Tuple[str, str]], prompt_toks=[]):
        batch_size = len(seqences)
        if len(prompt_toks) != 0:
            encoded_prompt = torch.tensor([self.alphabet.encode(prompt_tok)[0] for prompt_tok in prompt_toks])
        encoded_sequences = [self.alphabet.encode(sequence) for sequence in seqences]
        max_encoded_sequences_length = max(len(encoded_sequence) for encoded_sequence in encoded_sequences)
        tokens = torch.empty(
            (
                batch_size,
                max_encoded_sequences_length + len(prompt_toks) + 2,
            ),
            dtype=torch.int64
        )
        tokens.fill_(self.pad_idx)

        for i, encoded_sequence in enumerate(encoded_sequences):
            sequence_length = len(encoded_sequence)
            encoded_sequence = torch.tensor(encoded_sequence, dtype=torch.int64)
            tokens[i, 0] = self.cls_idx
            tokens[i, 1:len(encoded_sequence)+1] = encoded_sequence
            tokens[i, len(encoded_sequence)+1] = self.eos_idx

            if len(prompt_toks) != 0:
                tokens[i, -len(prompt_toks):] = torch.clone(encoded_prompt)

        return tokens


class MaskedConverter(object):
    """ Convert Batch to Pre-train Task Needed Format
    input:    batch [(label1, sequence1), (label2, sequence2), ...]
    output:   labels, str_sequences, origin_tokens, masked_tokens(have been masked and padding), mask_ids
    """
    def __init__(self, alphabet):
        self.alphabet = alphabet
        self.pad_idx = alphabet.padding_idx
        self.mask_idx = alphabet.mask_idx
        self.cls_idx = alphabet.cls_idx
        self.eos_idx = alphabet.eos_idx

        self.mask_prob = 0.15
        self.random_token_prob = 0.1
        self.leave_unmasked_prob = 0.1

        weights = np.zeros(len(self.alphabet))
        weights[:len(self.alphabet)-2] = 1
        self.weights = weights / weights.sum()

    def __call__(self, raw_batch: Sequence[Tuple[str, str]], prompt_toks):
        batch_size = len(raw_batch)
        if len(prompt_toks) != 0:
            encoded_prompt = torch.tensor([self.alphabet.encode(prompt_tok)[0] for prompt_tok in prompt_toks])
        encoded_sequences = [self.alphabet.encode(sequence) for sequence in raw_batch]
        max_encoded_sequences_length = max(len(encoded_sequence) for encoded_sequence in encoded_sequences)
        origin_tokens = torch.empty(
            (
                batch_size,
                max_encoded_sequences_length + 2 + len(prompt_toks),
            ),
            dtype=torch.int64
        )
        origin_tokens.fill_(self.alphabet.padding_idx)
        masked_tokens = torch.empty(
            (
                batch_size,
                max_encoded_sequences_length + 2 + len(prompt_toks), 
            ),
            dtype=torch.int64
        )
        masked_tokens.fill_(self.alphabet.padding_idx)
        target_tokens = torch.empty(
            (
                batch_size,
                max_encoded_sequences_length + 2, 
            ),
            dtype=torch.int64
        )
        target_tokens.fill_(self.alphabet.padding_idx)
        
        for i, encoded_sequence in enumerate(encoded_sequences):
            sequence_length = len(encoded_sequence)
            mask = np.full(sequence_length, False)
            num_mask = int(self.mask_prob * sequence_length + np.random.rand())
            mask_idc = np.random.choice(sequence_length, num_mask, replace=False)
            mask_idc = mask_idc[mask_idc < len(mask)]
            mask[mask_idc] = True

            # decide unmasking and random replacement
            rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob
            rand_or_unmask = mask & (np.random.rand(sequence_length) < rand_or_unmask_prob)
            if self.random_token_prob == 0.0:
                unmask = rand_or_unmask
                rand_mask = None
            elif self.leave_unmasked_prob == 0.0:
                unmask = None
                rand_mask = rand_or_unmask
            else:
                unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob
                decision = np.random.rand(sequence_length) < unmask_prob
                unmask = rand_or_unmask & decision
                rand_mask = rand_or_unmask & (~decision)

            mask = mask ^ unmask

            masked_sequence = np.copy(encoded_sequence)
            masked_sequence[mask] = self.mask_idx
            num_rand = rand_mask.sum()
            masked_sequence[rand_mask] = np.random.choice(
                len(self.alphabet.all_toks),
                num_rand,
                p=self.weights,
            )
            encoded_sequence = torch.tensor(encoded_sequence, dtype=torch.int64)
            masked_sequence = torch.tensor(masked_sequence, dtype=torch.int64)
            origin_tokens[i, 0] = self.cls_idx
            origin_tokens[i, 1:len(encoded_sequence)+1] = encoded_sequence
            origin_tokens[i, len(encoded_sequence)+1] = self.eos_idx

            masked_tokens[i, 0] = self.cls_idx
            masked_tokens[i, 1:len(masked_sequence)+1] = masked_sequence
            masked_tokens[i, len(masked_sequence)+1] = self.eos_idx

            target_tokens[i, 1:len(mask)+1][mask | unmask] = encoded_sequence[mask | unmask]

            if len(prompt_toks) != 0:
                origin_tokens[i, -len(prompt_toks):] = torch.clone(encoded_prompt)
                masked_tokens[i, -len(prompt_toks):] = torch.clone(encoded_prompt)
        return origin_tokens, masked_tokens, target_tokens
