from __future__ import annotations
from typing import NamedTuple, Type, Callable
from collections import namedtuple, defaultdict
from collections.abc import Mapping
from functools import reduce
import download
import sympy
from sympy import *
from sympy.integrals.manualintegrate import manualintegrate, _manualintegrate, heaviside_pattern
import traceback
import time
import gc
import multiprocessing

# Rules
from sympy.integrals.manualintegrate import manual_subs, contains_dont_know, IntegralInfo, ExpRule, AddRule, ConstantRule, ConstantTimesRule, ReciprocalRule, PowerRule, PiecewiseRule, DerivativeRule, DontKnowRule, RewriteRule, TrigRule, URule, ArctanRule, ArccothRule, ArctanhRule, EiRule, CiRule, ChiRule, SiRule, ShiRule, LiRule, ErfRule, FresnelSRule, FresnelCRule, UpperGammaRule, PolylogRule, EllipticFRule, EllipticERule, InverseHyperbolicRule, ArcsinRule, JacobiRule, GegenbauerRule, ChebyshevTRule, ChebyshevURule, LegendreRule, HermiteRule, LaguerreRule, AssocLaguerreRule

from sympy.core.logic import fuzzy_not

import sympy
from sympy import *

from sympy.integrals.manualintegrate import _manualintegrate, manual_diff, manual_subs

from sympy.core.add import Add
from sympy.core.cache import cacheit
from sympy.core.containers import Dict
from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.logic import fuzzy_not
from sympy.core.mul import Mul
from sympy.core.numbers import Integer, Number, E
from sympy.core.power import Pow
from sympy.core.relational import Eq, Ne
from sympy.core.singleton import S
from sympy.core.symbol import Dummy, Symbol, Wild
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.exponential import exp, log
from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch,
                                                   cosh, coth, sech, sinh, tanh, asinh, atanh)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.piecewise import Piecewise
from sympy.functions.elementary.trigonometric import (TrigonometricFunction,
                                                      cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec)
from sympy.functions.special.delta_functions import Heaviside, DiracDelta
from sympy.functions.special.error_functions import (erf, erfi, fresnelc,
                                                     fresnels, Ci, Chi, Si, Shi, Ei, li)
from sympy.functions.special.gamma_functions import uppergamma
from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f
from sympy.functions.special.polynomials import (chebyshevt, chebyshevu,
                                                 legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi,
                                                 OrthogonalPolynomial)
from sympy.functions.special.zeta_functions import polylog
from sympy.integrals.integrals import Integral
from sympy.logic.boolalg import And
from sympy.ntheory.factor_ import primefactors
from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly
from sympy.simplify.radsimp import fraction
from sympy.simplify.simplify import simplify
from sympy.solvers.solvers import solve
from sympy.strategies.core import switch, null_safe, condition
#from core_orig import switch, do_one, null_safe, condition
from sympy.utilities.iterables import iterable
from sympy.utilities.misc import debug

identity = lambda x: x

def do_one(*rules):
    """ Try each of the rules until one works. Then stop. """
    def do_one_rl(expr):
        for rl in rules:
            #            print(rl)
            result = rl(expr)
            if result:
                if type(result) == tuple:
                    if result[0] != expr:
                        # print(1)
                        return result
                else:
                    if result != expr:
                        # print(result)
                        return result
        return expr
    return do_one_rl

def Rule(name, props=""):
    # GOTCHA: namedtuple class name not considered!
    def __eq__(self, other):
        return self.__class__ == other.__class__ and tuple.__eq__(self, other)
    __neq__ = lambda self, other: not __eq__(self, other)
    cls = namedtuple(name, props + " context symbol")
    cls.__eq__ = __eq__
    cls.__ne__ = __neq__
    return cls

ConstantRule = Rule("ConstantRule", "constant")
ConstantTimesRule = Rule("ConstantTimesRule", "constant other substep")
PowerRule = Rule("PowerRule", "base exp")
AddRule = Rule("AddRule", "substeps")
URule = Rule("URule", "u_var u_func constant substep")
PartsRule = Rule("PartsRule", "u dv v_step second_step")
CyclicPartsRule = Rule("CyclicPartsRule", "parts_rules coefficient")
TrigRule = Rule("TrigRule", "func arg")
HyperbolicRule = Rule("HyperbolicRule", "func arg")
ExpRule = Rule("ExpRule", "base exp")
ReciprocalRule = Rule("ReciprocalRule", "func")
ArcsinRule = Rule("ArcsinRule")
ArcsinhRule = Rule("ArcsinhRule")
ReciprocalSqrtQuadraticRule = Rule("ReciprocalSqrtQuadraticRule", "a b c")
SqrtQuadraticDenomRule = Rule("SqrtQuadraticDenomRule", "a b c coeffs")
SqrtQuadraticRule = Rule("SqrtQuadraticRule", "a b c")
AlternativeRule = Rule("AlternativeRule", "alternatives")
DontKnowRule = Rule("DontKnowRule")
DerivativeRule = Rule("DerivativeRule")
RewriteRule = Rule("RewriteRule", "rewritten substep")
CompleteSquareRule = Rule("CompleteSquareRule", "rewritten substep")
PiecewiseRule = Rule("PiecewiseRule", "subfunctions")
HeavisideRule = Rule("HeavisideRule", "harg ibnd substep")
DiracDeltaRule = Rule("DiracDeltaRule", "n a b")
TrigSubstitutionRule = Rule("TrigSubstitutionRule",
                            "theta func rewritten substep restriction")
ArctanRule = Rule("ArctanRule", "a b c")
ArctanhRule = Rule("ArctanhRule", "a b c")
JacobiRule = Rule("JacobiRule", "n a b")
GegenbauerRule = Rule("GegenbauerRule", "n a")
ChebyshevTRule = Rule("ChebyshevTRule", "n")
ChebyshevURule = Rule("ChebyshevURule", "n")
LegendreRule = Rule("LegendreRule", "n")
HermiteRule = Rule("HermiteRule", "n")
LaguerreRule = Rule("LaguerreRule", "n")
AssocLaguerreRule = Rule("AssocLaguerreRule", "n a")
CiRule = Rule("CiRule", "a b")
ChiRule = Rule("ChiRule", "a b")
EiRule = Rule("EiRule", "a b")
SiRule = Rule("SiRule", "a b")
ShiRule = Rule("ShiRule", "a b")
ErfRule = Rule("ErfRule", "a b c")
FresnelCRule = Rule("FresnelCRule", "a b c")
FresnelSRule = Rule("FresnelSRule", "a b c")
LiRule = Rule("LiRule", "a b")
PolylogRule = Rule("PolylogRule", "a b")
UpperGammaRule = Rule("UpperGammaRule", "a e")
EllipticFRule = Rule("EllipticFRule", "a d")
EllipticERule = Rule("EllipticERule", "a d")


class IntegralInfo(NamedTuple):
    integrand: Expr
    symbol: Symbol


evaluators = {}
def evaluates(rule):
    def _evaluates(func):
        func.rule = rule
        evaluators[rule] = func
        return func
    return _evaluates


def contains_dont_know(rule):
    if isinstance(rule, DontKnowRule):
        return True
    if not isinstance(rule, tuple):
        return False
    for val in rule:
        if isinstance(val, tuple):
            if contains_dont_know(val):
                return True
        elif isinstance(val, list):
            if any(contains_dont_know(i) for i in val):
                return True
    return False


def manual_diff(f, symbol):
    """Derivative of f in form expected by find_substitutions

    SymPy's derivatives for some trig functions (like cot) are not in a form
    that works well with finding substitutions; this replaces the
    derivatives for those particular forms with something that works better.

    """
    if f.args:
        arg = f.args[0]
        if isinstance(f, tan):
            return arg.diff(symbol) * sec(arg)**2
        elif isinstance(f, cot):
            return -arg.diff(symbol) * csc(arg)**2
        elif isinstance(f, sec):
            return arg.diff(symbol) * sec(arg) * tan(arg)
        elif isinstance(f, csc):
            return -arg.diff(symbol) * csc(arg) * cot(arg)
        elif isinstance(f, Add):
            return sum([manual_diff(arg, symbol) for arg in f.args])
        elif isinstance(f, Mul):
            if len(f.args) == 2 and isinstance(f.args[0], Number):
                return f.args[0] * manual_diff(f.args[1], symbol)
    return f.diff(symbol)

def manual_subs(expr, *args):
    """
    A wrapper for `expr.subs(*args)` with additional logic for substitution
    of invertible functions.
    """
    if len(args) == 1:
        sequence = args[0]
        if isinstance(sequence, (Dict, Mapping)):
            sequence = sequence.items()
        elif not iterable(sequence):
            raise ValueError("Expected an iterable of (old, new) pairs")
    elif len(args) == 2:
        sequence = [args]
    else:
        raise ValueError("subs accepts either 1 or 2 arguments")

    new_subs = []
    for old, new in sequence:
        if isinstance(old, log):
            # If log(x) = y, then exp(a*log(x)) = exp(a*y)
            # that is, x**a = exp(a*y). Replace nontrivial powers of x
            # before subs turns them into `exp(y)**a`, but
            # do not replace x itself yet, to avoid `log(exp(y))`.
            x0 = old.args[0]
            expr = expr.replace(lambda x: x.is_Pow and x.base == x0,
                                lambda x: exp(x.exp*new))
            new_subs.append((x0, exp(new)))

    return expr.subs(list(sequence) + new_subs)

# Method based on that on SIN, described in "Symbolic Integration: The
# Stormy Decade"

inverse_trig_functions = (atan, asin, acos, acot, acsc, asec)


