from dataclasses import dataclass, field
import sympy as sp
from sympy import simplify
from sympy.core.symbol import Wild
from sympy.core.singleton import S
from sympy.core.relational import Ne
from alpha_integrate.synthetic_data.solver import solve_with_timeout as solve

# patterns copied from sympy.integrals.manualintegrate

def make_wilds(symbol):
    a = Wild('a', exclude=[symbol])
    b = Wild('b', exclude=[symbol])
    m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, sp.Integer)])
    n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, sp.Integer)])

    return a, b, m, n


def sincos_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.sin(a*symbol)**m * sp.cos(b*symbol)**n

    return pattern, a, b, m, n

def tansec_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.tan(a*symbol)**m * sp.sec(b*symbol)**n

    return pattern, a, b, m, n

def cotcsc_pattern(symbol):
    a, b, m, n = make_wilds(symbol)
    pattern = sp.cot(a*symbol)**m * sp.csc(b*symbol)**n

    return pattern, a, b, m, n


@dataclass
class Method:
    name: str
    args: list = field(default_factory=list)

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        '''
        Try to apply the method to the subexpression and return the result, if cannot be applied return None
        '''
        pass


class ConstantMethod(Method):
    def __init__(self):
        super().__init__('ConstantRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        # subexpr should be an integral and a constant

        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]
            
            if integrand.is_constant():
                return integrand * symbol
            
        return None
            
