import sympy as sp
from sympy.core.symbol import Wild
from sympy.core.singleton import S
from sympy.functions.elementary.trigonometric import TrigonometricFunction
from sympy.integrals.manualintegrate import integral_steps
from sympy.integrals.manualintegrate import *

from alpha_integrate.synthetic_data.timeout import timeout
from alpha_integrate.synthetic_data.params.step_params import *
from alpha_integrate.synthetic_data.params.tokenizer_params import SYMBOLS
from alpha_integrate.synthetic_data.int_steps import int_steps

import random

# a specific way of getting wilds, helps with not repeating code
def get_wilds(symbol):
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    c = Wild('c', exclude=[symbol])
    d = Wild('d', exclude=[symbol])
    e = Wild('e', exclude=[symbol])
    f = Wild('f')
    n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd])
    return a,b,c,d,e,f,n

# patterns copied from sympy.integrals.manualintegrate

def make_wilds(symbol):
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, sp.Integer)])
    n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, sp.Integer)])

    return a, b, m, n


def sincos_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.sin(a*symbol)**m * sp.cos(b*symbol)**n

    return pattern, a, b, m, n

def tansec_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.tan(a*symbol)**m * sp.sec(b*symbol)**n

    return pattern, a, b, m, n

def cotcsc_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.cot(a*symbol)**m * sp.csc(b*symbol)**n

    return pattern, a, b, m, n