def find_substitutions(integrand, symbol, u_var):
    results = []

    def test_subterm(u, u_diff):
        if u_diff == 0:
            return False
        substituted = integrand / u_diff
        debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var))
        substituted = manual_subs(substituted, u, u_var).cancel()

        if symbol in substituted.free_symbols:
            return False
        # avoid increasing the degree of a rational function
        if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var):
            deg_before = max([degree(t, symbol) for t in integrand.as_numer_denom()])
            deg_after = max([degree(t, u_var) for t in substituted.as_numer_denom()])
            if deg_after > deg_before:
                return False
        return substituted.as_independent(u_var, as_Add=False)

    def exp_subterms(term: Expr):
        linear_coeffs = []
        terms = []
        n = Wild('n', properties=[lambda n: n.is_Integer])
        for exp_ in term.find(exp):
            arg = exp_.args[0]
            if symbol not in arg.free_symbols:
                continue
            match = arg.match(n*symbol)
            if match:
                linear_coeffs.append(match[n])
            else:
                terms.append(exp_)
        if linear_coeffs:
            terms.append(exp(gcd_list(linear_coeffs)*symbol))
        return terms

    def possible_subterms(term):
        if isinstance(term, (TrigonometricFunction, HyperbolicFunction,
                             *inverse_trig_functions,
                             exp, log, Heaviside)):
            return [term.args[0]]
        elif isinstance(term, (chebyshevt, chebyshevu,
                               legendre, hermite, laguerre)):
            return [term.args[1]]
        elif isinstance(term, (gegenbauer, assoc_laguerre)):
            return [term.args[2]]
        elif isinstance(term, jacobi):
            return [term.args[3]]
        elif isinstance(term, Mul):
            r = []
            for u in term.args:
                r.append(u)
                r.extend(possible_subterms(u))
            return r
        elif isinstance(term, Pow):
            r = [arg for arg in term.args if arg.has(symbol)]
            if term.exp.is_Integer:
                r.extend([term.base**d for d in primefactors(term.exp)
                          if 1 < d < abs(term.args[1])])
                if term.base.is_Add:
                    r.extend([t for t in possible_subterms(term.base)
                              if t.is_Pow])
            return r
        elif isinstance(term, Add):
            r = []
            for arg in term.args:
                r.append(arg)
                r.extend(possible_subterms(arg))
            return r
        return []

    for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))):
        if u == symbol:
            continue
        u_diff = manual_diff(u, symbol)
        new_integrand = test_subterm(u, u_diff)
        if new_integrand is not False:
            constant, new_integrand = new_integrand
            if new_integrand == integrand.subs(symbol, u_var):
                continue
            substitution = (u, constant, new_integrand)
            if substitution not in results:
                results.append(substitution)

    return results

def rewriter(condition, rewrite, rule_called):
    """Strategy that rewrites an integrand."""
    #    print("rewriter", rl_name, condition, rewrite)
    def _rewriter(integral):
        integrand, symbol = integral
        debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol))
        #        print("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol))
        if condition(*integral):
            rewritten = rewrite(*integral)
            if rewritten != integrand:
                steps = [(integral, rule_called)]
                rwr_res = integral_steps(rewritten, symbol)
                substep = rwr_res[0]
                if type(rwr_res[1]) == list:
                    steps.extend(rwr_res[1])
                else:
                    steps.append(rwr_res[1])
                if not isinstance(substep, DontKnowRule) and substep:
                    return RewriteRule(
                        rewritten,
                        substep,
                        integrand, symbol), steps
    return _rewriter

def proxy_rewriter(condition, rewrite, rule_called):
    """Strategy that rewrites an integrand based on some other criteria."""
    def _proxy_rewriter(criteria):
        criteria, integral = criteria
        integrand, symbol = integral
        debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria))
        #        print("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria))
        args = criteria + list(integral)
        if condition(*args):
            rewritten = rewrite(*args)
            if rewritten != integrand:
                steps = [(integral, rule_called)]
                prwr_res = integral_steps(rewritten, symbol)
                substep = prwr_res[0]
                if type(prwr_res[1]) == list:
                    steps.extend(prwr_res[1])
                else:
                    steps.append(prwr_res[1])
                return RewriteRule(
                    rewritten,
                    substep,
                    integrand, symbol), steps
    return _proxy_rewriter

def multiplexer(conditions):
    """Apply the rule that matches the condition, else None"""
    def multiplexer_rl(expr):
        for key, rule in conditions.items():
            if key(expr):
                return rule(expr)
    return multiplexer_rl

def alternatives(*rules):
    """Strategy that makes an AlternativeRule out of multiple possible results."""
    def _alternatives(integral):
        alts = []
        count = 0
        steps_list=[]
        debug("List of Alternative Rules")
        #        print("alternatives", rules)
        for rule in rules:
            count = count + 1
            debug("Rule {}: {}".format(count, rule))

            result = rule(integral)
            if type(result) == tuple:
                if result[0] != integral:
                    if (result[0] and not isinstance(result[0], DontKnowRule) and
                            result[0] != integral and result[0] not in alts):
                        alts.append(result[0])
                        steps_list.append(result[1])
        if len(alts) == 1:
            return alts[0], steps_list[0]
        elif alts:
            doable_rule = []
            doable_steps = []
            for i in range(len(steps_list)):
                # print(alts[i])
                k = contains_dont_know(alts[i])
                if not k:
                    doable_rule.append(alts[i])
                    doable_steps.append(steps_list[i])
            if doable_rule:
                return AlternativeRule(doable_rule, *integral), doable_steps
            else:
                return AlternativeRule(alts, *integral), steps_list
    return _alternatives

def constant_rule(integral):
    #    print("constant_rule", integral)
    return ConstantRule(integral.integrand, *integral), [(integral, "constant_rule")]

def power_rule(integral):
    #    print("power_rule", integral)
    integrand, symbol = integral
    base, expt = integrand.as_base_exp()

    if symbol not in expt.free_symbols and isinstance(base, Symbol):
        if simplify(expt + 1) == 0:
            return ReciprocalRule(base, integrand, symbol), [(integral, "power_rule")]
        return PowerRule(base, expt, integrand, symbol), [(integral, "power_rule")]
    elif symbol not in base.free_symbols and isinstance(expt, Symbol):
        rule = ExpRule(base, expt, integrand, symbol)

        if fuzzy_not(log(base).is_zero):
            return rule, [(integral, "power_rule")]
        elif log(base).is_zero:
            return ConstantRule(1, 1, symbol), [(integral, "power_rule")]

        return PiecewiseRule([
            (rule, Ne(log(base), 0)),
            (ConstantRule(1, 1, symbol), True)
        ], integrand, symbol), [(integral, "power_rule")]

def exp_rule(integral):
    #    print("exp_rule", integral)
    integrand, symbol = integral
    if isinstance(integrand.args[0], Symbol):
        return ExpRule(E, integrand.args[0], integrand, symbol), [(integral, "exp_rule")]


def orthogonal_poly_rule(integral):
    #    print("orthogonal_poly_rule", integral)
    orthogonal_poly_classes = {
        jacobi: JacobiRule,
        gegenbauer: GegenbauerRule,
        chebyshevt: ChebyshevTRule,
        chebyshevu: ChebyshevURule,
        legendre: LegendreRule,
        hermite: HermiteRule,
        laguerre: LaguerreRule,
        assoc_laguerre: AssocLaguerreRule
    }
    orthogonal_poly_var_index = {
        jacobi: 3,
        gegenbauer: 2,
        assoc_laguerre: 2
    }
    integrand, symbol = integral
    for klass in orthogonal_poly_classes:
        if isinstance(integrand, klass):
            var_index = orthogonal_poly_var_index.get(klass, 1)
            if (integrand.args[var_index] is symbol and not
            any(v.has(symbol) for v in integrand.args[:var_index])):
                args = integrand.args[:var_index] + (integrand, symbol)
                return orthogonal_poly_classes[klass](*args), [(integral, "orthogonal_poly_rule")]


_special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = []
_wilds = []
_symbol = Dummy('x')


def special_function_rule(integral):
    #    print("special_function_rule", integral)
    integrand, symbol = integral
    if not _special_function_patterns:
        a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero])
        b = Wild('b', exclude=[_symbol])
        c = Wild('c', exclude=[_symbol])
        d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero])
        e = Wild('e', exclude=[_symbol], properties=[
            lambda x: not (x.is_nonnegative and x.is_integer)])
        _wilds.extend((a, b, c, d, e))
        # patterns consist of a SymPy class, a wildcard expr, an optional
        # condition coded as a lambda (when Wild properties are not enough),
        # followed by an applicable rule
        linear_pattern = a*_symbol + b
        quadratic_pattern = a*_symbol**2 + b*_symbol + c
        _special_function_patterns.extend((
            (Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule),
            (Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule),
            (Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule),
            (Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule),
            (Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule),
            (Pow, 1/log(linear_pattern, evaluate=False), None, LiRule),
            (exp, exp(quadratic_pattern, evaluate=False), None, ErfRule),
            (sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule),
            (cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule),
            (Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule),
            (Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule),
            (Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2),
             lambda a, d: a != d, EllipticFRule),
            (Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2),
             lambda a, d: a != d, EllipticERule),
        ))
    _integrand = integrand.subs(symbol, _symbol)
    for type_, pattern, constraint, rule in _special_function_patterns:
        if isinstance(_integrand, type_):
            match = _integrand.match(pattern)
            if match:
                wild_vals = tuple(match.get(w) for w in _wilds
                                  if match.get(w) is not None)
                if constraint is None or constraint(*wild_vals):
                    args = wild_vals + (integrand, symbol)
                    return rule(*args), [(integral, "special_function_rule")]


def _add_degenerate_step(generic_cond, generic_step, degenerate_step):
    if degenerate_step is None:
        return generic_step
    if isinstance(generic_step, PiecewiseRule):
        subfunctions = [(substep, (cond & generic_cond).simplify())
                        for substep, cond in generic_step.subfunctions]
    else:
        subfunctions = [(generic_step, generic_cond)]
    if isinstance(degenerate_step, PiecewiseRule):
        subfunctions += degenerate_step.subfunctions
    else:
        subfunctions.append((degenerate_step, S.true))
    return PiecewiseRule(subfunctions, generic_step.context, generic_step.symbol)


