from typing import List, Tuple
from alpha_integrate.synthetic_data.exceptions import InvalidPrefixExpression, ImaginaryUnitException
import sympy as sp
from sympy.core.parameters import evaluate
from alpha_integrate.synthetic_data.timeout import timeout

#import time
from alpha_integrate.synthetic_data.params.fb_params import OPERATORS, SYMBOLS, CONSTANTS, OPERATOR_TO_SYMPY

class TokensToSympy:

    def __init__(self):
        pass

    @staticmethod
    def to_sympy(token, args, eval_bool = False):
        """
        Convert a token and its arguments to a sympy expression
        """
        #t0 = time.time()
        with evaluate(eval_bool):
            if token == 'add':
                r = sp.Add(args[0], args[1])
            elif token == 'sub':
                r = sp.Add(args[0], sp.Mul(args[1], -1))
            elif token == 'mul':
                r = sp.Mul(args[0], args[1],)
            elif token == 'div':
                r = sp.Mul(args[0], sp.Pow(args[1], -1))
            elif token == 'pow':
                r = sp.Pow(args[0], args[1])
            elif token == 'sqrt':
                r = sp.Pow(args[0], sp.Rational(1, 2))
            elif token in OPERATOR_TO_SYMPY:
                r = OPERATOR_TO_SYMPY[token](args[0])
            else:
                return InvalidPrefixExpression(f"Unknown token in prefix expression: {token}, with arguments {args}")

        #t1 = time.time()
        #print(f"Time to convert token {token}, {args} to sympy: {t1 - t0}")
        return r
        
    @staticmethod
    def parse_int(seq_expr: List[str]) -> Tuple[sp.Integer, int]:
        '''
        Try to parse an integer from a sequence that starts with INT
        '''
        try:
            sign = seq_expr[0][3]
        except Exception as e:
            #print(seq_expr)
            raise InvalidPrefixExpression(f'Invalid integer representation: {TokensToSympy.prefix_as_string(seq_expr)}')
        num = ''

        index = 1
        while index < len(seq_expr) and seq_expr[index].isdigit():
            num += seq_expr[index]
            index += 1
        
        if num == '':
            raise InvalidPrefixExpression(f'Invalid integer representation: {TokensToSympy.prefix_as_string(seq_expr)}')

        return sp.Integer(sign + num), len(num) + 1

    @staticmethod
    def _seq_to_sp_direct(seq_expr: List[str]) -> Tuple[str, List[str]]:
        '''
        Convert the a sequence representation to a string
        '''
        
        if len(seq_expr) == 0:
            raise InvalidPrefixExpression("Empty prefix list.")
        
        t = seq_expr[0]
        if t in OPERATORS:
            if len(seq_expr) == 1:
                raise InvalidPrefixExpression(f'Operator {t} has no arguments.')
            
            args = []
            l1 = seq_expr[1:]
            for _ in range(OPERATORS[t]):
                arg, l1 = TokensToSympy._seq_to_sp_direct(l1)
                args.append(arg)

            root = TokensToSympy.to_sympy(t, args)

            return root, l1
        
        elif t in SYMBOLS:
            return SYMBOLS[t], seq_expr[1:]

        elif t in CONSTANTS:
            return CONSTANTS[t], seq_expr[1:]
        
        elif t.startswith('INT'):
            val, index = TokensToSympy.parse_int(seq_expr)
            return val, seq_expr[index:]
        
        elif t == 'I':
            raise ImaginaryUnitException('Imaginary unit not supported.')

        else:
            raise InvalidPrefixExpression(f'Invalid object {t} in prefix list.')
        
    @staticmethod
    @timeout(1)
    def seq_to_sp_direct(seq_expr: List[str]) -> str:
        '''
        Convert the sequence representation to a string
        '''

        sp_expr, remaining = TokensToSympy._seq_to_sp_direct(seq_expr)
        if len(remaining) != 0:
            raise InvalidPrefixExpression(f'Invalid prefix list. Remaining: {TokensToSympy.prefix_as_string(remaining)}')
        return sp_expr
        
    @staticmethod
    def prefix_as_string(seq_expr: List[str]) -> str:
        '''
        Return the prefix expression as string
        '''

        string = ''
        for t in seq_expr:
            string += str(t) + ' '
        return string