import string

from nlpaug.augmenter.char.random import RandomCharAug
from nlpaug.util import Action, Method, Doc
import numpy as np


class MyRandomCharAug(RandomCharAug):
    def __init__(self, action=Action.SUBSTITUTE, name='RandomChar_Aug', aug_char_min=1, aug_char_max=10, aug_char_p=0.3,
                 aug_word_p=0.3, aug_word_min=1, aug_word_max=10, include_upper_case=True, include_lower_case=True,
                 include_numeric=True, min_char=4, swap_mode='adjacent', spec_char='!@#$%^&*()_+', stopwords=None,
                 tokenizer=None, reverse_tokenizer=None, verbose=0, stopwords_regex=None, candidates=None):
        super().__init__(
            action=action, name=name, aug_char_min=aug_char_min, aug_char_max=aug_char_max, aug_char_p=aug_char_p,
                 aug_word_p=aug_word_p, aug_word_min=aug_word_min, aug_word_max=aug_word_max, include_upper_case=include_upper_case, include_lower_case=include_lower_case,
                 include_numeric=include_numeric, min_char=min_char, swap_mode=swap_mode, spec_char=spec_char, stopwords=stopwords,
                 tokenizer=tokenizer, reverse_tokenizer=reverse_tokenizer, verbose=verbose, stopwords_regex=stopwords_regex, candidates=candidates)

        self.model = self.get_model()

    def _get_aug_idxes(self, tokens, aug_p):
        return np.nonzero([np.random.binomial(1, aug_p) for i, t in enumerate(tokens)])[0]

    def insert(self, data):
        if not data or not data.strip():
            return data

        change_seq = 0

        doc = Doc(data, self.tokenizer(data))

        for token_i, token in enumerate(doc.get_original_tokens()):
            chars = self.token2char(token)
            aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_p)
            if aug_char_idxes is None:
                continue

            aug_char_idxes = np.sort(aug_char_idxes)[::-1]
            for char_i in aug_char_idxes:
                chars.insert(char_i, self.sample(self.model, 1)[0])

            # No capitalization alignment as this augmenter try to simulate random error

            new_token = ''.join(chars)
            change_seq += 1
            doc.add_change_log(token_i, new_token=new_token, action=Action.INSERT,
                                  change_seq=self.parent_change_seq + change_seq)

        if self.include_detail:
            return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs()
        else:
            return self.reverse_tokenizer(doc.get_augmented_tokens())

    def substitute(self, data):
        if not data or not data.strip():
            return data

        change_seq = 0

        doc = Doc(data, self.tokenizer(data))

        for token_i, token in enumerate(doc.get_original_tokens()):
            substitute_token = ''
            chars = self.token2char(token)
            aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_p)
            if aug_char_idxes is None:
                continue

            for char_i, char in enumerate(chars):
                if char_i not in aug_char_idxes:
                    substitute_token += char
                    continue

                substitute_token += self.sample(self.model, 1)[0]

            # No capitalization alignment as this augmenter try to simulate random error

            change_seq += 1
            doc.add_change_log(token_i, new_token=substitute_token, action=Action.SUBSTITUTE,
                               change_seq=self.parent_change_seq + change_seq)

        if self.include_detail:
            return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs()
        else:
            return self.reverse_tokenizer(doc.get_augmented_tokens())

    def delete(self, data):
        if not data or not data.strip():
            return data
            
        change_seq = 0

        doc = Doc(data, self.tokenizer(data))

        for token_i, token in enumerate(doc.get_original_tokens()):

            chars = self.token2char(token)
            aug_char_idxes = self._get_aug_idxes(chars, self.aug_char_p)
            if aug_char_idxes is None or len(aug_char_idxes) < 1:
                continue

            aug_char_idxes = np.sort(aug_char_idxes)[::-1]
            for i in aug_char_idxes:
                del chars[i]

            # No capitalization alignment as this augmenter try to simulate random error

            delete_token = ''.join(chars)
            change_seq += 1
            doc.add_change_log(token_i, new_token=delete_token, action=Action.DELETE,
                               change_seq=self.parent_change_seq + change_seq)

        if self.include_detail:
            return self.reverse_tokenizer(doc.get_augmented_tokens()), doc.get_change_logs()
        else:
            return self.reverse_tokenizer(doc.get_augmented_tokens())

class MyRandomCharAugg():
    def __init__(self, aug_char_p=0.1):
        self.aug_char_p = aug_char_p
        self.model = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z",
 "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z",
 "0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]
    def augment(self, data):
        if not data or not data.strip():
            return data
        data_result = []
        for token_i, token in enumerate(data.split()):
            chars = list(token)
            aug_char_idxes = np.nonzero([np.random.binomial(1, self.aug_char_p) for i, t in enumerate(chars)])[0]
            if aug_char_idxes is None:
                continue
            actions = np.random.choice(["substitute", "insert", "delete"], len(aug_char_idxes))
            final_result=[]
            index = 0
            for i in range(len(chars)):
                if i in aug_char_idxes:
                    if actions[index] == "substitute":
                        final_result.append(np.random.choice(self.model))
                    if actions[index] == "insert":
                        final_result.append(chars[i])
                        final_result.append(np.random.choice(self.model))
                    if actions[index] == "delete":
                        pass
                    index += 1
                else:
                    final_result.append(chars[i])
            data_result.append("".join(final_result))
        return [" ".join(data_result)]

class MyRandomCharAug_spacedel():
    def __init__(self, aug_char_p=0.1):
        self.aug_char_p = aug_char_p
    def augment(self, data):
        if not data or not data.strip():
            return data
        data = list(data)
        space_idxes = [i for i, t in enumerate(data) if t==' ']
        aug_space_idxes = np.nonzero([np.random.binomial(1, self.aug_char_p) for _ in space_idxes])[0]
        for idx in aug_space_idxes[::-1]:
            del data[space_idxes[idx]]
        return [''.join(data)]

class MyRandomCharAug_spaceinsert():
    def __init__(self, aug_char_p=0.1):
        self.aug_char_p = aug_char_p
    def augment(self, data):
        if not data or not data.strip():
            return data
        data_result = []
        for token_i, token in enumerate(data.split()):
            chars = list(token)
            aug_char_idxes = np.nonzero([np.random.binomial(1, self.aug_char_p) for i, t in enumerate(chars)])[0]
            if aug_char_idxes is None:
                continue
            final_result=[]
            index = 0
            for i in range(len(chars)):
                if i in aug_char_idxes:
                    final_result.append(chars[i])
                    final_result.append(' ')
                    index += 1
                else:
                    final_result.append(chars[i])
            data_result.append("".join(final_result))
        return [" ".join(data_result)]