import sympy as sp
from alpha_integrate.synthetic_data.params.step_params import *
from typing import List, Tuple
from alpha_integrate.synthetic_data.method import *
from itertools import product
import timeout_decorator

@timeout_decorator.timeout(2)
def apply_step(step: Tuple[sp.Expr, sp.Expr, Tuple]) -> List[sp.Expr]:

    '''
    This function applies the rule to the expression and returns all the possible results. The input
    step is a tuple consisting of original expression, subexpression to which the rule is applied, the rule
    '''

    expr, subexpr, rule = step

    # check if subexpr is actually a subexpression of expr
    #if expr == subexpr:
    #    return [apply_rule(rule, expr)]
    
    res, is_applied = apply_rule(rule, subexpr)

    if is_applied:
        res, res_bool = replace_subexpr(expr, subexpr, res)

        if res_bool:
            return res

    return [expr]



def apply_rule(rule: Tuple, subexpr: sp.Expr) -> Tuple[sp.Expr, bool]:

    '''
    This function applies the rule to the expression and returns all the possible results. The input
    rule is a tuple consisting of the rule and the subexpression to which the rule is applied. 
    Returns the result and a boolean indicating if the rule was applied or not
    '''

    # first element of the rule is rule_name, rest is the args
    rule_name, *args = rule

    if rule_name == 'ConstantRule':
        res = ConstantMethod().apply(subexpr)
    elif rule_name == 'PowerRule':
        res = PowerMethod().apply(subexpr)
    elif rule_name == 'ExpRule':
        res = ExpMethod().apply(subexpr)
    elif rule_name == 'ConstantTimesRule':
        res = ConstantTimesMethod().apply(subexpr)
    elif rule_name == 'ReciprocalRule':
        res = ReciprocalMethod().apply(subexpr)
    elif rule_name == 'NestedPowRule':
        res = NestedPowMethod().apply(subexpr)
    elif rule_name == 'ArcsinRule':
        res = ArcsinMethod().apply(subexpr)
    elif rule_name == 'ArcsinhRule':
        res = ArcsinhMethod().apply(subexpr)
    elif rule_name == 'SinRule':
        res = SinMethod().apply(subexpr)
    elif rule_name == 'CosRule':
        res = CosMethod().apply(subexpr)
    elif rule_name == 'SecTanRule':
        res = SecTanMethod().apply(subexpr)
    elif rule_name == 'CscCotRule':
        res = CscCotMethod().apply(subexpr)
    elif rule_name == 'Sec2Rule':
        res = Sec2Method().apply(subexpr)
    elif rule_name == 'Csc2Rule':
        res = Csc2Method().apply(subexpr)
    elif rule_name == 'SinhRule':
        res = SinhMethod().apply(subexpr)
    elif rule_name == 'CoshRule':
        res = CoshMethod().apply(subexpr)
    elif rule_name == 'ArctanRule':
        res = ArctanMethod().apply(subexpr)
    elif rule_name == 'ReciprocalSqrtQuadraticRule':
        res = ReciprocalSqrtQuadraticMethod().apply(subexpr)
    elif rule_name == 'CiRule':
        res = CiMethod().apply(subexpr)
    elif rule_name == 'EiRule':
        res = EiMethod().apply(subexpr)
    elif rule_name == 'UpperGammaRule':
        res = UpperGammaMethod().apply(subexpr)
    elif rule_name == 'AddRule':
        res = AddMethod().apply(subexpr)
    elif rule_name == 'URule':
        u_rule, u_func = args
        res = UMethod(u_rule, u_func).apply(subexpr)
    elif rule_name == 'PartsRule':
        u, dv = args
        res = PartsMethod(u, dv).apply(subexpr)
    elif rule_name == 'PartialFractions':
        res = PartialFractionsMethod().apply(subexpr)
    elif rule_name == 'Cancel':
        res = CancelMethod().apply(subexpr)
    elif rule_name == 'Expand':
        res = ExpandMethod().apply(subexpr)
    elif rule_name == 'Tan1':
        res = Tan1Method().apply(subexpr)
    elif rule_name == 'Cot1':
        res = Cot1Method().apply(subexpr)
    elif rule_name == 'Cos1':
        res = Cos1Method().apply(subexpr)
    elif rule_name == 'Sec1':
        res = Sec1Method().apply(subexpr)
    elif rule_name == 'Csc1':
        res = Csc1Method().apply(subexpr)
    elif rule_name == 'Tanh1':
        res = Tanh1Method().apply(subexpr)
    elif rule_name == 'Coth1':
        res = Coth1Method().apply(subexpr)
    elif rule_name == 'Sech1':
        res = Sech1Method().apply(subexpr)
    elif rule_name == 'Csch1':
        res = Csch1Method().apply(subexpr)
    elif rule_name == 'TrigExpand':
        res = TrigExpandMethod().apply(subexpr)
    elif rule_name == 'SinCosEven':
        res = SinCosEvenMethod().apply(subexpr)
    elif rule_name == 'SinOddCos':
        res = SinOddCosMethod().apply(subexpr)
    elif rule_name == 'CosOddSin':
        res = CosOddSinMethod().apply(subexpr)
    elif rule_name == 'SecEvenTan':
        res = SecEvenTanMethod().apply(subexpr)
    elif rule_name == 'TanOddSec':
        res = TanOddSecMethod().apply(subexpr)
    elif rule_name == 'Tan2':
        res = Tan2Method().apply(subexpr)
    elif rule_name == 'CotCscEven':
        res = CotCscEvenMethod().apply(subexpr)
    elif rule_name == 'CotOddCsc':
        res = CotOddCscMethod().apply(subexpr)
    else:
        res = None
    if res is not None:
        return res, True
    else:
        return subexpr, False



def generate_combinations(lists: List[List[sp.Expr]]) -> List[List[sp.Expr]]:

    '''
    This function generates all possible combinations of elements from the input lists
    '''

    return list(product(*lists))
    

def replace_subexpr(expr: sp.Expr, subexpr: sp.Expr, replacement: sp.Expr) -> Tuple[List[sp.Expr], bool]:

    '''
    This function replaces subexpr with replacement in the expression expr and returns all the possible results
    by finding all matching subexpressions and replacing them one by one
    '''

    if expr == subexpr:
        return [expr, replacement], True
    
    # magic for replacing double integrals
    if isinstance(expr, sp.Integral) and isinstance(subexpr, sp.Integral):
        if expr.args[0] == subexpr.args[0] and expr.args[1] == subexpr.args[1] and len(expr.args) > 2:
            next_args = tuple(list(expr.args[2:]))
            return [expr, sp.Integral(replacement, *next_args)], True
            
    if expr.is_Atom:
        return [expr], False
    
    replaced_args = []
    for arg in expr.args:
        replaced_arg, is_replaced = replace_subexpr(arg, subexpr, replacement)
        if is_replaced:
            replaced_args.append(replaced_arg)
        else:
            replaced_args.append([arg])
    
    # now we compute all possibilities by combining all possible replacements
    results = []
    for comb in generate_combinations(replaced_args):
        results.append(expr.func(*comb))
    
    return results, len(results) > 1
    

def is_subexpr(expr: sp.Expr, subexpr: sp.Expr):

    '''
    This function checks if subexpr is a subexpression of expr
    '''

    if expr == subexpr:
        return True

    if expr.is_Atom:
        return False

    for arg in expr.args:
        if is_subexpr(arg, subexpr):
            return True

    return False


if __name__ == "__main__":


    pass

