from cirkit.pipeline import PipelineContext
from cirkit.symbolic.layers import LayerOperator
from initializers import compile_exp_uniform_initializer
from parameters import compile_negate_parameter, compile_fourier_integral_parameter, compile_expand_parameter
from layers import compile_fourier_layer
from operators import integrate_fourier_layer, multiply_fourier_layers, conjugate_fourier_layer


def setup_pipeline_context(
    *,
    backend: str = "torch",
    semiring: str = "lse-sum",
    fold: bool = True,
    optimize: bool = True,
) -> PipelineContext:
    ctx = PipelineContext(
        backend=backend,
        semiring=semiring,
        fold=fold,
        optimize=optimize
    )
    ctx.add_parameter_compilation_rule(compile_negate_parameter)
    ctx.add_parameter_compilation_rule(compile_fourier_integral_parameter)
    ctx.add_parameter_compilation_rule(compile_expand_parameter)
    ctx.add_layer_compilation_rule(compile_fourier_layer)
    ctx.add_initializer_compilation_rule(compile_exp_uniform_initializer)
    ctx.add_operator_rule(LayerOperator.CONJUGATION, conjugate_fourier_layer)
    ctx.add_operator_rule(LayerOperator.MULTIPLICATION, multiply_fourier_layers)
    ctx.add_operator_rule(LayerOperator.INTEGRATION, integrate_fourier_layer)
    return ctx