def inverse_trig_rule(integral: IntegralInfo, degenerate=True):
    #    print("inverse_trig_rule", integral)
    """
    Set degenerate=False on recursive call where coefficient of quadratic term
    is assumed non-zero.
    """
    integrand, symbol = integral
    base, exp = integrand.as_base_exp()
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    c = Wild('c', exclude=[symbol, 0])
    match = base.match(a + b*symbol + c*symbol**2)

    if not match:
        return

    def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h):
        u_var = Dummy("u")
        rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2)  # a>0, c>0
        quadratic_base = sqrt(c/a)*(symbol-h)
        constant = 1/sqrt(c)
        u_func = None
        if quadratic_base is not symbol:
            u_func = quadratic_base
            quadratic_base = u_var
        standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2)
        substep = RuleClass(standard_form, quadratic_base)
        if constant != 1:
            substep = ConstantTimesRule(constant, standard_form, substep, constant*standard_form, symbol)
        if u_func is not None:
            substep = URule(u_var, u_func, None, substep, rewritten, symbol)
        if h != 0:
            substep = CompleteSquareRule(rewritten, substep, integrand, symbol)
        return substep

    steps = [(integral, "inverse_trig_rule")]
    a, b, c = [match.get(i, S.Zero) for i in (a, b, c)]
    generic_cond = Ne(c, 0)
    if not degenerate or generic_cond is S.true:
        degenerate_step = None
    elif b.is_zero:
        degenerate_step = ConstantRule(a ** exp, a ** exp, symbol)
    else:
        itr_slr_res = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol))
        degenerate_step = itr_slr_res[0]
        if type(itr_slr_res[1]) == list:
            steps.extend(itr_slr_res[1])
        else:
            steps.append(itr_slr_res[1])

    if simplify(2*exp + 1) == 0:
        h, k = -b/(2*c), a - b**2/(4*c)  # rewrite base to k + c*(symbol-h)**2
        step = general_rule = ReciprocalSqrtQuadraticRule(a, b, c, integrand, symbol)
        if k.is_real and c.is_real:
            # list of ((rule, base_exp, a, sign_a, b, sign_b), condition)
            possibilities = []
            for args, cond in (  # don't apply ArccoshRule to x**2-1
                    ((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)),  # 1-x**2
                    ((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)),  # 1+x**2
            ):
                if cond is S.true:
                    return make_inverse_trig(*args), steps
                if cond is not S.false:
                    possibilities.append((args, cond))
            if possibilities:
                rules = [(make_inverse_trig(*args), cond) for args, cond in possibilities]
                if not k.is_positive:  # conditions are not thorough, need fall back rule
                    rules.append((general_rule, S.true))
                step = PiecewiseRule(rules, integrand, symbol)
            else:
                step = general_rule
        return_rule_1 = _add_degenerate_step(generic_cond, step, degenerate_step)
        return return_rule_1, steps
    if exp == S.Half:
        step = SqrtQuadraticRule(a, b, c, integrand, symbol)
        return_rule_2 = _add_degenerate_step(generic_cond, step, degenerate_step)
        return return_rule_2, steps


def add_rule(integral):
    #    print("add_rule", integral)
    integrand, symbol = integral
    steps = [(integral, "add_rule")]
    results = []
    for g in integrand.as_ordered_terms():
        add_res = integral_steps(g, symbol)
        results.append(add_res[0])
        if type(add_res[1]) == list:
            steps.extend(add_res[1])
        else:
            steps.append(add_res[1])

    return None if None in results else AddRule(results, integrand, symbol), steps


def mul_rule(integral: IntegralInfo):
    #    print("mul_rule", integral)
    integrand, symbol = integral

    # Constant times function case
    coeff, f = integrand.as_independent(symbol)
    if coeff != 1:
        steps = [(integral, "mul_rule")]
        mul_res = integral_steps(f, symbol)
        next_step = mul_res[0]
        if type(mul_res[1]) == list:
            steps.extend(mul_res[1])
        else:
            steps.append(mul_res[1])

        if next_step is not None:
            return ConstantTimesRule(coeff, f, next_step, integrand, symbol), steps


def _parts_rule(integrand, symbol):
    #    print("_parts_rule", integrand, symbol)
    # LIATE rule:
    # log, inverse trig, algebraic, trigonometric, exponential
    def pull_out_algebraic(integrand):
        integrand = integrand.cancel().together()
        # iterating over Piecewise args would not work here
        algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul
                     else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)])
        if algebraic:
            u = Mul(*algebraic)
            dv = (integrand / u).cancel()
            return u, dv

    def pull_out_u(*functions):
        def pull_out_u_rl(integrand):
            if any(integrand.has(f) for f in functions):
                args = [arg for arg in integrand.args
                        if any(isinstance(arg, cls) for cls in functions)]
                if args:
                    u = reduce(lambda a,b: a*b, args)
                    dv = integrand / u
                    return u, dv

        return pull_out_u_rl

    liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions),
                   pull_out_algebraic, pull_out_u(sin, cos),
                   pull_out_u(exp)]


    dummy = Dummy("temporary")
    # we can integrate log(x) and atan(x) by setting dv = 1
    if isinstance(integrand, (log, *inverse_trig_functions)):
        integrand = dummy * integrand

    for index, rule in enumerate(liate_rules):
        result = rule(integrand)

        if result:
            u, dv = result

            # Don't pick u to be a constant if possible
            if symbol not in u.free_symbols and not u.has(dummy):
                return

            u = u.subs(dummy, 1)
            dv = dv.subs(dummy, 1)

            # Don't pick a non-polynomial algebraic to be differentiated
            if rule == pull_out_algebraic and not u.is_polynomial(symbol):
                return
            # Don't trade one logarithm for another
            if isinstance(u, log):
                rec_dv = 1/dv
                if (rec_dv.is_polynomial(symbol) and
                        degree(rec_dv, symbol) == 1):
                    return

            # Can integrate a polynomial times OrthogonalPolynomial
            if rule == pull_out_algebraic:
                if dv.is_Derivative or dv.has(TrigonometricFunction) or \
                        isinstance(dv, OrthogonalPolynomial):
                    v_step, _ = integral_steps(dv, symbol)
                    if contains_dont_know(v_step):
                        return
                    else:
                        du = u.diff(symbol)
                        v = _manualintegrate(v_step)
                        return u, dv, v, du, v_step

            # make sure dv is amenable to integration
            accept = False
            if index < 2:  # log and inverse trig are usually worth trying
                accept = True
            elif (rule == pull_out_algebraic and dv.args and
                  all(isinstance(a, (sin, cos, exp))
                      for a in dv.args)):
                accept = True
            else:
                for lrule in liate_rules[index + 1:]:
                    r = lrule(integrand)
                    if r and r[0].subs(dummy, 1).equals(dv):
                        accept = True
                        break

            if accept:
                du = u.diff(symbol)
                v_step, _ = integral_steps(simplify(dv), symbol)
                if not contains_dont_know(v_step):
                    v = _manualintegrate(v_step)
                    return u, dv, v, du, v_step


def parts_rule(integral):
    #    print("parts_rule", integral)
    integrand, symbol = integral
    constant, integrand = integrand.as_coeff_Mul()

    result = _parts_rule(integrand, symbol)

    steps = []
    if result:
        u, dv, v, du, v_step = result
        debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step))
        steps.append(result)

        if isinstance(v, Integral):
            return

        # Set a limit on the number of times u can be used
        if isinstance(u, (sin, cos, exp, sinh, cosh)):
            cachekey = u.xreplace({symbol: _cache_dummy})
            if _parts_u_cache[cachekey] > 2:
                return
            _parts_u_cache[cachekey] += 1

        # Try cyclic integration by parts a few times
        parts_steps = [(integral, "parts_rule", u)]
        for _ in range(4):
            debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand))
            coefficient = ((v * du) / integrand).cancel()
            if coefficient == 1:
                break

            if symbol not in coefficient.free_symbols:
                rule = CyclicPartsRule(
                    [PartsRule(u, dv, v_step, None, None, None)
                     for (u, dv, v, du, v_step) in steps],
                    (-1) ** len(steps) * coefficient,
                    integrand, symbol
                )
                if (constant != 1) and rule:
                    rule = ConstantTimesRule(constant, integrand, rule,
                                             constant * integrand, symbol)
                return rule, parts_steps

            # _parts_rule is sensitive to constants, factor it out
            next_constant, next_integrand = (v * du).as_coeff_Mul()
            result = _parts_rule(next_integrand, symbol)

            if result:
                u, dv, v, du, v_step = result
                u *= next_constant
                parts_steps.append((IntegralInfo(next_integrand, symbol), "parts_rule", u))
                du *= next_constant
                steps.append((u, dv, v, du, v_step))
            else:
                break

    def make_second_step(steps, integrand):
        if steps:
            u, dv, v, du, v_step = steps[0]
            mss_mss_res = make_second_step(steps[1:], v * du)
            mss_mss_res_steps = []
            if len(mss_mss_res) == 2:
                mss_mss_res_steps = mss_mss_res[1]
            return PartsRule(u, dv, v_step,
                             mss_mss_res[0],
                             integrand, symbol), mss_mss_res_steps
        mss_res = integral_steps(integrand, symbol)
        return mss_res[0], mss_res[1]

    if steps:
        u, dv, v, du, v_step = steps[0]
        mss_pr_res = make_second_step(steps[1:], v * du)
        parts_steps.extend(mss_pr_res[1])
        rule = PartsRule(u, dv, v_step,
                         mss_pr_res[0],
                         integrand, symbol)
        if (constant != 1) and rule:
            rule = ConstantTimesRule(constant, integrand, rule,
                                     constant * integrand, symbol)
        return rule, parts_steps


def trig_rule(integral):
    #    print("trig_rule", integral)
    integrand, symbol = integral
    if isinstance(integrand, (sin, cos)):
        arg = integrand.args[0]

        if not isinstance(arg, Symbol):
            return  # perhaps a substitution can deal with it

        if isinstance(integrand, sin):
            func = 'sin'
        else:
            func = 'cos'

        return TrigRule(func, arg, integrand, symbol), [(integral, "trig_rule")]

    if integrand == sec(symbol)**2:
        return TrigRule('sec**2', symbol, integrand, symbol), [(integral, "trig_rule")]
    elif integrand == csc(symbol)**2:
        return TrigRule('csc**2', symbol, integrand, symbol), [(integral, "trig_rule")]

    if isinstance(integrand, tan):
        rewritten = sin(*integrand.args) / cos(*integrand.args)
    elif isinstance(integrand, cot):
        rewritten = cos(*integrand.args) / sin(*integrand.args)
    elif isinstance(integrand, sec):
        arg = integrand.args[0]
        rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) /
                     (sec(arg) + tan(arg)))
    elif isinstance(integrand, csc):
        arg = integrand.args[0]
        rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) /
                     (csc(arg) + cot(arg)))
    else:
        return
    trig_res = integral_steps(rewritten, symbol)
    steps = [(integral, "trig_rule")]
    if type(trig_res[1]) == list:
        steps.extend(trig_res[1])
    else:
        steps.append(trig_res[1])
    return RewriteRule(
        rewritten,
        trig_res[0],
        integrand, symbol
    ), steps