class PowerMethod(Method):
    def __init__(self):
        super().__init__('PowerRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        # subexpr should be an integral
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            base, expt = integrand.as_base_exp()

            if symbol not in expt.free_symbols and base == symbol:
                if simplify(expt + 1) != 0:
                    return base**(expt + 1) / (expt + 1)
                
        return None

class ExpMethod(Method):
    def __init__(self):
        super().__init__('ExpRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        # subexpr should be an integral and in form a^x where a is a constant

        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            base, expt = integrand.as_base_exp()

            if base.is_constant() and base != 1 and base > 0:
                if expt == symbol:
                    return base**symbol / sp.log(base)
                
        return None
    
class ConstantTimesMethod(Method):
    def __init__(self):
        super().__init__('ConstantTimesRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            coeff, f = integrand.as_independent(symbol)

            if f == integrand:
                return None
            if coeff == 1 or coeff == 0:
                return None

            if coeff.is_constant():
                return coeff * sp.Integral(f, symbol)
            
        return None
    
class ReciprocalMethod(Method):
    def __init__(self):
        super().__init__('ReciprocalRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            base, expt = integrand.as_base_exp()
            if base == symbol and expt == -1:
                return sp.log(base)
            
        return None
    
class NestedPowMethod(Method):
    def __init__(self):
        super().__init__('NestedPowRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            # copied from sympy.integrals.manualintegrate and edited
            a_ = Wild('a', exclude=[symbol])
            b_ = Wild('b', exclude=[symbol, 0])
            pattern = a_+b_*symbol
            generic_cond = S.true

            class NoMatch(Exception):
                pass

            def _get_base_exp(expr: sp.Expr) -> tuple[sp.Expr, sp.Expr]:
                if not expr.has_free(symbol):
                    raise NoMatch
                if expr.is_Mul:
                    _, terms = expr.as_coeff_mul()
                    if not terms:
                        raise NoMatch
                    results = [_get_base_exp(term) for term in terms]
                    bases = {b for b, _ in results}
                    bases.discard(S.One)
                    if len(bases) == 1:
                        return bases.pop(), sp.Add(*(e for _, e in results))
                    raise NoMatch
                if expr.is_Pow:
                    b, e = expr.base, expr.exp  
                    if e.has_free(symbol):
                        raise NoMatch
                    base_, sub_exp = _get_base_exp(b)
                    return base_, sub_exp * e
                match = expr.match(pattern)
                if match:
                    a, b = match[a_], match[b_]
                    base_ = symbol + a/b
                    nonlocal generic_cond
                    generic_cond = Ne(b, 0)
                    if generic_cond is S.true:
                        return base_, S.One
                raise NoMatch


            try:
                base, exp_ = _get_base_exp(integrand)
            except NoMatch:
                return None

            m = base * integrand
            if exp_ == -1:
                return m * sp.log(base)
            else:
                return m / (exp_ + 1)
            
        return None
    
class ArcsinMethod(Method):
    def __init__(self):
        super().__init__('ArcsinRule')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:
        
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = 1/sp.sqrt(1 - symbol**2)
            if integrand == pattern:
                return sp.asin(symbol)
            
        return None
    
class ArcsinhMethod(Method):
    def __init__(self):
        super().__init__('ArcsinhRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = 1/sp.sqrt(symbol**2 + 1)
            if integrand == pattern:
                return sp.asinh(symbol)
            
        return None
    
class SinMethod(Method):
    def __init__(self):
        super().__init__('SinRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.sin(symbol)
            if integrand == pattern:
                return -sp.cos(symbol)
            
        return None
    
class CosMethod(Method):
    def __init__(self):
        super().__init__('CosRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.cos(symbol)
            if integrand == pattern:
                return sp.sin(symbol)
            
        return None
    
class SecTanMethod(Method):
    def __init__(self):
        super().__init__('SecTanRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.sec(symbol)*sp.tan(symbol)
            if integrand == pattern:
                return sp.sec(symbol)
            
        return None
    
class CscCotMethod(Method):
    def __init__(self):
        super().__init__('CscCotRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.csc(symbol)*sp.cot(symbol)
            if integrand == pattern:
                return -sp.csc(symbol)
            
        return None
    
class Sec2Method(Method):
    def __init__(self):
        super().__init__('Sec2Rule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.sec(symbol)**2
            if integrand == pattern:
                return sp.tan(symbol)
            
        return None

class Csc2Method(Method):
    def __init__(self):
        super().__init__('Csc2Rule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.csc(symbol)**2
            if integrand == pattern:
                return -sp.cot(symbol)
            
        return None

class SinhMethod(Method):
    def __init__(self):
        super().__init__('SinhRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.sinh(symbol)
            if integrand == pattern:
                return sp.cosh(symbol)
            
        return None
    
class CoshMethod(Method):
    def __init__(self):
        super().__init__('CoshRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            pattern = sp.cosh(symbol)
            if integrand == pattern:
                return sp.sinh(symbol)
            
        return None
    

class ArctanMethod(Method):
    def __init__(self):
        super().__init__('ArctanRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            a = Wild('a', exclude=[symbol])
            b = Wild('b', exclude=[symbol])
            c = Wild('c', exclude=[symbol])

            match = integrand.match(a / (b * symbol ** 2 + c))
            if match:
                a, b, c = match[a], match[b], match[c]
                if b.is_extended_real and c.is_extended_real and b!= 0:
                    positive_cond = c/b > 0
                    if positive_cond is S.true:
                        return a/b / sp.sqrt(c/b) * sp.atan(symbol/sp.sqrt(c/b))
                    
        return None
                    

class ReciprocalSqrtQuadraticMethod(Method):
    def __init__(self):
        super().__init__('ReciprocalSqrtQuadraticRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            a = Wild('a', exclude=[symbol])
            b = Wild('b', exclude=[symbol])
            c = Wild('c', exclude=[symbol, 0])

            base, exp = integrand.as_base_exp()
            match = base.match(a + b*symbol + c*symbol**2)
            if match:
                a, b, c = [match.get(i, S.Zero) for i in (a, b, c)]
                if simplify(2*exp + 1) == 0:
                    return sp.log(2*sp.sqrt(c)*sp.sqrt(a+b*symbol+c*symbol**2)+b+2*c*symbol)/sp.sqrt(c)
                
        return None
    
class CiMethod(Method):
    def __init__(self):
        super().__init__('CiRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            a = Wild('a', exclude=[symbol], properties=[lambda x: not x.is_zero])
            b = Wild('b', exclude=[symbol])

            linear_pattern = a*symbol + b
            pattern = sp.cos(linear_pattern, evaluate=False)/symbol

            match = integrand.match(pattern)
            if match:
                a, b = [match.get(i, S.Zero) for i in (a, b)]
                return sp.cos(b)*sp.Ci(a*symbol) - sp.sin(b)*sp.Si(a*symbol)
            
        return None
    
class EiMethod(Method):
    def __init__(self):
        super().__init__('EiRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]

            a = Wild('a', exclude=[symbol], properties=[lambda x: not x.is_zero])
            b = Wild('b', exclude=[symbol])

            linear_pattern = a*symbol + b
            pattern = sp.exp(linear_pattern, evaluate=False)/symbol

            match = integrand.match(pattern)
            if match:
                a, b = [match.get(i, S.Zero) for i in (a, b)]
                return sp.exp(b)*sp.Ei(a*symbol)
            
        return None
    
class UpperGammaMethod(Method):
    def __init__(self):
        super().__init__('UpperGammaRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]
            
            a = Wild('a', exclude=[symbol], properties=[lambda x: not x.is_zero])
            b = Wild('b', exclude=[symbol])
            e = Wild('e', exclude=[symbol], properties=[lambda x: not (x.is_nonnegative and x.is_integer)])

            linear_pattern = a*symbol + b
            pattern = symbol**e*sp.exp(a*symbol, evaluate=False)

            match = integrand.match(pattern)
            if match:
                a, b, e = [match.get(i, S.Zero) for i in (a, b, e)]
                return symbol**e * (-a*symbol)**(-e) * sp.uppergamma(e + 1, -a*symbol)/a
            

        return None
    
class AddMethod(Method):
    def __init__(self):
        super().__init__('AddRule')
        
    def apply(self, subexpr: sp.Expr) -> sp.Expr:
            
        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]
            
            if integrand.is_Add:
                integral_terms = integrand.as_ordered_terms()
                integrals = [sp.Integral(term, symbol) for term in integral_terms]
                integral = sp.Add(*integrals)
                return integral
            
        return None
    

class UMethod(Method):
    def __init__(self, u_var: sp.Symbol, u_func: sp.Expr):
        super().__init__('URule', [u_var, u_func])

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]
            u_var, u_func = tuple(self.args)
            
            if not integrand.has(symbol):
                return None
            # return None if u_var is not a sp.Symbol
            if not isinstance(u_var, sp.Symbol):
                return None
            
            # differentiate u_func wrt symbol
            
            integrand = integrand / sp.diff(u_func, symbol)

            if u_func.is_Pow:
                base, exp_ = u_func.as_base_exp()
                if exp_ == -1:
                    # avoid needless -log(1/x) from substitution
                    integrand = integrand.subs(sp.log(u_var), -sp.log(base))

            integrand = integrand.subs(u_func, u_var)

            # if not everything is substituted this way, we solve for u = u_func(x) and substitute again
            if integrand.has(symbol):
                try:
                    solution = solve(u_func - u_var, symbol, dict=True)
                except:
                    return None
                if len(solution) == 1:
                    symbol_solution = solution[0][symbol]
                    integrand = integrand.subs(symbol, symbol_solution)
                else:
                    return None
                
            if integrand.has(symbol):
                return None
            
            return sp.Integral(integrand, u_var)
            
        return None


class PartsMethod(Method):
    def __init__(self, u: sp.Expr, dv: sp.Expr):
        super().__init__('PartsRule', [u, dv])

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.Integral):
            integrand = subexpr.args[0]
            symbol = subexpr.args[1][0]
            u, dv = tuple(self.args)

            # verify that integrand = u*dv
            if simplify(integrand - u*dv) == 0:
                du = sp.diff(u, symbol)
                return u * sp.Integral(dv, symbol) - sp.Integral(du*sp.Integral(dv, symbol), symbol)
            
        return None
    

class PartialFractionsMethod(Method):
    def __init__(self):
        super().__init__('PartialFractions')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        try:
            for symbol in subexpr.free_symbols:
                partial_fraction = subexpr.apart(symbol)
                return partial_fraction
        except:
            return None

class CancelMethod(Method):
    def __init__(self):
        super().__init__('Cancel')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        try:
            cancelled = subexpr.cancel()
            return cancelled
        except:
            return None
        
class ExpandMethod(Method):
    def __init__(self):
        super().__init__('Expand')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        try:
            expanded = subexpr.expand()
            return expanded
        except:
            return None
        
class Tan1Method(Method):
    def __init__(self):
        super().__init__('Tan1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.tan):
            return sp.sin(subexpr.args[0])/sp.cos(subexpr.args[0])
        return None
    
class Cot1Method(Method):
    def __init__(self):
        super().__init__('Cot1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.cot):
            return sp.cos(subexpr.args[0])/sp.sin(subexpr.args[0])
        return None
    
class Cos1Method(Method):
    def __init__(self):
        super().__init__('Cos1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        f = Wild('f')
        match = subexpr.match(1/sp.cos(f))
        if match:
            f = match[f]
            return sp.sec(f)
        return None

class Sec1Method(Method):
    def __init__(self):
        super().__init__('Sec1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.sec):
            arg = subexpr.args[0]
            return ((sp.sec(arg)**2 + sp.tan(arg) * sp.sec(arg)) / (sp.sec(arg) + sp.tan(arg)))
        return None
    
class Csc1Method(Method):
    def __init__(self):
        super().__init__('Csc1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.csc):
            arg = subexpr.args[0]
            return ((sp.csc(arg)**2 + sp.cot(arg) * sp.csc(arg)) / (sp.csc(arg) + sp.cot(arg)))
        return None

class Tanh1Method(Method):
    def __init__(self):
        super().__init__('Tanh1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.tanh):
            return sp.sinh(subexpr.args[0])/sp.cosh(subexpr.args[0])
        return None
    
class Coth1Method(Method):
    def __init__(self):
        super().__init__('Coth1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.coth):
            return sp.cosh(subexpr.args[0])/sp.sinh(subexpr.args[0])
        return None

class Sech1Method(Method):
    def __init__(self):
        super().__init__('Sech1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.sech):
            arg = subexpr.args[0]
            return (1-sp.tanh(arg/2)**2)/(1+sp.tanh(arg/2)**2)
        return None

class Csch1Method(Method):
    def __init__(self):
        super().__init__('Csch1')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        if isinstance(subexpr, sp.csch):
            arg = subexpr.args[0]
            return (1-sp.tanh(arg/2)**2)/(2*sp.tanh(arg/2))
        return None
    
class TrigExpandMethod(Method):
    def __init__(self):
        super().__init__('TrigExpand')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        try:
            expanded = subexpr.expand(trig = True)
            return expanded
        except:
            return None
        
class SinCosEvenMethod(Method):
    def __init__(self):
        super().__init__('SinCosEven')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = sincos_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if m.is_even and n.is_even and m.is_nonnegative and n.is_nonnegative:
                    return (((((1 - sp.cos(2*a*symbol)) / 2))**(m/2)) * ((((1 + sp.cos(2*b*symbol)) / 2))**(n/2)))

        return None
    
class SinOddCosMethod(Method): 
    def __init__(self):
        super().__init__('SinOddCos')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = sincos_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if m.is_odd and m >= 3:
                    return ( (1 - sp.cos(a*symbol)**2)**((m - 1) / 2) * sp.sin(a*symbol) * sp.cos(b*symbol) ** n)

        return None
    
class CosOddSinMethod(Method):
    def __init__(self):
        super().__init__('CosOddSin')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = sincos_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if n.is_odd and n >= 3:
                    return ( (1 - sp.sin(b*symbol)**2)**((n - 1) / 2) * sp.cos(b*symbol) * sp.sin(a*symbol) ** m)

        return None
    
class SecEvenTanMethod(Method):
    def __init__(self):
        super().__init__('SecEvenTan')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = tansec_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if n.is_even and n >= 4:
                    return ((1 + sp.tan(b * symbol)**2)**(n / 2 - 1) * sp.sec(b * symbol)**2 * sp.tan(a * symbol)**m)

        return None
    
class TanOddSecMethod(Method):
    def __init__(self):
        super().__init__('TanOddSec')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = tansec_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if m.is_odd:
                    return ((sp.sec(a * symbol)**2 - 1)**((m - 1) / 2) * sp.tan(a * symbol) * sp.sec(b * symbol)**n)

        return None
    
class Tan2Method(Method):
    def __init__(self):
        super().__init__('Tan2')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = tansec_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if m == 2 and n == 0:
                    return sp.sec(a * symbol)**2 - 1

        return None
    
class CotCscEvenMethod(Method):
    def __init__(self):
        super().__init__('CotCscEven')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = cotcsc_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if n.is_even and n >= 4:
                    return ((1 + sp.cot(b * symbol)**2)**(n / 2 - 1) * sp.csc(b * symbol)**2 * sp.cot(a * symbol)**m)

        return None

class CotOddCscMethod(Method):
    def __init__(self):
        super().__init__('CotOddCsc')

    def apply(self, subexpr: sp.Expr) -> sp.Expr:

        for symbol in subexpr.free_symbols:
            pattern, a, b, m, n = cotcsc_pattern(symbol)
            match = subexpr.match(pattern)
        
            if match:
                a,b,m,n = tuple(match.get(w, S.Zero) for w in (a,b,m,n))
                if m.is_odd:
                    return ((sp.csc(a * symbol)**2 - 1)**((m - 1) / 2) * sp.cot(a * symbol) * sp.csc(b * symbol)**n)

        return None