from alpha_integrate.synthetic_data.params.tokenizer_params import EXPRESSION_TOKENS
from alpha_integrate.synthetic_data.params.step_params import RULE_TOKENS

'''
Tokenization will look like
START <EXPR> SUBEXPR <SUBEXPR> RULE <RULE> PARAM1 <PARAM1> PARAM2 <PARAM2> END
'''


START = 'START'
SUBEXPR = 'SUBEXPR'
RULE = 'RULE'
PARAM1 = 'PARAM1'
PARAM2 = 'PARAM2'
END = 'END'
PAD = 'PAD'

VOCABULARY = EXPRESSION_TOKENS + RULE_TOKENS + [START, SUBEXPR, RULE, PARAM1, PARAM2, END, PAD]
ID2WORD = {i: s for i, s in enumerate(VOCABULARY)}
WORD2ID = {s: i for i, s in ID2WORD.items()}

#RULE_TOKENS = [WORD2ID[w] for w in ALL_STEPS]

# print tokens and their ids
#for i, s in ID2WORD.items():
#    print(f'{i}: {s}')

def tokenize(line):
    data = line.split('\t\t')
    expr = data[0]
    subexpr = data[1]
    rule = tuple(data[2].split('\t'))
    result = data[3]

    return _tokenize(expr, subexpr, *rule)


def _tokenize(expr, subexpr, rule, param1 = '', param2 = ''):
    tokenized = [WORD2ID[START]]
    tokenized += [WORD2ID[w] for w in expr.split()]
    start_id = len(tokenized)
    tokenized.append(WORD2ID[SUBEXPR])
    tokenized += [WORD2ID[w] for w in subexpr.split()]
    tokenized.append(WORD2ID[RULE])
    tokenized.append(WORD2ID[rule])
    if param1 != '':
        tokenized.append(WORD2ID[PARAM1])
        tokenized += [WORD2ID[w] for w in param1.split()]
        if param2 != '':
            tokenized.append(WORD2ID[PARAM2])
            tokenized += [WORD2ID[w] for w in param2.split()]
    
    tokenized.append(WORD2ID[END])
    end_id = len(tokenized)
    return tokenized, start_id, end_id

def detokenize(tokenized):
    return [ID2WORD[t] for t in tokenized]