def trig_product_rule(integral):
    #    print("trig_product_rule", integral)
    integrand, symbol = integral

    sectan = sec(symbol) * tan(symbol)
    q = integrand / sectan

    if symbol not in q.free_symbols:
        rule = TrigRule('sec*tan', symbol, sectan, symbol)
        if q != 1 and rule:
            rule = ConstantTimesRule(q, sectan, rule, integrand, symbol)

        return rule, [(integral, "trig_product_rule")]

    csccot = -csc(symbol) * cot(symbol)
    q = integrand / csccot

    if symbol not in q.free_symbols:
        rule = TrigRule('csc*cot', symbol, csccot, symbol)
        if q != 1 and rule:
            rule = ConstantTimesRule(q, csccot, rule, integrand, symbol)

        return rule, [(integral, "trig_product_rule")]


def quadratic_denom_rule(integral):
    #    print("quadratic_denom_rule", integral)
    integrand, symbol = integral
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    c = Wild('c', exclude=[symbol])

    match = integrand.match(a / (b * symbol ** 2 + c))

    if match:
        a, b, c = match[a], match[b], match[c]
        general_rule = ArctanRule(a, b, c, integrand, symbol)
        if b.is_extended_real and c.is_extended_real:
            positive_cond = c/b > 0
            if positive_cond is S.true:
                return general_rule, [(integral, "quadratic_denom_rule")]
            coeff = a/(2*sqrt(-c)*sqrt(b))
            constant = sqrt(-c/b)
            r1 = 1/(symbol-constant)
            r2 = 1/(symbol+constant)
            log_steps = [ReciprocalRule(symbol-constant, r1, symbol),
                         ConstantTimesRule(-1, r2, ReciprocalRule(symbol+constant, r2, symbol), -r2, symbol)]
            rewritten = sub = r1 - r2
            negative_step = AddRule(log_steps, sub, symbol)
            if coeff != 1:
                rewritten = Mul(coeff, sub, evaluate=False)
                negative_step = ConstantTimesRule(coeff, sub, negative_step, rewritten, symbol)
            negative_step = RewriteRule(rewritten, negative_step, integrand, symbol)
            if positive_cond is S.false:
                return negative_step, [(integral, "quadratic_denom_rule")]
            return PiecewiseRule([(general_rule, positive_cond), (negative_step, S.true)], integrand, symbol), [(integral, "quadratic_denom_rule")]
        return general_rule, [(integral, "quadratic_denom_rule")]

    d = Wild('d', exclude=[symbol])
    match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d))
    if match2:
        b, c =  match2[b], match2[c]
        if b.is_zero:
            return
        u = Dummy('u')
        u_func = symbol + c/(2*b)
        integrand2 = integrand.subs(symbol, u - c / (2*b))
        qdr_res_1 = integral_steps(integrand2, u)
        next_step = qdr_res_1[0]
        if next_step:
            steps = [(integral, "quadratic_denom_rule")]
            if type(qdr_res_1[1]) == list:
                steps.extend(qdr_res_1[1])
            else:
                steps.append(qdr_res_1[1])
            return URule(u, u_func, None, next_step, integrand2, symbol), steps
        else:
            return
    e = Wild('e', exclude=[symbol])
    match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e))
    if match3:
        a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e]
        if c.is_zero:
            return
        denominator = c * symbol**2 + d * symbol + e
        const =  a/(2*c)
        numer1 =  (2*c*symbol+d)
        numer2 = - const*d + b
        u = Dummy('u')
        qdr_res_2 = integral_steps(u**(-1), u)
        steps = [(integral, "quadratic_denom_rule")]
        if type(qdr_res_2[1]) == list:
            steps.extend(qdr_res_2[1])
        else:
            steps.append(qdr_res_2[1])
        step1 = URule(u,
                      denominator,
                      const,
                      qdr_res_2[0],
                      integrand,
                      symbol)
        if const != 1:
            steps = [(integral, "quadratic_denom_rule")]
            step1 = ConstantTimesRule(const,
                                      numer1/denominator,
                                      step1,
                                      const*numer1/denominator,
                                      symbol)
        if numer2.is_zero:
            return step1, steps
        qdr_res_3 = integral_steps(numer2/denominator, symbol)
        step2 = qdr_res_3[0]
        if type(qdr_res_3[1]) == list:
            steps.extend(qdr_res_3[1])
        else:
            steps.append(qdr_res_3[1])
        substeps = AddRule([step1, step2], integrand, symbol)
        rewriten = const*numer1/denominator+numer2/denominator
        return RewriteRule(rewriten, substeps, integrand, symbol), steps

    return


def sqrt_linear_rule(integral: IntegralInfo):
    #    print("sqrt_linear_rule", integral)
    """
    Substitute common (a+b*x)**(1/n)
    """
    integrand, x = integral
    a = Wild('a', exclude=[x])
    b = Wild('b', exclude=[x, 0])
    a0 = b0 = 0
    bases, qs, bs = [], [], []
    for pow_ in integrand.find(Pow):  # collect all (a+b*x)**(p/q)
        base, exp_ = pow_.base, pow_.exp
        if exp_.is_Integer or x not in base.free_symbols:  # skip 1/x and sqrt(2)
            continue
        if not exp_.is_Rational:  # exclude x**pi
            return
        match = base.match(a+b*x)
        if not match:  # skip non-linear
            continue  # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x)
        a1, b1 = match[a], match[b]
        if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative:  # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x)
            return
        if b0 == 0 or (b0/b1 > 1) is S.true:  # choose the latter of sqrt(2*x) and sqrt(x) as representative
            a0, b0 = a1, b1
        bases.append(base)
        bs.append(b1)
        qs.append(exp_.q)
    if b0 == 0:  # no such pattern found
        return
    q0: Integer = lcm_list(qs)
    u_x = (a0 + b0*x)**(1/q0)
    u = Dummy("u")
    substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q)
                                  for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0)
    slr_res_1 = integral_steps(substituted*u**(q0-1)*q0/b0, u)
    substep = slr_res_1[0]
    steps = [(integral, "sqrt_linear_rule")]
    if type(slr_res_1[1]) == list:
        steps.extend(slr_res_1[1])
    else:
        steps.append(slr_res_1[1])
    if not contains_dont_know(substep):
        steps = [(integral, "sqrt_linear_rule")]
        step = URule(u, u_x, None, substep, integrand, x)
        generic_cond = Ne(b0, 0)
        if generic_cond is not S.true:  # possible degenerate case
            simplified = integrand.subs({b: 0 for b in bs})
            slr_res_2 = integral_steps(simplified, x)
            degenerate_step = slr_res_2[0]
            if type(slr_res_2[1]) == list:
                steps.extend(slr_res_2[1])
            else:
                steps.append(slr_res_2[1])
            step = PiecewiseRule([(step, generic_cond), (degenerate_step, S.true)], integrand, x)
        return step, steps


def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True):
    #    print("sqrt_quadratic_rule", integral)
    integrand, x = integral
    a = Wild('a', exclude=[x])
    b = Wild('b', exclude=[x])
    c = Wild('c', exclude=[x, 0])
    f = Wild('f')
    n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd])
    match = integrand.match(f*sqrt(a+b*x+c*x**2)**n)
    if not match:
        return
    a, b, c, f, n = match[a], match[b], match[c], match[f], match[n]
    f_poly = f.as_poly(x)
    if f_poly is None:
        return
    steps = [(integral, "sqrt_quadratic_rule")]
    generic_cond = Ne(c, 0)
    if not degenerate or generic_cond is S.true:
        degenerate_step = None
    elif b.is_zero:
        sqr_res_1 = integral_steps(f*sqrt(a)**n, x)
        degenerate_step = sqr_res_1[0]
        if type(sqr_res_1[1]) == list:
            steps.extend(sqr_res_1[1])
        else:
            steps.append(sqr_res_1[1])
    else:
        sqr_res_2 = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x))
        degenerate_step = sqr_res_2[0]
        if type(sqr_res_2[1]) == list:
            steps.extend(sqr_res_2[1])
        else:
            steps.append(sqr_res_2[1])

    def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr):
        denom = sqrt(a+b*x+c*x**2)
        deg = numer_poly.degree()
        if deg <= 1:
            # integrand == (d+e*x)/sqrt(a+b*x+c*x**2)
            e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr())
            # rewrite numerator to A*(2*c*x+b) + B
            A = e/(2*c)
            B = d-A*b
            pre_substitute = (2*c*x+b)/denom
            constant_step = linear_step = None
            if A != 0:
                u = Dummy("u")
                pow_rule = PowerRule(u, -S.Half, 1/sqrt(u), u)
                linear_step = URule(u, a+b*x+c*x**2, None, pow_rule, pre_substitute, x)
                if A != 1:
                    linear_step = ConstantTimesRule(A, pre_substitute, linear_step, A*pre_substitute, x)
            if B != 0:
                constant_step, _ = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False)
                if B != 1:
                    constant_step = ConstantTimesRule(B, 1/denom, constant_step, B/denom, x)
            if linear_step and constant_step:
                add = Add(A*pre_substitute, B/denom, evaluate=False)
                step = RewriteRule(add, AddRule([linear_step, constant_step], add, x), integrand, x)
            else:
                step = linear_step or constant_step
        else:
            coeffs = numer_poly.all_coeffs()
            step = SqrtQuadraticDenomRule(a, b, c, coeffs, integrand, x)
        return step

    if n > 0:  # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s)
        numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2)
        rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2)
        substep = sqrt_quadratic_denom_rule(numer_poly, rewritten)
        generic_step = RewriteRule(rewritten, substep, integrand, x)
    elif n == -1:
        generic_step = sqrt_quadratic_denom_rule(f_poly, integrand)
    else:
        return  # todo: handle n < -1 case
    return_rules = _add_degenerate_step(generic_cond, generic_step, degenerate_step)
    return return_rules, steps

