import sympy as sp
from typing import List, Tuple
from alpha_integrate.synthetic_data.exceptions import InvalidPrefixExpression, UnknownSympyExpression, InfinityError
from alpha_integrate.synthetic_data.params.tokenizer_params import *
from alpha_integrate.synthetic_data.timeout import timeout

class ExpressionTokenizer:

    def __init__(self):
        pass

    @staticmethod
    def number_to_seq(number: sp.Integer) -> List[str]:
        ''' 
        Convert an integer to a list of tokens
        '''

        num_ls = []

        if number < 0:
            num_ls.append(INTMINUS)
            number = -number
        else:
            num_ls.append(INTPLUS)

        for digit in str(number):
            num_ls.append(digit)
        
        return num_ls
    
    @staticmethod
    def rational_to_seq(rational: sp.Rational) -> List[str]:
        '''
        Convert a sympy rational number to a list of tokens
        '''

        rational_ls = [RATIONAL]

        rational_ls += ExpressionTokenizer.number_to_seq(rational.p)
        rational_ls += ExpressionTokenizer.number_to_seq(rational.q)
        
        return rational_ls


    @staticmethod
    @timeout(1)
    def sp_to_seq(sp_expr: sp.Expr) -> List[str]:
        '''
        Traverse the sympy expression and convert it to a list of tokens
        '''

        # check if there's infinity sp.oo or complex infinity sp.zoo as subxpression of sp_expr

        if sp_expr.has(sp.oo) or sp_expr.has(sp.zoo):
            raise InfinityError(f'Infinity in expression {sp_expr} is not supported.')


        if isinstance(sp_expr, sp.Symbol):
            return [SYMBOLS[sp_expr.name]]
        
        elif isinstance(sp_expr, sp.Integer):
            return ExpressionTokenizer.number_to_seq(sp_expr)
        
        elif isinstance(sp_expr, sp.Rational):
            return ExpressionTokenizer.rational_to_seq(sp_expr)
        
        elif isinstance(sp_expr, sp.Integral):
            integration_vars = sp_expr.args[1:]
            num_vars = len(integration_vars)
            seq = num_vars * [INTEGRAL] + ExpressionTokenizer.sp_to_seq(sp_expr.args[0]) 
            for var in integration_vars:
                if len(var) > 1:
                    raise InvalidPrefixExpression(f'Integral with multiple variables {sp_expr} is not supported.')
                seq += ExpressionTokenizer.sp_to_seq(var[0])

            return seq

        
        elif sp_expr == sp.E:
            return [EXP]
        elif sp_expr == sp.pi:
            return [PI]
        elif sp_expr == sp.I:
            return [I]
        
        for sympy_operator, token in UNARY_OPERATORS.items():
            if isinstance(sp_expr, sympy_operator):
                return [token] + ExpressionTokenizer.sp_to_seq(sp_expr.args[0])
            
        for sympy_operator, token in SPECIAL_FUNCTIONS.items():
            if isinstance(sp_expr, sympy_operator):
                return [token] + ExpressionTokenizer.sp_to_seq(sp_expr.args[0])
            
        for sympy_operator, token in SPECIAL_FUNCTIONS2.items():
            if isinstance(sp_expr, sympy_operator):
                n_args = len(sp_expr.args)
                seq = []
                for i, arg in enumerate(sp_expr.args):
                    if i == 0 or i < n_args - 1:
                        seq.append(token)
                    seq += ExpressionTokenizer.sp_to_seq(arg)

                return seq
        
        for sympy_operator, token in BINARY_OPERATORS.items():
            if isinstance(sp_expr, sympy_operator):
                n_args = len(sp_expr.args)
                seq = []
                for i, arg in enumerate(sp_expr.args):
                    if i == 0 or i < n_args - 1:
                        seq.append(token)
                    seq += ExpressionTokenizer.sp_to_seq(arg)

                return seq
            
        raise UnknownSympyExpression
    
    @staticmethod
    def write_infix(token: str, *args: str) -> str:
        '''
        Write an infix expression given an operator and its arguments
        '''

        if token in BINARY_OPERATORS.values() or token == RATIONAL:
            if len(args) != 2:
                raise InvalidPrefixExpression(f'Binary operator {token} requires 2 arguments.')
            return f'({args[0]}) {token} ({args[1]})'
    
        if token in UNARY_OPERATORS.values() or token in SPECIAL_FUNCTIONS.values():
            if len(args) != 1:
                raise InvalidPrefixExpression(f'Unary operator {token} requires 1 argument.')
            return f'{token}({args[0]})'
        
        if token in SYMBOLS.values():
            return token
        
        if token == INTEGRAL:
            if len(args) != 2:
                raise InvalidPrefixExpression(f'Integral operator {token} requires 2 arguments.')
            return f'Integral({args[0]}, {args[1]})'
        
        if token in SPECIAL_FUNCTIONS2.values():
            if len(args) < 2:
                raise InvalidPrefixExpression(f'Special function {token} requires at least 2 arguments.')
            return f'{token}({", ".join(args)})'

    
    @staticmethod
    def _seq_to_string(seq_expr: List[str]) -> Tuple[str, List[str]]:
        '''
        Helper function to convert the sequence representation of an expression into a string
        '''

        if len(seq_expr) == 0:
            raise InvalidPrefixExpression("Empty prefix list.")
        
        t = seq_expr[0]

        bool1 = t in UNARY_OPERATORS.values() or t in SPECIAL_FUNCTIONS.values()
        bool2 = t in BINARY_OPERATORS.values() or t in [RATIONAL, INTEGRAL] or t in SPECIAL_FUNCTIONS2.values()
        
        if bool1 or bool2:
            
            if len(seq_expr) == 1:
                raise InvalidPrefixExpression(f'Operator {t} in {seq_expr} has no arguments.')

            arity = 1 if bool1 else 2
            args = []
            l1 = seq_expr[1:]
            for _ in range(arity):
                arg, l1 = ExpressionTokenizer._seq_to_string(l1)
                args.append(arg)

            return ExpressionTokenizer.write_infix(t, *args), l1

        elif t in SYMBOLS.values() or t == EXP or t == PI or t == I:
            return f'{t}', seq_expr[1:]
        
        elif t == INTPLUS or t == INTMINUS:
            sign = '' if t == INTPLUS else '-'
            i = 1
            while i < len(seq_expr) and seq_expr[i] in DIGITS.values():
                i += 1
            
            
            if i == 1:
                raise InvalidPrefixExpression(f'Invalid integer representation {seq_expr}.')
            
            return sign + ''.join(seq_expr[1:i]), seq_expr[i:]
        
        else:
            raise InvalidPrefixExpression(f'Unknown token {t} in {seq_expr}.')
        
    @staticmethod
    def seq_to_string(seq_expr: List[str]) -> str:
        '''
        Convert a sequence representation of an expression into a string
        '''

        expr_string, remaining = ExpressionTokenizer._seq_to_string(seq_expr)
        if len(remaining) > 0:
            raise InvalidPrefixExpression(f'Invalid prefix list. Remaining: {remaining}')
        
        return expr_string
    
    @staticmethod
    @timeout(1)
    def seq_to_sp(seq_expr: List[str]) -> sp.Expr:
        '''
        Convert the sequence representation to a sympy expression using sympify
        '''

        return sp.sympify(ExpressionTokenizer.seq_to_string(seq_expr))