@timeout(4)
def decompose_steps(expression: sp.Expr, symbol: sp.Symbol, variable_list: list = [], steps = None) -> list:

    '''
    The goal of this function is to decompose the steps of the integration of an expression and generate a list of pairs of
    (expression, subexpression at which the rule should be applied, next rule to apply, result after the rule is applied) by taking the expression and applying the rules until the expression is integrated
    and recording the expression and the rule that was applied to get to the next step.
    '''
   
    tokenizable_steps = []

    integral = sp.Integral(expression, symbol)

    if len(variable_list) == 0:
        for s in expression.free_symbols:
            variable_list.append((s.name, s))

    if steps is None:
        try:
            steps = int_steps(expression, symbol)
        except: 
            return [None]

    if steps.contains_dont_know():
        return [None]
    
    if expression.has(sp.oo) or expression.has(sp.zoo):
        return [None]
    
    #print(expression, symbol, steps)
    #print("Integral: ", integral)
    #print("Steps: ", steps)

    if isinstance(steps, (CyclicPartsRule, CompleteSquareRule)):
        return [None]


    elif isinstance(steps, SIMPLE_STEPS):
        tokenizable_steps.append((integral, integral, (steps.__class__.__name__,), steps.eval()))

    
    elif isinstance(steps, SPECIAL_SIMPLE_STEPS):
        tokenizable_steps.append((integral, integral, (steps.__class__.__name__,), steps.eval()))

    elif isinstance(steps, AddRule):
        

        integrand = expression
        symbol = steps.variable
        integral_terms = integrand.as_ordered_terms()
        integrals = [sp.Integral(term, symbol) for term in integral_terms]
        old_integral = integral
        integral = sp.Add(*integrals)
        tokenizable_steps.append((old_integral, old_integral, (steps.__class__.__name__,), integral))

        substeps = steps.substeps

        if len(integral_terms) != len(substeps):
            # in some cases it doesn't work (e.g (sqrt(x) + 2)/sqrt(x)), just ignore for now
            return [None]
        
        for i, substep in enumerate(substeps):
            decomposed_substep = decompose_steps(integral_terms[i], symbol, variable_list, substep)
            if None in decomposed_substep:
                return [None]
            for element in decomposed_substep:
                e_expr, e_subexpr, e_rule, e_res = element
                old_integral = integral
                integrals[i] = e_res
                integral = sp.Add(*integrals)       
                tokenizable_steps.append((old_integral, e_subexpr, e_rule, integral))  
            # if the integral has no sp.Integral in it, we can break
            if not integral.has(sp.Integral):
                break         
        

    elif isinstance(steps, ConstantTimesRule):
        integrand = expression
        symbol = steps.variable
        constant = steps.constant
        other = steps.other
        old_integral = integral
        substep = steps.substep
        integral = constant*sp.Integral(other, symbol)

        # there are some cases where the further integrand is the same as the original integrand
        # in that case we don't want to add the step
        if integrand != substep.integrand:
            tokenizable_steps.append((old_integral, old_integral, (steps.__class__.__name__, ), integral))

        decomposed_substep = decompose_steps(other, symbol, variable_list, substep)
        if None in decomposed_substep:
            return [None]
        
        for element in decomposed_substep:
            e_expr, e_subexpr, e_rule, e_res = element
            old_integral = integral
            integral = constant*e_res
            tokenizable_steps.append((old_integral, e_subexpr, e_rule, integral))

    elif isinstance(steps, URule):
        u_var = steps.u_var
        u_func = steps.u_func

        new_symbol = None

        for s in SYMBOLS.keys():
            new_symbol = sp.Symbol(s)
            free_symbol_names = [phi.name for phi in integral.free_symbols]
            variable_list_names = [varchange[0] for varchange in variable_list]
            #print('Trying the symbol with name:', new_symbol.name)
            #print('Free symbol names:', free_symbol_names)
            #print('Variable list names:', variable_list_names)
            
            if not(new_symbol.name in variable_list_names or new_symbol.name in free_symbol_names):
                #print('Acquired symbol:', new_symbol)
                #print()
                u_func = u_func.subs(u_var, new_symbol)
                variable_list.append((new_symbol.name, u_func))
                break
            else:
                new_symbol = None

        if new_symbol is None:
            # this happens when there is more than 7 change of variables
            return [None]
        
        substep = steps.substep

        old_integral = integral
        new_integrand = substep.integrand.subs(u_var, new_symbol)
        integral = sp.Integral(new_integrand, new_symbol)

        tokenizable_steps.append((old_integral, old_integral, (steps.__class__.__name__, new_symbol, u_func), integral))
        
        decomposed_substep = decompose_steps(new_integrand, new_symbol, variable_list)
        if None in decomposed_substep:
            return [None]
        
        tokenizable_steps += decomposed_substep

    elif isinstance(steps, PartsRule):
        u = steps.u
        dv = steps.dv
        v_step = steps.v_step
        second_step = steps.second_step
        symbol = steps.variable

        du = sp.diff(u, symbol)
        old_integral = integral
        
        v = sp.Integral(dv, symbol)
        integral_part = sp.Integral(v*du, symbol)
        integral = u*v - integral_part

        tokenizable_steps.append((old_integral, old_integral, (steps.__class__.__name__, u, dv), integral))

        decomposed_v_step = decompose_steps(dv, symbol, variable_list, v_step)
        if None in decomposed_v_step:
            return [None]
        
        for element in decomposed_v_step:
            e_expr, e_subexpr, e_rule, e_res = element
            old_integral = integral
            integral = u*e_res - sp.Integral(e_res*du, symbol)
            tokenizable_steps.append((old_integral, e_subexpr, e_rule, integral))
            v = e_res

        # v = v_step.eval()
        # v step might have had some change of variables, we would like to replace these back
        # iterate through variable list in reverse order and replace the variables
        # until we reach symbol or beginning of the list
        for varchange in variable_list[::-1]:
            varname, func = varchange
            s = sp.Symbol(varname)
            if s == symbol:
                break
            v = v.subs(s, func)

        decomposed_next_steps = decompose_steps(v*du, symbol, variable_list)
        if None in decomposed_next_steps:
            return [None]
        for element in decomposed_next_steps:
            e_expr, e_subexpr, e_rule, e_res = element
            old_integral = integral
            integral = u*v - e_res
            tokenizable_steps.append((old_integral, e_subexpr, e_rule, integral))


    elif isinstance(steps, AlternativeRule):
        flag = True
        alternatives = steps.alternatives
        # shuffle to pick a random alternative
        random.shuffle(alternatives)

        for alternative in alternatives:
            decomposed_alternative = decompose_steps(expression, symbol, variable_list, alternative)
            if None not in decomposed_alternative:
                tokenizable_steps += decomposed_alternative
                flag = False
                break
        if flag:
            return [None]

    elif isinstance(steps, RewriteRule):
        integrand = expression
        symbol = steps.variable
        rewritten = steps.rewritten
        args = integrand.args

        rewritten_integral = sp.Integral(rewritten, symbol)

        # we check if the length increased at the end to see if we could resolve the rewrite
        curr_len = len(tokenizable_steps)
        rw_wait = True

        # now we need to understand what kind of rewriting this was 

        # get some substitutions for the integrand for later use
        integrand_sub1 = integrand.subs({1 / sp.cos(symbol): sp.sec(symbol)})
        integrand_sub2 = integrand.subs({1 / sp.sin(symbol): sp.csc(symbol), 1 / sp.tan(symbol): sp.cot(symbol), sp.cos(symbol) / sp.tan(symbol): sp.cot(symbol) })

        # first try to see if it matches some patterns for quadratic denominators with/out square roots
        # for now we omit this part
        """
        a,b,c,d,e,f,n = get_wilds(symbol)
        match1 = integrand.match(a / (b * symbol ** 2 + c))

        if rw_wait and match1:
            a,b,c = match1[a], match1[b], match1[c]

            if b.is_extended_real and c.is_extended_real:
                positive_cond = c/b > 0
                if positive_cond is S.false:
                    coeff = a/(2*sp.sqrt(-c)*sp.sqrt(b))
                    constant = sp.sqrt(-c/b)
                    r1 = 1/(symbol-constant)
                    r2 = 1/(symbol+constant)
                    sub = r1 - r2
                    rw = Mul(coeff, sub, evaluate=False) if coeff != 1 else sub
                    if rewritten == rw:
                        rule = 'QuadraticDenom1'
                        rw_wait = False
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))   

        a,b,c,d,e,f,n = get_wilds(symbol)
        match2 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e))

        if rw_wait and match2:
            a, b, c, d, e = match2[a], match2[b], match2[c], match2[d], match2[e]
            if not c.is_zero:
                denominator = c * symbol**2 + d * symbol + e
                const =  a/(2*c)
                numer1 =  (2*c*symbol+d)
                numer2 = - const*d + b

                if rewritten == const*numer1/denominator+numer2/denominator:
                    rule = 'QuadraticDenom2'
                    rw_wait = False
                    tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                
        a,b,c,d,e,f,n = get_wilds(symbol)
        match3 = integrand.match(f*sqrt(a+b*symbol+c*symbol**2)**n)

        if rw_wait and match3:
            a, b, c, f, n = match3[a], match3[b], match3[c], match3[f], match3[n]
            f_poly = f.as_poly(symbol)
            if f_poly is not None:
                if n == -1:
                    numer_poly = f_poly 
                    denom = sp.sqrt(a+b*symbol+c*symbol**2)
                    deg = numer_poly.degree()
                    if deg <= 1:
                        e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr())
                        A = e/(2*c)
                        B = d-A*b
                        pre_substitute = (2*c*symbol+b)/denom
                        if A != 0 and B!= 0 and rewritten == Add(A*pre_substitute, B/denom, evaluate=False):
                            rule = 'SqrtQuadraticDenom1'
                            rw_wait = False
                            tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                    
                elif n > 0:
                    numer_poly = f_poly * (a+b*symbol+c*symbol**2)**((n+1)/2)
                    if rewritten == numer_poly.as_expr()/sqrt(a+b*symbol+c*symbol**2):
                        rule = 'SqrtQuadratic1'
                        rw_wait = False
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
        """
        if rw_wait:
            try:
                expanded = integrand.expand()
            except:
                expanded = None
            if expanded is not None and rewritten == expanded:
                rule = 'Expand'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait:
            try: 
                apart = integrand.apart(symbol)
            except:
                apart = None

            if apart is not None and rewritten == apart:
                rule = 'PartialFractions'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait: 
            try: 
                cancelled = integrand.cancel()
            except:
                cancelled = None

            if cancelled is not None and rewritten == cancelled:
                rule = 'Cancel'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False
        

        if rw_wait and isinstance(integrand, sp.tan):
            if rewritten == (sp.sin(*args) / sp.cos(*args)):
                rule = 'Tan1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and isinstance(integrand, sp.cot):
            if rewritten == (sp.cos(*args) / sp.sin(*args)):
                rule = 'Cot1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and integrand == 1/sp.cos(symbol):
            if rewritten == sp.sec(symbol):
                rule = 'Cos1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and isinstance(integrand, sp.sec):
            arg = args[0]
            if rewritten == ((sp.sec(arg)**2 + sp.tan(arg) * sp.sec(arg)) / (sp.sec(arg) + sp.tan(arg))):
                rule = 'Sec1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False
        
        if rw_wait and isinstance(integrand, sp.csc):
            arg = args[0]
            if rewritten == ((sp.csc(arg)**2 + sp.cot(arg) * sp.csc(arg)) / (sp.csc(arg) + sp.cot(arg))):
                rule = 'Csc1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and isinstance(integrand, sp.tanh):
            arg = args[0]
            if rewritten == (sp.sinh(arg) / sp.cosh(arg)):
                rule = 'Tanh1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False
        
        if rw_wait and isinstance(integrand, sp.coth):
            if rewritten == (sp.cosh(*args) / sp.sinh(*args)):
                rule = 'Coth1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False
        
        if rw_wait and isinstance(integrand, sp.sech):
            arg = args[0]
            if rewritten == (1-sp.tanh(arg/2)**2)/(1+sp.tanh(arg/2)**2):
                rule = 'Sech1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False
        
        if rw_wait and isinstance(integrand, sp.csch):
            arg = args[0]
            if rewritten == (1-sp.tanh(arg/2)**2)/(2*sp.tanh(arg/2)):
                rule = 'Csch1'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1:
            if rewritten == integrand.expand(trig = True):
                rule = 'TrigExpand'
                tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                rw_wait = False

        if rw_wait and any(integrand.has(f) for f in (sp.sin, sp.cos)):
            pattern, a, b, m, n = sincos_pattern(symbol)
            match = integrand.match(pattern)

            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))

                if rw_wait and m.is_even and n.is_even and m.is_nonnegative and n.is_nonnegative:
                    if rewritten == (((((1 - sp.cos(2*a*symbol)) / 2))**(m/2)) * ((((1 + sp.cos(2*b*symbol)) / 2))**(n/2))):
                        rule = 'SinCosEven'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False

                if rw_wait and m.is_odd and m >= 3:
                    if rewritten == ( (1 - sp.cos(a*symbol)**2)**((m - 1) / 2) * sp.sin(a*symbol) * sp.cos(b*symbol) ** n):
                        rule = 'SinOddCos'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False

                if rw_wait and n.is_odd and n >= 3:
                    if rewritten == ( (1 - sp.sin(b*symbol)**2)**((n - 1) / 2) * sp.cos(b*symbol) * sp.sin(a*symbol) ** m):
                        rule = 'CosOddSin'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False
        
        if rw_wait and any(integrand_sub1.has(f) for f in (sp.sec, sp.tan)):

            pattern, a, b, m, n = tansec_pattern(symbol)
            match = integrand_sub1.match(pattern)

            if match:
                a, b, m, n = tuple(match.get(w, sp.S.Zero) for w in (a, b, m, n))

                if rw_wait and n.is_even and n >= 4:
                    if rewritten == ((1 + sp.tan(b * symbol)**2)**(n / 2 - 1) * sp.sec(b * symbol)**2 * sp.tan(a * symbol)**m):
                        rule = 'SecEvenTan'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False

                if rw_wait and m.is_odd:
                    if rewritten == ((sp.sec(a * symbol)**2 - 1)**((m - 1) / 2) * sp.tan(a * symbol) * sp.sec(b * symbol)**n):
                        rule = 'TanOddSec'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False

                if rw_wait and m == 2 and n == 0:
                    if rewritten == sp.sec(a * symbol)**2 - 1:
                        rule = 'Tan2'
                        tokenizable_steps.append((integral, integrand, (rule,), rewritten_integral))
                        rw_wait = False
        
        if rw_wait and any(integrand_sub2.has(f) for f in (sp.cot, sp.csc)):
            pattern, a, b, m, n = cotcsc_pattern(symbol)
            match = integrand_sub2.match(pattern)

            if match:
                a, b, m, n = tuple(match.get(w, sp.S.Zero) for w in (a, b, m, n))

                if rw_wait and n.is_even and n >= 4:
                    if rewritten == ((1 + sp.cot(b * symbol)**2)**(n / 2 - 1) * sp.csc(b * symbol)**2 * sp.cot(a * symbol)**m):
                        rule = 'CotCscEven'
                        tokenizable_steps.append((integral, integrand_sub2, (rule,), rewritten_integral))
                        rw_wait = False

                if rw_wait and m.is_odd:
                    if rewritten == ((sp.csc(a * symbol)**2 - 1)**((m - 1) / 2) * sp.cot(a * symbol) * sp.csc(b * symbol)**n):
                        rule = 'CotOddCsc'
                        tokenizable_steps.append((integral, integrand_sub2, (rule,), rewritten_integral))
                        rw_wait = False              
        
        # couldn't figure out this rewrite
        cond = (curr_len == len(tokenizable_steps))
        assert cond == rw_wait
        if curr_len == len(tokenizable_steps):
            #logging.info(f"Failed to decompose {integrand} with rewrite rule {rewritten}.")
            return [None]
        
        substep = steps.substep
        decomposed_substep = decompose_steps(rewritten, symbol, variable_list, substep)
        if None in decomposed_substep:
            return [None]
        
        integral = rewritten_integral

        for element in decomposed_substep:
            e_expr, e_subexpr, e_rule, e_res = element
            old_integral = integral
            integral = e_res
            tokenizable_steps.append((old_integral, e_subexpr, e_rule, integral))

    else:
        #logging.info(f"Unrecognized Step: {steps}")
        return [None]

    #if None in tokenizable_steps:
    #    logging.info(f"Failed Expression: {expression}")
    #    logging.info(f"Steps: {steps}\n")
    return tokenizable_steps


def steps_to_string(tokenizable_steps: list) -> str:

    '''
    Convert the list of tokenizable steps to a string
    '''

    step_str = ''
    for step in tokenizable_steps:
        integral, subexpr, rule, result = step
        step_str += f"{integral} -> {result} using {rule} at {subexpr}\n"

    return step_str

def list_of_rules(tokenizable_steps: list) -> list:
    '''
    Returns list of rules used in a tokenizable steps list
    '''

    rules = []
    for step in tokenizable_steps:
        rules.append(step[2][0])

    return rules