def hyperbolic_rule(integral: tuple[Expr, Symbol]):
    #    print("hyperbolic_rule", integral)
    integrand, symbol = integral
    if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol:
        if integrand.func == sinh:
            return HyperbolicRule('sinh', symbol, integrand, symbol), [(integral, "hyperbolic_rule")]
        if integrand.func == cosh:
            return HyperbolicRule('cosh', symbol, integrand, symbol), [(integral, "hyperbolic_rule")]
        u = Dummy('u')
        if integrand.func == tanh:
            rewritten = sinh(symbol)/cosh(symbol)
            return RewriteRule(rewritten,
                               URule(u, cosh(symbol), None,
                                     ReciprocalRule(u, 1/u, u), rewritten, symbol), integrand, symbol), [(integral, "hyperbolic_rule")]
        if integrand.func == coth:
            rewritten = cosh(symbol)/sinh(symbol)
            return RewriteRule(rewritten,
                               URule(u, sinh(symbol), None,
                                     ReciprocalRule(u, 1/u, u), rewritten, symbol), integrand, symbol), [(integral, "hyperbolic_rule")]
        else:
            rewritten = integrand.rewrite(tanh)
            if integrand.func == sech:
                return RewriteRule(rewritten,
                                   URule(u, tanh(symbol/2), None,
                                         ArctanRule(S(2), S.One, S.One, 2/(u**2 + 1), u), rewritten, symbol), integrand, symbol), [(integral, "hyperbolic_rule")]
            if integrand.func == csch:
                return RewriteRule(rewritten,
                                   URule(u, tanh(symbol/2), None,
                                         ReciprocalRule(u, 1/u, u), rewritten, symbol), integrand, symbol), [(integral, "hyperbolic_rule")]

@cacheit
def make_wilds(symbol):
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])
    n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)])

    return a, b, m, n

@cacheit
def sincos_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sin(a*symbol)**m * cos(b*symbol)**n

    return pattern, a, b, m, n

@cacheit
def tansec_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = tan(a*symbol)**m * sec(b*symbol)**n

    return pattern, a, b, m, n

@cacheit
def cotcsc_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = cot(a*symbol)**m * csc(b*symbol)**n

    return pattern, a, b, m, n

@cacheit
def heaviside_pattern(symbol):
    m = Wild('m', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    g = Wild('g')
    pattern = Heaviside(m*symbol + b) * g

    return pattern, m, b, g

def uncurry(func):
    def uncurry_rl(args):
        return func(*args)
    return uncurry_rl

def trig_rewriter(rewrite, rule_called):
    def trig_rewriter_rl(args):
        a, b, m, n, integrand, symbol = args
        rewritten = rewrite(a, b, m, n, integrand, symbol)
        if rewritten != integrand:
            steps = [(IntegralInfo(rewritten, symbol), rule_called)]
            tr_res = integral_steps(rewritten, symbol)
            if type(tr_res[1]) == list:
                steps.extend(tr_res[1])
            else:
                steps.append(tr_res[1])
            return RewriteRule(
                rewritten,
                tr_res[0],
                integrand, symbol), steps
    return trig_rewriter_rl

sincos_botheven_condition = uncurry(
    lambda a, b, m, n, i, s: m.is_even and n.is_even and
                             m.is_nonnegative and n.is_nonnegative)

sincos_botheven = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) *
                                    (((1 + cos(2*b*symbol)) / 2) ** (n / 2)) ),"trig_sincos_rule")

sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3)

sincos_sinodd = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) *
                                    sin(a*symbol) *
                                    cos(b*symbol) ** n), "trig_sincos_rule")

sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3)

sincos_cosodd = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) *
                                    cos(b*symbol) *
                                    sin(a*symbol) ** m), "trig_sincos_rule")

tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
tansec_seceven = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) *
                                    sec(b*symbol)**2 *
                                    tan(a*symbol) ** m ), "trig_tansec_rule")

tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
tansec_tanodd = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) *
                                    tan(a*symbol) *
                                    sec(b*symbol) ** n ), "trig_tansec_rule")

tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0)
tan_tansquared = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1), "trig_tansec_rule")

cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4)
cotcsc_csceven = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) *
                                    csc(b*symbol)**2 *
                                    cot(a*symbol) ** m ), "trig_cotcsc_rule")

cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd)
cotcsc_cotodd = trig_rewriter(
    lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) *
                                    cot(a*symbol) *
                                    csc(b*symbol) ** n ), "trig_cotcsc_rule")

def trig_sincos_rule(integral):
    #    print("trig_sincos_rule", integral)
    integrand, symbol = integral

    if any(integrand.has(f) for f in (sin, cos)):
        pattern, a, b, m, n = sincos_pattern(symbol)
        match = integrand.match(pattern)
        if not match:
            return
        return multiplexer({
            sincos_botheven_condition: sincos_botheven,
            sincos_sinodd_condition: sincos_sinodd,
            sincos_cosodd_condition: sincos_cosodd
        })(tuple(
            [match.get(i, S.Zero) for i in (a, b, m, n)] +
            [integrand, symbol]))

def trig_tansec_rule(integral):
    #    print("trig_tansec_rule", integral)
    integrand, symbol = integral

    integrand = integrand.subs({
        1 / cos(symbol): sec(symbol)
    })

    if any(integrand.has(f) for f in (tan, sec)):
        pattern, a, b, m, n = tansec_pattern(symbol)
        match = integrand.match(pattern)
        if not match:
            return
        return multiplexer({
            tansec_tanodd_condition: tansec_tanodd,
            tansec_seceven_condition: tansec_seceven,
            tan_tansquared_condition: tan_tansquared
        })(tuple(
            [match.get(i, S.Zero) for i in (a, b, m, n)] +
            [integrand, symbol]))

def trig_cotcsc_rule(integral):
    #    print("trig_cotcsc_rule", integral)
    integrand, symbol = integral
    integrand = integrand.subs({
        1 / sin(symbol): csc(symbol),
        1 / tan(symbol): cot(symbol),
        cos(symbol) / tan(symbol): cot(symbol)
    })

    if any(integrand.has(f) for f in (cot, csc)):
        pattern, a, b, m, n = cotcsc_pattern(symbol)
        match = integrand.match(pattern)
        if not match:
            return
        return multiplexer({
            cotcsc_cotodd_condition: cotcsc_cotodd,
            cotcsc_csceven_condition: cotcsc_csceven
        })(tuple(
            [match.get(i, S.Zero) for i in (a, b, m, n)] +
            [integrand, symbol]))

def trig_sindouble_rule(integral):
    #    print("trig_sindouble_rule", integral)
    integrand, symbol = integral
    a = Wild('a', exclude=[sin(2*symbol)])
    match = integrand.match(sin(2*symbol)*a)
    if match:
        sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol)
        steps = [(integral, "trig_sindouble_rule")]
        tsinr_res = integral_steps(integrand * sin_double, symbol)
        if type(tsinr_res[1]) == list:
            steps.extend(tsinr_res[1])
        else:
            steps.append(tsinr_res[1])
        return tsinr_res[0], steps

def trig_powers_products_rule(integral):
    #    print("trig_powers_products_rule", integral)
    return do_one(null_safe(trig_sincos_rule),
                  null_safe(trig_tansec_rule),
                  null_safe(trig_cotcsc_rule),
                  null_safe(trig_sindouble_rule))(integral)

def trig_substitution_rule(integral):
    #    print("trig_substitution_rule", integral)
    integrand, symbol = integral
    A = Wild('a', exclude=[0, symbol])
    B = Wild('b', exclude=[0, symbol])
    theta = Dummy("theta")
    target_pattern = A + B*symbol**2

    matches = integrand.find(target_pattern)
    for expr in matches:
        match = expr.match(target_pattern)
        a = match.get(A, S.Zero)
        b = match.get(B, S.Zero)

        a_positive = ((a.is_number and a > 0) or a.is_positive)
        b_positive = ((b.is_number and b > 0) or b.is_positive)
        a_negative = ((a.is_number and a < 0) or a.is_negative)
        b_negative = ((b.is_number and b < 0) or b.is_negative)
        x_func = None
        if a_positive and b_positive:
            # a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2
            x_func = (sqrt(a)/sqrt(b)) * tan(theta)
            # Do not restrict the domain: tan(theta) takes on any real
            # value on the interval -pi/2 < theta < pi/2 so x takes on
            # any value
            restriction = True
        elif a_positive and b_negative:
            # a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2
            constant = sqrt(a)/sqrt(-b)
            x_func = constant * sin(theta)
            restriction = And(symbol > -constant, symbol < constant)
        elif a_negative and b_positive:
            # b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi
            constant = sqrt(-a)/sqrt(b)
            x_func = constant * sec(theta)
            restriction = And(symbol > -constant, symbol < constant)
        if x_func:
            # Manually simplify sqrt(trig(theta)**2) to trig(theta)
            # Valid due to assumed domain restriction
            substitutions = {}
            for f in [sin, cos, tan,
                      sec, csc, cot]:
                substitutions[sqrt(f(theta)**2)] = f(theta)
                substitutions[sqrt(f(theta)**(-2))] = 1/f(theta)

            replaced = integrand.subs(symbol, x_func).trigsimp()
            replaced = manual_subs(replaced, substitutions)
            if not replaced.has(symbol):
                replaced *= manual_diff(x_func, theta)
                replaced = replaced.trigsimp()
                secants = replaced.find(1/cos(theta))
                if secants:
                    replaced = replaced.xreplace({
                        1/cos(theta): sec(theta)
                    })
                steps = [(integral, "trig_substitution_rule", replaced)]
                tsubsr_res = integral_steps(replaced, theta)
                substep = tsubsr_res[0]
                if type(tsubsr_res[1]) == list:
                    steps.extend(tsubsr_res[1])
                else:
                    steps.append(tsubsr_res[1])
                if not contains_dont_know(substep):
                    return TrigSubstitutionRule(
                        theta, x_func, replaced, substep, restriction,
                        integrand, symbol), steps

def heaviside_rule(integral):
    #    print("heaviside_rule", integral)
    integrand, symbol = integral
    pattern, m, b, g = heaviside_pattern(symbol)
    match = integrand.match(pattern)
    if match and 0 != match[g]:
        # f = Heaviside(m*x + b)*g
        steps = [(integral, "heaviside_rule")]
        hvr_res = integral_steps(match[g], symbol)
        v_step = hvr_res[0]
        if type(hvr_res[1]) == list:
            steps.extend(hvr_res[1])
        else:
            steps.append(hvr_res[1])
        result = _manualintegrate(v_step)
        m, b = match[m], match[b]
        return HeavisideRule(m*symbol + b, -b/m, result, integrand, symbol), steps


