import numpy as np

def check_and_convert(value):
    try:
        return float(value)  
    except ValueError:
        return None
    
def numeric_token_search(enc):
    # Find all tokens that decode to a number
    numeric_tokens = [token for token in range(enc.n_vocab - 100) if enc.decode([token]).isdigit()]
    return numeric_tokens  # Tokens representing whole numbers


def token_attack(x, scale, enc, numeric_tokens,additem=3,iteration=5):
    #x_str = "".join(map(str,x))
    # x is a normalized number 
    x_str = str(x)
    target = enc.encode(x_str)
    poisoned_token = []
    
    # check decimal
    decimal_token = enc.encode('.')
    minus_token = enc.encode('-')
    L = len(numeric_tokens)
    
    # compute poisoned token
    for i in range(len(target)):
        #print('iteration:', i)
        term = target[i]
        
        # add negative to the poisoned tokens
        if term == minus_token[0]:
            poisoned_token.append(term)
            #print('poisoned tokens:', poisoned_token)
            continue
        
        # add decimal to the poisoned tokens
        if term == decimal_token[0]:
            poisoned_token.append(term)
            #print('poisoned tokens:', poisoned_token)
            continue
        
        # locate the token
        p = numeric_tokens.index(term)
        for j in range(iteration):
            if p < 0.5 * L:
                p_ = np.random.random_integers(np.round(0.7*L),np.round(0.9*L)) 
            else:
                p_ = p_ = np.random.random_integers(np.round(0.1*L),np.round(0.3*L))
            poisoned_token.append(numeric_tokens[p_])
            
            poisoned_number = enc.decode(poisoned_token)
            
            # check whether the token can be converted into valid number and whether the difference is smaller than scale
            if check_and_convert(poisoned_number) == None:
                del poisoned_token[-1]
            else:
                if np.abs(float(poisoned_number) - x) < scale:
                    flag = 0
                    break
                else:
                    flag = 1
                    del poisoned_token[-1]
        if flag == 1:
            poisoned_token.append(term)
                
        #print('poisoned tokens:', poisoned_token)
    
    # add tokens
    if target.index(decimal_token[0]) < len(target)-1:
        for i in range(additem):
            p_ = np.random.random_integers(np.round(0.1*L),np.round(0.9*L)) 
            poisoned_token.append(numeric_tokens[p_])
            poisoned_number = enc.decode(poisoned_token)
            if check_and_convert(poisoned_number) == None:
                del poisoned_token[-1]
    
    return float(enc.decode(poisoned_token)), poisoned_token


def poison_token_sequence(a, scale, enc, numeric_tokens,additem=3,iteration=5):
    a_ = []
    for i in range(len(a)):
      x = a[i]
      x_, poisoned_token = token_attack(x, scale, enc, numeric_tokens,additem,iteration)  
      a_.append(x_)
    return np.array(a_)