def dirac_delta_rule(integral: IntegralInfo):
    #    print("dirac_delta_rule", integral)
    integrand, x = integral
    if len(integrand.args) == 1:
        n = S.Zero
    else:
        n = integrand.args[1]
    if not n.is_Integer or n < 0:
        return
    a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0])
    match = integrand.args[0].match(a+b*x)
    if not match:
        return
    a, b = match[a], match[b]
    generic_cond = Ne(b, 0)
    steps = [(integral, "dirac_delta_rule")]
    if generic_cond is S.true:
        degenerate_step = None
    else:
        degenerate_step = ConstantRule(DiracDelta(a, n), integrand, x)
    generic_step = DiracDeltaRule(n, a, b, integrand, x)
    return_rule = _add_degenerate_step(generic_cond, generic_step, degenerate_step)
    return return_rule, steps


def substitution_rule(integral):
    #    print("substitution_rule", integral)
    integrand, symbol = integral

    u_var = Dummy("u")
    substitutions = find_substitutions(integrand, symbol, u_var)
    count = 0
    if substitutions:
        debug("List of Substitution Rules")
        ways = []
        pre_steps = []
        for u_func, c, substituted in substitutions:
            subsr_res_1 = integral_steps(substituted, u_var)
            subrule = subsr_res_1[0]
            count = count + 1
            debug("Rule {}: {}".format(count, subrule))

            if contains_dont_know(subrule):
                continue
            pre_steps.extend([(integral, "substitution_rule", u_func)])
            if type(subsr_res_1[1]) == list:
                pre_steps.extend(subsr_res_1[1])
            else:
                pre_steps.append(subsr_res_1[1])

            if simplify(c - 1) != 0:
                _, denom = c.as_numer_denom()
                if subrule:
                    subrule = ConstantTimesRule(c, substituted, subrule, c * substituted, u_var)

                if denom.free_symbols:
                    piecewise = []
                    could_be_zero = []

                    if isinstance(denom, Mul):
                        could_be_zero = denom.args
                    else:
                        could_be_zero.append(denom)

                    for expr in could_be_zero:
                        if not fuzzy_not(expr.is_zero):
                            subsr_res_2 = integral_steps(manual_subs(integrand, expr, 0), symbol)
                            substep = subsr_res_2[0]
                            if type(subsr_res_2[1]) == list:
                                pre_steps.extend(subsr_res_2[1])
                            else:
                                pre_steps.append(subsr_res_2[1])
                            if substep:
                                piecewise.append((
                                    substep,
                                    Eq(expr, 0)
                                ))
                    piecewise.append((subrule, True))
                    subrule = PiecewiseRule(piecewise, substituted, symbol)

            ways.append(URule(u_var, u_func, c,
                              subrule,
                              integrand, symbol))
        if len(ways) > 1:
            return AlternativeRule(ways, integrand, symbol), pre_steps
        elif ways:
            return ways[0], pre_steps


partial_fractions_rule = rewriter(
    lambda integrand, symbol: integrand.is_rational_function(),
    lambda integrand, symbol: integrand.apart(symbol),
    "partial_fractions_rule")

cancel_rule = rewriter(
    # lambda integrand, symbol: integrand.is_algebraic_expr(),
    # lambda integrand, symbol: isinstance(integrand, Mul),
    lambda integrand, symbol: True,
    lambda integrand, symbol: integrand.cancel(),
    "cancel_rule")

distribute_expand_rule = rewriter(
    lambda integrand, symbol: (
            all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)
            or isinstance(integrand, Pow)
            or isinstance(integrand, Mul)),
    lambda integrand, symbol: integrand.expand(),
    "distribute_expand_rule")

trig_expand_rule = rewriter(
    # If there are trig functions with different arguments, expand them
    lambda integrand, symbol: (
            len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1),
    lambda integrand, symbol: integrand.expand(trig=True),
    "trig_expand_rule")

def derivative_rule(integral):
    #    print("derivative_rule", integral)
    integrand = integral[0]
    diff_variables = integrand.variables
    undifferentiated_function = integrand.expr
    integrand_variables = undifferentiated_function.free_symbols

    if integral.symbol in integrand_variables:
        if integral.symbol in diff_variables:
            return DerivativeRule(*integral), [(integral, "derivative_rule")]
        else:
            return DontKnowRule(integrand, integral.symbol), [(integral, "dont_know_rule")]
    else:
        return ConstantRule(integral.integrand, *integral), [(integral, "derivative_rule")]

def rewrites_rule(integral):
    #    print("rewrites_rule", integral)
    integrand, symbol = integral

    if integrand.match(1/cos(symbol)):
        rewritten = integrand.subs(1/cos(symbol), sec(symbol))
        steps = [(integral, "rewrites_rule")]
        rewr_res = integral_steps(rewritten, symbol)
        if type(rewr_res[1]) == list:
            steps.extend(rewr_res[1])
        else:
            steps.append(rewr_res[1])
        return RewriteRule(rewritten, rewr_res[0], integrand, symbol), steps

def fallback_rule(integral):
    #    print("fallback_rule", integrand)
    return DontKnowRule(*integral), [(integral, "dont_know_rule")]

# Cache is used to break cyclic integrals.
# Need to use the same dummy variable in cached expressions for them to match.
# Also record "u" of integration by parts, to avoid infinite repetition.
_integral_cache: dict[Expr, Expr | None] = {}
_parts_u_cache: dict[Expr, int] = defaultdict(int)
_cache_dummy = Dummy("z")

def integral_steps(integrand, symbol, **options):
    """Returns the steps needed to compute an integral.

    Explanation
    ===========

    This function attempts to mirror what a student would do by hand as
    closely as possible.

    SymPy Gamma uses this to provide a step-by-step explanation of an
    integral. The code it uses to format the results of this function can be
    found at
    https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py.

    Examples
    ========

    >>> from sympy import exp, sin
    >>> from sympy.integrals.manualintegrate import integral_steps
    >>> from sympy.abc import x
    >>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \
    # doctest: +NORMALIZE_WHITESPACE
    URule(u_var=_u, u_func=exp(x), constant=1,
    substep=ArctanRule(a=1, b=1, c=1, context=1/(_u**2 + 1), symbol=_u),
    context=exp(x)/(exp(2*x) + 1), symbol=x)
    >>> print(repr(integral_steps(sin(x), x))) \
    # doctest: +NORMALIZE_WHITESPACE
    TrigRule(func='sin', arg=x, context=sin(x), symbol=x)
    >>> print(repr(integral_steps((x**2 + 3)**2, x))) \
    # doctest: +NORMALIZE_WHITESPACE
    RewriteRule(rewritten=x**4 + 6*x**2 + 9,
    substep=AddRule(substeps=[PowerRule(base=x, exp=4, context=x**4, symbol=x),
        ConstantTimesRule(constant=6, other=x**2,
            substep=PowerRule(base=x, exp=2, context=x**2, symbol=x),
                context=6*x**2, symbol=x),
        ConstantRule(constant=9, context=9, symbol=x)],
    context=x**4 + 6*x**2 + 9, symbol=x), context=(x**2 + 3)**2, symbol=x)


    Returns
    =======

    rule : namedtuple
        The first step; most rules have substeps that must also be
        considered. These substeps can be evaluated using ``manualintegrate``
        to obtain a result.

    """
    cachekey = integrand.xreplace({symbol: _cache_dummy})
    if cachekey in _integral_cache:
        if _integral_cache[cachekey] is None:
            # Stop this attempt, because it leads around in a loop
            return DontKnowRule(integrand, symbol), [(IntegralInfo(integrand, symbol), "dont_know_rule")]
        else:
            # TODO: This is for future development, as currently
            # _integral_cache gets no values other than None
            return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol),
                    symbol)
    else:
        _integral_cache[cachekey] = None

    integral = IntegralInfo(integrand, symbol)

    def key(integral):
        integrand = integral.integrand

        if symbol not in integrand.free_symbols:
            return Number
        for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial):
            if isinstance(integrand, cls):
                return cls
        return type(integrand)

    def integral_is_subclass(*klasses):
        def _integral_is_subclass(integral):
            k = key(integral)
            return k and issubclass(k, klasses)
        return _integral_is_subclass

    #    print("integrand", integrand)
    result = do_one(
        null_safe(special_function_rule),
        null_safe(switch(key, {
            Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule),
                        null_safe(sqrt_linear_rule),
                        null_safe(quadratic_denom_rule)),
            Symbol: power_rule,
            exp: exp_rule,
            Add: add_rule,
            Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule),
                        null_safe(heaviside_rule), null_safe(quadratic_denom_rule),
                        null_safe(sqrt_linear_rule),
                        null_safe(sqrt_quadratic_rule)),
            Derivative: derivative_rule,
            TrigonometricFunction: trig_rule,
            Heaviside: heaviside_rule,
            DiracDelta: dirac_delta_rule,
            OrthogonalPolynomial: orthogonal_poly_rule,
            Number: constant_rule
        })),
        do_one(
            null_safe(trig_rule),
            null_safe(hyperbolic_rule),
            null_safe(alternatives(
                rewrites_rule,
                substitution_rule,
                condition(
                    integral_is_subclass(Mul, Pow),
                    partial_fractions_rule),
                condition(
                    integral_is_subclass(Mul, Pow),
                    cancel_rule),
                condition(
                    integral_is_subclass(Mul, log,
                                         *inverse_trig_functions),
                    parts_rule),
                condition(
                    integral_is_subclass(Mul, Pow),
                    distribute_expand_rule),
                trig_powers_products_rule,
                trig_expand_rule
            )),
            null_safe(trig_substitution_rule)
        ),
        fallback_rule)(integral)
    del _integral_cache[cachekey]
    #    print(result)
    rule, steps = result
    return rule, steps

@evaluates(ConstantRule)
def eval_constant(constant, integrand, symbol):
    return constant * symbol

@evaluates(ConstantTimesRule)
def eval_constanttimes(constant, other, substep, integrand, symbol):
    return constant * _manualintegrate(substep)

@evaluates(PowerRule)
def eval_power(base, exp, integrand, symbol):
    return Piecewise(
        ((base**(exp + 1))/(exp + 1), Ne(exp, -1)),
        (log(base), True),
    )

@evaluates(ExpRule)
def eval_exp(base, exp, integrand, symbol):
    return integrand / log(base)

@evaluates(AddRule)
def eval_add(substeps, integrand, symbol):
    return sum(map(_manualintegrate, substeps))

@evaluates(URule)
def eval_u(u_var, u_func, constant, substep, integrand, symbol):
    result = _manualintegrate(substep)
    if u_func.is_Pow and u_func.exp == -1:
        # avoid needless -log(1/x) from substitution
        result = result.subs(log(u_var), -log(u_func.base))
    return result.subs(u_var, u_func)

@evaluates(PartsRule)
def eval_parts(u, dv, v_step, second_step, integrand, symbol):
    v = _manualintegrate(v_step)

    return u * v - _manualintegrate(second_step)

@evaluates(CyclicPartsRule)
def eval_cyclicparts(parts_rules, coefficient, integrand, symbol):
    coefficient = 1 - coefficient
    result = []

    sign = 1
    for rule in parts_rules:
        result.append(sign * rule.u * _manualintegrate(rule.v_step))
        sign *= -1

    return Add(*result) / coefficient

@evaluates(TrigRule)
def eval_trig(func, arg, integrand, symbol):
    if func == 'sin':
        return -cos(arg)
    elif func == 'cos':
        return sin(arg)
    elif func == 'sec*tan':
        return sec(arg)
    elif func == 'csc*cot':
        return csc(arg)
    elif func == 'sec**2':
        return tan(arg)
    elif func == 'csc**2':
        return -cot(arg)


@evaluates(HyperbolicRule)
def eval_hyperbolic(func: str, arg: Expr, integrand, symbol):
    if func == 'sinh':
        return cosh(arg)
    if func == 'cosh':
        return sinh(arg)


@evaluates(ArctanRule)
def eval_arctan(a, b, c, integrand, symbol):
    return a / b * 1 / sqrt(c / b) * atan(symbol / sqrt(c / b))


@evaluates(ArctanhRule)
def eval_arctanh(a, b, c, integrand, symbol):
    return - a / b * 1 / sqrt(-c / b) * atanh(symbol / sqrt(-c / b))

@evaluates(ReciprocalRule)
def eval_reciprocal(func, integrand, symbol):
    return log(func)

@evaluates(ArcsinRule)
def eval_arcsin(integrand, symbol):
    return asin(symbol)


@evaluates(ArcsinhRule)
def eval_arcsinh(integrand, x):
    return asinh(x)


@evaluates(ReciprocalSqrtQuadraticRule)
def eval_reciprocal_sqrt_quadratic(a, b, c, integrand, x):
    return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)


@evaluates(SqrtQuadraticDenomRule)
def eval_sqrt_quadratic_denom(a, b, c, coeffs: list[Expr], integrand, x):
    # Integrate poly/sqrt(a+b*x+c*x**2) using recursion.
    # coeffs are coefficients of the polynomial.
    # Let I_n = x**n/sqrt(a+b*x+c*x**2), then
    # I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2}
    # where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c)
    # See https://github.com/sympy/sympy/pull/23608 for proof.
    result_coeffs = []
    coeffs = coeffs.copy()
    for i in range(len(coeffs)-2):
        n = len(coeffs)-1-i
        coeff = coeffs[i]/(c*n)
        result_coeffs.append(coeff)
        coeffs[i+1] -= (2*n-1)*b/2*coeff
        coeffs[i+2] -= (n-1)*a*coeff
    d, e = coeffs[-1], coeffs[-2]
    s = sqrt(a+b*x+c*x**2)
    constant = d-b*e/(2*c)
    if constant == 0:
        I0 = 0
    else:
        step, _ = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False)
        I0 = constant*_manualintegrate(step)
    return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i)
                 for i in range(len(result_coeffs))), e/c)*s + I0


@evaluates(SqrtQuadraticRule)
def eval_sqrt_quadratic(a, b, c, integrand, x):
    step, _ = sqrt_quadratic_rule(IntegralInfo(integrand, x), degenerate=False)
    return _manualintegrate(step)


@evaluates(AlternativeRule)
def eval_alternative(alternatives, integrand, symbol):
    return _manualintegrate(alternatives[0])


@evaluates(CompleteSquareRule)
@evaluates(RewriteRule)
def eval_rewrite(rewritten, substep, integrand, symbol):
    return _manualintegrate(substep)


@evaluates(PiecewiseRule)
def eval_piecewise(substeps, integrand, symbol):
    return Piecewise(*[(_manualintegrate(substep), cond)
                       for substep, cond in substeps])

@evaluates(TrigSubstitutionRule)
def eval_trigsubstitution(theta, func, rewritten, substep, restriction, integrand, symbol):
    func = func.subs(sec(theta), 1/cos(theta))
    func = func.subs(csc(theta), 1/sin(theta))
    func = func.subs(cot(theta), 1/tan(theta))

    trig_function = list(func.find(TrigonometricFunction))
    assert len(trig_function) == 1
    trig_function = trig_function[0]
    relation = solve(symbol - func, trig_function)
    assert len(relation) == 1
    numer, denom = fraction(relation[0])

    if isinstance(trig_function, sin):
        opposite = numer
        hypotenuse = denom
        adjacent = sqrt(denom**2 - numer**2)
        inverse = asin(relation[0])
    elif isinstance(trig_function, cos):
        adjacent = numer
        hypotenuse = denom
        opposite = sqrt(denom**2 - numer**2)
        inverse = acos(relation[0])
    elif isinstance(trig_function, tan):
        opposite = numer
        adjacent = denom
        hypotenuse = sqrt(denom**2 + numer**2)
        inverse = atan(relation[0])

    substitution = [
        (sin(theta), opposite/hypotenuse),
        (cos(theta), adjacent/hypotenuse),
        (tan(theta), opposite/adjacent),
        (theta, inverse)
    ]
    return Piecewise(
        (_manualintegrate(substep).subs(substitution).trigsimp(), restriction)
    )

@evaluates(DerivativeRule)
def eval_derivativerule(integrand, symbol):
    # isinstance(integrand, Derivative) should be True
    variable_count = list(integrand.variable_count)
    for i, (var, count) in enumerate(variable_count):
        if var == symbol:
            variable_count[i] = (var, count-1)
            break
    return Derivative(integrand.expr, *variable_count)

@evaluates(HeavisideRule)
def eval_heaviside(harg, ibnd, substep, integrand, symbol):
    # If we are integrating over x and the integrand has the form
    #       Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol)
    # then there needs to be continuity at -b/m == ibnd,
    # so we subtract the appropriate term.
    return Heaviside(harg)*(substep - substep.subs(symbol, ibnd))


@evaluates(DiracDeltaRule)
def eval_dirac_delta(n, a, b, integrand, x):
    if n == 0:
        return Heaviside(a+b*x)/b
    return DiracDelta(a+b*x, n-1)/b


@evaluates(JacobiRule)
def eval_jacobi(n, a, b, integrand, symbol):
    return Piecewise(
        (2*jacobi(n + 1, a - 1, b - 1, symbol)/(n + a + b), Ne(n + a + b, 0)),
        (symbol, Eq(n, 0)),
        ((a + b + 2)*symbol**2/4 + (a - b)*symbol/2, Eq(n, 1)))

@evaluates(GegenbauerRule)
def eval_gegenbauer(n, a, integrand, symbol):
    return Piecewise(
        (gegenbauer(n + 1, a - 1, symbol)/(2*(a - 1)), Ne(a, 1)),
        (chebyshevt(n + 1, symbol)/(n + 1), Ne(n, -1)),
        (S.Zero, True))

@evaluates(ChebyshevTRule)
def eval_chebyshevt(n, integrand, symbol):
    return Piecewise(((chebyshevt(n + 1, symbol)/(n + 1) -
                       chebyshevt(n - 1, symbol)/(n - 1))/2, Ne(Abs(n), 1)),
                     (symbol**2/2, True))

@evaluates(ChebyshevURule)
def eval_chebyshevu(n, integrand, symbol):
    return Piecewise(
        (chebyshevt(n + 1, symbol)/(n + 1), Ne(n, -1)),
        (S.Zero, True))

@evaluates(LegendreRule)
def eval_legendre(n, integrand, symbol):
    return (legendre(n + 1, symbol) - legendre(n - 1, symbol))/(2*n + 1)

@evaluates(HermiteRule)
def eval_hermite(n, integrand, symbol):
    return hermite(n + 1, symbol)/(2*(n + 1))

@evaluates(LaguerreRule)
def eval_laguerre(n, integrand, symbol):
    return laguerre(n, symbol) - laguerre(n + 1, symbol)

@evaluates(AssocLaguerreRule)
def eval_assoclaguerre(n, a, integrand, symbol):
    return -assoc_laguerre(n + 1, a - 1, symbol)

@evaluates(CiRule)
def eval_ci(a, b, integrand, symbol):
    return cos(b)*Ci(a*symbol) - sin(b)*Si(a*symbol)

@evaluates(ChiRule)
def eval_chi(a, b, integrand, symbol):
    return cosh(b)*Chi(a*symbol) + sinh(b)*Shi(a*symbol)

@evaluates(EiRule)
def eval_ei(a, b, integrand, symbol):
    return exp(b)*Ei(a*symbol)

@evaluates(SiRule)
def eval_si(a, b, integrand, symbol):
    return sin(b)*Ci(a*symbol) + cos(b)*Si(a*symbol)

@evaluates(ShiRule)
def eval_shi(a, b, integrand, symbol):
    return sinh(b)*Chi(a*symbol) + cosh(b)*Shi(a*symbol)

@evaluates(ErfRule)
def eval_erf(a, b, c, integrand, symbol):
    if a.is_extended_real:
        return Piecewise(
            (sqrt(S.Pi/(-a))/2 * exp(c - b**2/(4*a)) *
             erf((-2*a*symbol - b)/(2*sqrt(-a))), a < 0),
            (sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) *
             erfi((2*a*symbol + b)/(2*sqrt(a))), True))
    else:
        return sqrt(S.Pi/a)/2 * exp(c - b**2/(4*a)) * \
               erfi((2*a*symbol + b)/(2*sqrt(a)))

@evaluates(FresnelCRule)
def eval_fresnelc(a, b, c, integrand, symbol):
    return sqrt(S.Pi/(2*a)) * (
            cos(b**2/(4*a) - c)*fresnelc((2*a*symbol + b)/sqrt(2*a*S.Pi)) +
            sin(b**2/(4*a) - c)*fresnels((2*a*symbol + b)/sqrt(2*a*S.Pi)))

@evaluates(FresnelSRule)
def eval_fresnels(a, b, c, integrand, symbol):
    return sqrt(S.Pi/(2*a)) * (
            cos(b**2/(4*a) - c)*fresnels((2*a*symbol + b)/sqrt(2*a*S.Pi)) -
            sin(b**2/(4*a) - c)*fresnelc((2*a*symbol + b)/sqrt(2*a*S.Pi)))

@evaluates(LiRule)
def eval_li(a, b, integrand, symbol):
    return li(a*symbol + b)/a

@evaluates(PolylogRule)
def eval_polylog(a, b, integrand, symbol):
    return polylog(b + 1, a*symbol)

@evaluates(UpperGammaRule)
def eval_uppergamma(a, e, integrand, symbol):
    return symbol**e * (-a*symbol)**(-e) * uppergamma(e + 1, -a*symbol)/a

@evaluates(EllipticFRule)
def eval_elliptic_f(a, d, integrand, symbol):
    return elliptic_f(symbol, d/a)/sqrt(a)

@evaluates(EllipticERule)
def eval_elliptic_e(a, d, integrand, symbol):
    return elliptic_e(symbol, d/a)*sqrt(a)

@evaluates(DontKnowRule)
def eval_dontknowrule(integrand, symbol):
    return Integral(integrand, symbol)

def _manualintegrate(rule):
    evaluator = evaluators.get(rule.__class__)
    if not evaluator:
        raise ValueError("Cannot evaluate rule %s" % repr(rule))
    return evaluator(*rule)

from sympy import *
import sys
import os
import numpy as np
import sympy as sp
import torch
import csv
from logging import getLogger
from download import download
from tqdm import tqdm
import random
import time
import io
import signal

sys.path.insert(0, 'SymbolicMathematics')

from src.envs.char_sp import CharSPEnvironment
from src.envs.sympy_utils import remove_root_constant_terms, reduce_coefficients, reindex_coefficients
from src.utils import AttrDict
from src.envs import build_env
from src.model import build_modules

from src.utils import to_cuda
from src.envs.sympy_utils import simplify as smfy

logger = getLogger()

params = params = AttrDict({

    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'reload_data': "prim_fwd.test",
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',
})

env = build_env(params)
x = env.local_dict['x']

if not os.path.isfile('prim_fwd.train'):
    import tarfile
    fname = 'prim_fwd.tar.gz'
    if fname.endswith("tar.gz"):
        tar = tarfile.open(fname, "r:gz")
        tar.extractall()
        tar.close()
    elif fname.endswith("tar"):
        tar = tarfile.open(fname, "r:")
        tar.extractall()
        tar.close()

path = 'prim_fwd.train'
train = True
if path is not None:
    assert os.path.isfile(path)
    logger.info(f"Loading data from {path} ...")
    with io.open(path, mode='r', encoding='utf-8') as f:
        # either reload the entire file, or the first N lines (for the training set)
        if not train:
            lines = [line.rstrip().split('|') for line in f]
        else:
            lines = []
            for i, line in tqdm(enumerate(f)):
                if i == 5000000:
                    break
                lines.append(line.rstrip().split('|'))
    data = [xy.split('\t') for _, xy in lines]
    data = [xy for xy in data if len(xy) == 2]
    logger.info(f"Loaded {len(data)} equations from the disk.")

del lines

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    #print(traceback.format_exc())
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

def check_piecewise_match(intg, intg_act):
    diff_intg = smfy(intg-intg_act, 20)
    if type(diff_intg) == Piecewise:
        for i in range(len(diff_intg.args)):
            if (diff_intg.args[i][0]).is_constant():
                return True
    elif diff_intg.is_constant():
        return True
    return False

def process_1(fun, integral = False):
    try:
        if not integral:
            pre = env.prefix_to_infix(fun[2:])
        else:
            pre = env.prefix_to_infix(fun)
        return env.infix_to_sympy(pre)
    except Exception:
        return "conversion_failed"

def check_1(check_1_input):
    signal.signal(signal.SIGALRM, timeout_handler)
    id, func = check_1_input
    try:
        try:
            func, _ = func
            func = func.split()
            func = process_1(func, False)
        except (TimeoutException):
            print("Timeout Occurred ", id, func)
            return (id, func, -1, -1, -1)
        except (KeyboardInterrupt):
            print("Keyboard Interupt ",id, func)
            return (id, func, -2, -2, -2)
        except Exception as e:
            print("check error Occurred ", e, id, func)
            return (id, func, e, -3, -3)
        else:
            if func != "conversion_failed":
                try:
                    signal.alarm(300)
                    rules_act = sympy.integrals.manualintegrate.integral_steps(func, x)
                    sympy.integrals.manualintegrate._parts_u_cache.clear()
                    signal.alarm(0)
                    rules_actual = [1, rules_act]
                    if "DontKnowRule" in str(rules_actual[1]):
                        if "rule" not in (str(rules_actual[1]).replace("DontKnowRule", "")).lower():   
                            print(id, "dont_know_rule")
                            return (id, func, -6, -6, rules_actual)
                except (TimeoutException):
                    print("Timeout Occurred Original", id, func)
                    return (id, func, -1, -1, -1)
                except Exception as e:
                    print(id, func, e, -4, -4)
                    return (id, func, e, -4, -4)
            else:
                return (id, func, "conversion_failed", -4, -4)
            try:
                signal.alarm(300)
                rule, steps = integral_steps(func, x)
                _parts_u_cache.clear()
                signal.alarm(0)
            except (TimeoutException):
                print("Timeout Occurred ", id, func)
                return (id, func, -1, -1, -1)
            except Exception as e:
                print(id, func, e)
                #print(traceback.format_exc())
                print("==========================================================================================")
                return (id, func, e, -5, -5)
            else:
                if type(steps)!=list:
                    label=[steps]
                else:
                    label = steps

                try:
                    if "rule" in (str(rule).replace("DontKnowRule", "")).lower() and (str(rules_actual[1]) == str(rule) or str(sympy.integrals.manualintegrate._manualintegrate(rules_actual[1])) == str(_manualintegrate(rule)) or smfy(diff(_manualintegrate(rule)) - func, 20).is_constant() or check_piecewise_match(_manualintegrate(rule), sympy.integrals.manualintegrate._manualintegrate(rules_actual[1]))):
                        return (id, func, steps, rule, rules_actual[1])
                    else:
                        if "rule" not in (str(rules_actual[1]).replace("DontKnowRule", "")).lower():
                            print(id, "dont_know_rule")
                            return (id, func, -6, rule, rules_actual)
                        else:
                            print(id, "rule_mismatch")
                            return (id, func, "rule_mismatch", rule, rules_actual)
                except Exception as e:
                    return (id, func, e, -7, -7)
    except Exception as e:
        print("Overall Timeout Occurred ", id, func)
        return (id, func, e, -8, -8)

def multi_run_wrapper(args):
    return check_1(args[0], args[1])

starttime = time.time()
length = len(data)
filename = "dataset_check.csv"

# update last_id
if not os.path.isfile('check.ids'):
    ids_processed = []
else:
    with io.open("check.ids", 'r', encoding='utf-8') as fff:
        for line in fff:
            check = line.rstrip().split(',')
        if check[-1] == "":
            check = check[:-1]
        ids_processed = list(map(int, check))
        del check

print(f"{len(ids_processed)} completed out of {length}")

import csv
import multiprocessing as mp
from multiprocessing import Pool, TimeoutError
from joblib import Parallel, delayed
import threading
import queue

# Define your processing function
'''def process_input(input_data):
    # Process the input here
    index, input_value = input_data
    try:
        with mp.Pool(processes=80) as pool:
            result = pool.apply_async(multi_run_wrapper, (input_data,))
            output = result.get(timeout=30)  # Timeout set to 5 minutes (300 seconds)
    except TimeoutError:
        output = "Timeout occurred"
    return output'''

def call_with_multiprocessing(target_function, parameters, n_jobs=-2, backend='loky'):
    results = Parallel(n_jobs=n_jobs, backend=backend)(delayed(target_function)(**param_dict) for param_dict in tqdm(parameters))
    return results

def whether_id_processed(index_to_check, index_to_store):
    if index_to_check in ids_processed:
        return
    return index_to_store

def write_to_csv(q):
    with open(filename, 'a', newline='') as csvfile, io.open("check.ids", 'a', encoding='utf-8') as ff:
        writer = csv.writer(csvfile)
        while True:
            item = q.get()
            if item is None:
                break
            writer.writerow(item)
            ff.write(str(item[0]))
            ff.write(',')
            q.task_done()

def main(data):
    #batches = [data[i:i+100] for i in range(0, len(data), 100)]

    # Start a separate thread for writing to CSV
    q = queue.Queue()
    writer_thread = threading.Thread(target=write_to_csv, args=(q,))
    writer_thread.start()

    data_indices = list(range(len(data)))
    
    data_to_process_indices = Parallel(n_jobs=-2, backend='multiprocessing')(delayed(whether_id_processed)(**{"index_to_check": i, "index_to_store": j}) for i, j in enumerate(tqdm(data_indices)))
    print(len(data_to_process_indices))
    print(data_to_process_indices[0], data_to_process_indices[-1])
    data_to_process_indices = [x for x in data_to_process_indices if x is not None]
    print(data_to_process_indices[0], data_to_process_indices[-1])
    data_to_process_indices.sort()
    print(data_to_process_indices[0], data_to_process_indices[-1])
    #data_to_process_indices = []
    #for i, index in enumerate(tqdm(data_indices)):
    #    if i in ids_processed:
    #        continue
    #    data_to_process_indices.append(index)
        
    data_to_process = []
    for i in tqdm(data_to_process_indices):
        data_to_process.append(data[i])
    
    del data
    del data_indices
    print("Data and indices to process finalised!")

    with Pool() as pool:
        for result in pool.imap_unordered(check_1, zip(data_to_process_indices, data_to_process)):
            q.put(result)



    # Wait for all items to be written to the CSV file
    q.join()

    # Signal the writer thread to exit
    q.put(None)
    writer_thread.join()

if __name__ == "__main__":
    main(data)
    print('That took {} seconds'.format(time.time() - starttime))
