import itertools
import math
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, cast, Self
from collections.abc import Iterator, Sequence
from functools import partial

import numpy as np
import torch
from torch import Tensor, nn

import cirkit.symbolic.functional as SF
from cirkit.backend.torch.circuits import TorchCircuit, TorchConstantCircuit
from cirkit.backend.torch.layers import TorchInnerLayer, TorchInputLayer, TorchHadamardLayer, TorchLayer
from cirkit.backend.torch.parameters.nodes import TorchTensorParameter, TorchPointerParameter
from cirkit.pipeline import compile
from cirkit.symbolic.circuit import Circuit
from cirkit.symbolic.dtypes import DataType
from cirkit.symbolic.initializers import NormalInitializer, UniformInitializer
from cirkit.symbolic.layers import (
    CategoricalLayer,
    GaussianLayer,
    EmbeddingLayer,
)
from cirkit.symbolic.parameters import (
    ClampParameter,
    ExpParameter,
    LogSoftmaxParameter,
    Parameter,
    ScaledSigmoidParameter,
    TensorParameter,
    SoftmaxParameter,
)
from cirkit.templates.region_graph import (
    RegionGraph,
    LinearTree,
    RandomBinaryTree,
)
from cirkit.utils.scope import Scope
from initializers import ExpUniformInitializer
from layers import FourierLayer
from optimization.utils import TorchEmbeddingStiefelParameter, Concatenate
from pipeline import setup_pipeline_context

import geoopt
from cirkit.symbolic.layers import KroneckerLayer
from cirkit.backend.torch.layers.input import TorchEmbeddingLayer

from manifolds import StiefelT
from region_graphs import DecisionLeafQuadGraph, RandomQuadTree2, QuadTree, union_region_graphs


class PC(nn.Module, ABC):
    def __init__(
        self, num_variables: int, image_shape: tuple[int, int, int] | None = None
    ) -> None:
        assert num_variables > 1
        if image_shape is not None:
            assert np.prod(image_shape) == num_variables
        super().__init__()
        self.num_variables = num_variables
        self.image_shape = image_shape
        self.__cache_log_z: Tensor | None = None

    def train(self, mode: bool = True) -> Self:
        if mode:
            self.__cache_log_z = None
        else:
            with torch.no_grad():
                self.__cache_log_z = self.log_partition()
        return super().train(mode)

    def forward(self, x: Tensor) -> Tensor:
        return self.log_score(x)

    def log_likelihood(self, x: Tensor) -> Tensor:
        log_z = (
            self.log_partition() if self.__cache_log_z is None else self.__cache_log_z
        )
        log_score = self.log_score(x)
        return log_score - log_z

    def num_params(self, requires_grad: bool = True) -> int:
        return self.num_input_params(requires_grad) + self.num_sum_params(requires_grad)

    def num_input_params(self, requires_grad: bool = True) -> int:
        params = itertools.chain(*[l.parameters() for l in self.input_layers()])
        if requires_grad:
            params = filter(lambda p: p.requires_grad, params)
        num_params = sum(
            (2 * p.numel()) if p.is_complex() else p.numel() for p in params
        )
        return num_params

    def num_sum_params(self, requires_grad: bool = True) -> int:
        params = itertools.chain(*[l.parameters() for l in self.inner_layers()])
        if requires_grad:
            params = filter(lambda p: p.requires_grad, params)
        num_params = sum(
            (2 * p.numel()) if p.is_complex() else p.numel() for p in params
        )
        return num_params

    @abstractmethod
    def layers(self) -> Iterator[TorchLayer]: ...

    @abstractmethod
    def input_layers(self) -> Iterator[TorchInputLayer]: ...

    @abstractmethod
    def inner_layers(self) -> Iterator[TorchInnerLayer]: ...

    @abstractmethod
    def log_partition(self) -> Tensor: ...

    @abstractmethod
    def log_score(self, x: Tensor) -> Tensor: ...


class MPC(PC):
    def __init__(
        self,
        num_variables: int,
        image_shape: tuple[int, int, int] | None = None,
        *,
        num_input_units: int,
        num_sum_units: int,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_components: int = 1,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        mono_clamp: bool = False,
        seed: int = 42,
    ) -> None:
        assert num_components > 0
        super().__init__(num_variables, image_shape)
        self._pipeline = setup_pipeline_context(semiring="lse-sum")
        self._circuit, self._int_circuit = self._build_circuits(
            num_input_units,
            num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            num_components=num_components,
            region_graph=region_graph,
            structured_decomposable=structured_decomposable,
            mono_clamp=mono_clamp,
            seed=seed,
        )
        self.register_buffer(
            "_mixing_log_weight", -torch.log(torch.tensor(num_components))
        )

    def layers(self) -> Iterator[TorchLayer]:
        return iter(self._circuit.layers)

    def input_layers(self) -> Iterator[TorchInputLayer]:
        return map(
            partial(cast, TorchInputLayer),
            filter(lambda l: isinstance(l, TorchInputLayer), self._circuit.layers)
        )

    def inner_layers(self) -> Iterator[TorchInnerLayer]:
        return map(
            partial(cast, TorchInnerLayer),
            filter(lambda l: isinstance(l, TorchInnerLayer), self._circuit.layers)
        )

    def log_partition(self) -> Tensor:
        log_z = self._int_circuit()
        return torch.logsumexp(cast(Tensor, self._mixing_log_weight) + log_z, dim=0)

    def log_score(self, x: Tensor) -> Tensor:
        log_score = self._circuit(x)
        return torch.logsumexp(cast(Tensor, self._mixing_log_weight) + log_score, dim=1)

    def _build_circuits(
        self,
        num_input_units: int,
        num_sum_units: int,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_components: int = 1,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        mono_clamp: bool = False,
        seed: int = 42,
    ) -> tuple[TorchCircuit, TorchConstantCircuit]:
        # Build the region graphs
        rgs = _build_region_graphs(
            region_graph,
            num_components,
            num_variables=self.num_variables,
            image_shape=self.image_shape,
            structured_decomposable=structured_decomposable,
            seed=seed,
        )

        # Build one symbolic circuit for each region graph
        sym_circuits = _build_monotonic_sym_circuits(
            rgs,
            num_input_units,
            num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            mono_clamp=mono_clamp,
        )

        with self._pipeline:
            # Merge the symbolic circuits into a single one having multiple outputs
            sym_circuit = SF.concatenate(sym_circuits)

            # Integrate the circuits (by integrating the merged symbolic representation)
            sym_int_circuit = SF.integrate(sym_circuit)

            # Compile the symbolic circuits
            circuit = cast(TorchCircuit, compile(sym_circuit))
            int_circuit = cast(TorchConstantCircuit, compile(sym_int_circuit))

        return circuit, int_circuit


class SOS(PC):
    def __init__(
        self,
        num_variables: int,
        image_shape: tuple[int, int, int] | None = None,
        *,
        num_input_units: int,
        num_sum_units: int,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_squares: int = 1,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        complex: bool = False,
        use_tucker: bool = False,
        seed: int = 42,
    ) -> None:
        assert num_squares > 0
        super().__init__(num_variables, image_shape)
        self.num_squares = num_squares
        self._pipeline = setup_pipeline_context(semiring="complex-lse-sum")
        self._circuit, self._int_sq_circuit = self._build_circuits(
            num_input_units,
            num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            num_squares=num_squares,
            region_graph=region_graph,
            structured_decomposable=structured_decomposable,
            complex=complex,
            use_tucker=use_tucker,
            seed=seed,
        )
        self.register_buffer(
            "_mixing_log_weight", -torch.log(torch.tensor(num_squares))
        )

    def layers(self) -> Iterator[TorchLayer]:
        return iter(self._circuit.layers)

    def input_layers(self) -> Iterator[TorchInputLayer]:
        return map(
            partial(cast, TorchInputLayer),
            filter(lambda l: isinstance(l, TorchInputLayer), self._circuit.layers)
        )

    def inner_layers(self) -> Iterator[TorchInnerLayer]:
        return map(
            partial(cast, TorchInputLayer),
            filter(lambda l: isinstance(l, TorchInnerLayer), self._circuit.layers)
        )

    def log_partition(self) -> Tensor:
        log_z = self._int_sq_circuit().real
        return torch.logsumexp(cast(Tensor, self._mixing_log_weight) + log_z, dim=0)

    def log_score(self, x: Tensor) -> Tensor:
        log_score = 2.0 * self._circuit(x).real
        return torch.logsumexp(cast(Tensor, self._mixing_log_weight) + log_score, dim=1)

    def _build_circuits(
        self,
        num_input_units: int,
        num_sum_units: int,
        *,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_squares: int = 1,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        complex: bool = False,
        seed: int = 42,
        use_tucker: bool = False,
    ) -> tuple[TorchCircuit, TorchConstantCircuit]:
        # Build the region graphs
        rgs = _build_region_graphs(
            region_graph,
            num_squares,
            num_variables=self.num_variables,
            image_shape=self.image_shape,
            structured_decomposable=structured_decomposable,
            seed=seed,
        )

        # Build one symbolic circuit for each region graph
        sym_circuits = _build_non_monotonic_sym_circuits(
            rgs,
            num_input_units,
            num_sum_units,
            model_name='sos',
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            complex=complex,
            use_tucker=use_tucker,
        )

        with self._pipeline:
            # Merge the symbolic circuits into a single one having multiple outputs
            sym_circuit = SF.concatenate(sym_circuits)

            # Square each symbolic circuit and merge them into a single one having multiple outputs
            sym_sq_circuits = [
                SF.multiply(SF.conjugate(sc), sc) # If it is real, just adds extra torch.conj calls
                for sc in sym_circuits
            ]
            sym_sq_circuit = SF.concatenate(sym_sq_circuits)

            # Integrate the squared circuits (by integrating the merged symbolic representation)
            sym_int_sq_circuit = SF.integrate(sym_sq_circuit)

            # Compile the symbolic circuits
            circuit = cast(TorchCircuit, compile(sym_circuit))
            int_sq_circuit = cast(TorchConstantCircuit, compile(sym_int_sq_circuit))

        return circuit, int_sq_circuit


class ExpSOS(PC):
    def __init__(
        self,
        num_variables: int,
        image_shape: tuple[int, int, int] | None = None,
        *,
        num_input_units: int,
        num_sum_units: int,
        mono_num_input_units: int = 2,
        mono_num_sum_units: int = 2,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        mono_clamp: bool = False,
        complex: bool = False,
        seed: int = 42,
    ) -> None:
        super().__init__(num_variables, image_shape)
        self._pipeline = setup_pipeline_context(semiring="complex-lse-sum")
        # Introduce optimization rules
        self._circuit, self._mono_circuit, self._int_circuit = self._build_circuits(
            num_input_units,
            num_sum_units,
            mono_num_input_units,
            mono_num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            region_graph=region_graph,
            structured_decomposable=structured_decomposable,
            mono_clamp=mono_clamp,
            complex=complex,
            seed=seed,
        )

    def layers(self) -> Iterator[TorchLayer]:
        return itertools.chain(self._circuit.layers, self._mono_circuit.layers)

    def input_layers(self) -> Iterator[TorchInputLayer]:
        return itertools.chain(
            map(
                partial(cast, TorchInputLayer),
                filter(lambda l: isinstance(l, TorchInputLayer), self._circuit.layers)
            ),
            map(
                partial(cast, TorchInputLayer),
                filter(lambda l: isinstance(l, TorchInputLayer), self._mono_circuit.layers)
            )
        )

    def inner_layers(self) -> Iterator[TorchInnerLayer]:
        return itertools.chain(
            map(
                partial(cast, TorchInputLayer),
                filter(lambda l: isinstance(l, TorchInnerLayer), self._circuit.layers)
            ),
            map(
                partial(cast, TorchInputLayer),
                filter(lambda l: isinstance(l, TorchInnerLayer), self._mono_circuit.layers),
            )
        )

    def log_partition(self) -> Tensor:
        return self._int_circuit().real

    def log_score(self, x: Tensor) -> Tensor:
        sq_log_score = 2.0 * self._circuit(x).real
        mono_log_score = self._mono_circuit(x).real
        return (sq_log_score + mono_log_score).squeeze(dim=1)

    def _build_circuits(
        self,
        num_input_units: int,
        num_sum_units: int,
        mono_num_input_units: int = 2,
        mono_num_sum_units: int = 2,
        *,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        region_graph: str = "rnd-bt",
        structured_decomposable: bool = False,
        mono_clamp: bool = False,
        complex: bool = False,
        seed: int = 42,
    ) -> tuple[TorchCircuit, TorchCircuit, TorchConstantCircuit]:
        # Build the region graphs
        rgs = _build_region_graphs(
            region_graph,
            1,
            num_variables=self.num_variables,
            image_shape=self.image_shape,
            structured_decomposable=structured_decomposable,
            seed=seed,
        )
        assert len(rgs) == 1

        # Build one symbolic circuit for each region graph
        sym_circuits = _build_non_monotonic_sym_circuits(
            rgs,
            num_input_units,
            num_sum_units,
            model_name='expsos',
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            complex=complex,
        )

        sym_mono_circuits = _build_monotonic_sym_circuits(
            rgs,
            mono_num_input_units,
            mono_num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            mono_clamp=mono_clamp,
        )
        assert len(sym_circuits) == 1
        assert len(sym_mono_circuits) == 1
        (sym_circuit,) = sym_circuits
        (sym_mono_circuit,) = sym_mono_circuits

        with self._pipeline:
            # Square the symbolic circuit and make the product with the monotonic circuit
            if complex:
                # Apply the conjugate operator if the circuit is complex
                sym_prod_circuit = SF.multiply(
                    SF.multiply(sym_mono_circuit, SF.conjugate(sym_circuit)),
                    sym_circuit,
                )
            else:
                sym_prod_circuit = SF.multiply(
                    SF.multiply(sym_mono_circuit, sym_circuit), sym_circuit
                )

            # Integrate the overall product circuit
            sym_int_circuit = SF.integrate(sym_prod_circuit)

            # Compile the symbolic circuits
            circuit = cast(TorchCircuit, compile(sym_circuit))
            mono_circuit = cast(TorchCircuit, compile(sym_mono_circuit))
            int_circuit = cast(TorchConstantCircuit, compile(sym_int_circuit))

        return circuit, mono_circuit, int_circuit


class OrthogonalSOS(PC):
    def __init__(
        self,
        num_variables: int,
        image_shape: tuple[int, int, int] | None = None,
        *,
        num_input_units: int,
        num_sum_units: int,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_squares: int = 1,
        region_graph: str = "rnd-bt",
        num_repetitions: int = 1,
        max_patch_size: int = 8,
        structured_decomposable: bool = False,
        complex: bool = False,
        seed: int = 42,
        benchmark: bool = False,
        use_tucker: bool = False,
    ) -> None:
        assert num_squares > 0
        assert num_repetitions > 0
        super().__init__(num_variables, image_shape)
        self.num_squares = num_squares
        self.num_repetitions = num_repetitions
        self.num_sum_units = num_sum_units
        self.num_input_units = num_input_units
        self.benchmark = benchmark
        self._pipeline = setup_pipeline_context(semiring="complex-lse-sum")
        self._circuit, _int_sq_circuit = self._build_circuits(
            num_input_units,
            num_sum_units,
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            num_squares=num_squares,
            region_graph=region_graph,
            structured_decomposable=structured_decomposable,
            num_repetitions=num_repetitions,
            max_patch_size=max_patch_size,
            complex=complex,
            seed=seed,
            use_tucker=use_tucker,
        )
        
        self._logits_mixing_weights = torch.nn.Parameter(torch.zeros((num_squares, 1)) - math.log(num_squares), requires_grad=False)
        
        # It does not have new parameters and we do not use it 
        self._int_sq_circuit = _int_sq_circuit

        self._add_manifolds()

    @property
    def _mixing_weights(self):
        return torch.softmax(self._logits_mixing_weights, dim=0)

    def layers(self) -> Iterator[TorchLayer]:
        return iter(self._circuit.layers)

    def input_layers(self) -> Iterator[TorchInputLayer]:
        return map(
            partial(cast, TorchInputLayer),
            filter(lambda l: isinstance(l, TorchInputLayer), self._circuit.layers)
        )

    def inner_layers(self) -> Iterator[TorchInnerLayer]:
        return map(
            partial(cast, TorchInputLayer),
            filter(lambda l: isinstance(l, TorchInnerLayer), self._circuit.layers)
        )

    @torch.no_grad()
    def log_partition(self) -> Tensor:
        if (self.benchmark and self.train) or (self._int_sq_circuit is None):  # There is still some unnecessary overhead but only during evaluation
            log_z = 0.
        else:
            log_z = cast(TorchConstantCircuit, self._int_sq_circuit)().real 

        _mixing_log_weight = self._logits_mixing_weights - torch.logsumexp(self._logits_mixing_weights, dim=0)
        return torch.logsumexp(_mixing_log_weight + log_z, dim=0)

    def log_score(self, x: Tensor) -> Tensor:
        log_score = 2.0 * self._circuit(x).real
        _mixing_log_weight = self._logits_mixing_weights - torch.logsumexp(self._logits_mixing_weights, dim=0)
        return torch.logsumexp(_mixing_log_weight + log_score, dim=1)

    def _build_circuits(
        self,
        num_input_units: int,
        num_sum_units: int,
        *,
        input_layer: str,
        input_layer_kwargs: dict[str, Any] | None = None,
        num_squares: int = 1,
        region_graph: str = "rnd-bt",
        num_repetitions: int = 1,
        max_patch_size: int = 8,
        structured_decomposable: bool = False,
        complex: bool = False,
        use_tucker: bool = False,
        seed: int = 42,
    ) -> tuple[TorchCircuit, TorchConstantCircuit | None]:
        # Build the region graphs
        rgs = _build_region_graphs(
            region_graph,
            num_squares,
            num_variables=self.num_variables,
            image_shape=self.image_shape,
            num_repetitions=num_repetitions,
            max_patch_size=max_patch_size,
            structured_decomposable=structured_decomposable,
            seed=seed,
        )

        # Build one symbolic circuit for each region graph
        sym_circuits = _build_non_monotonic_sym_circuits(
            rgs,
            num_input_units,
            num_sum_units,
            model_name='osos',
            input_layer=input_layer,
            input_layer_kwargs=input_layer_kwargs,
            complex=complex,
            use_tucker=use_tucker,
        )

        use_tucker = use_tucker or any([
            any([isinstance(l, KroneckerLayer) for l in sc.topological_ordering()])
            for sc in sym_circuits
        ]) # set flag on if any 

        with self._pipeline:
            # Merge the symbolic circuits into a single one having multiple outputs
            sym_circuit = SF.concatenate(sym_circuits)
            circuit = cast(TorchCircuit, compile(sym_circuit))

            if use_tucker or num_repetitions > 1:
                return circuit, None
            
            # Square each symbolic circuit and merge them into a single one having multiple outputs
            sym_sq_circuits = [
                SF.multiply(SF.conjugate(sc), sc)  # If it is real, just adds extra torch.conj calls
                for sc in sym_circuits
            ]
            sym_sq_circuit = SF.concatenate(sym_sq_circuits)

            # Integrate the squared circuits (by integrating the merged symbolic representation)
            sym_int_sq_circuit = SF.integrate(sym_sq_circuit)

            # Compile the symbolic circuits
            int_sq_circuit = cast(TorchConstantCircuit, compile(sym_int_sq_circuit))

        return circuit, int_sq_circuit

    def _add_manifolds(self):
        # A simple utility to inject the Stiefel manifold in a layer parameter
        def _add_manifold(pnode: TorchTensorParameter):
            pnode._ptensor = geoopt.ManifoldParameter(
                pnode._ptensor.data,
                manifold=StiefelT()
            )
            with torch.no_grad():
                pnode._ptensor.proj_()

        # Collect the layers and their torch tensor parameters
        layer_pnodes: dict[TorchInnerLayer | TorchEmbeddingLayer, list[TorchTensorParameter]] = defaultdict(list)
        for layer in self._circuit.layers:
            if not isinstance(layer, (TorchInnerLayer, TorchEmbeddingLayer)):
                continue
            if isinstance(layer, TorchHadamardLayer):
                continue
            assert len(layer.params) == 1, list(layer.params.keys())
            pname = list(layer.params.keys())[0]
            layer_pnodes[layer].extend(
                cast(TorchTensorParameter, m) for m in layer.params[pname].modules() if hasattr(m, '_ptensor')
            )

        # Inject the Stiefel manifold to the parameters
        for layer, pnodes in layer_pnodes.items():
            if isinstance(layer, TorchInnerLayer):
                for pnode in pnodes:
                    _add_manifold(pnode)
                continue
            assert isinstance(layer, TorchEmbeddingLayer)
            # Check whether each squared circuit has a structure obtained by combining
            # several other circuits, i.e., when num_repetitions>1.
            # In that case, use a parameterization that ensures orthogonality
            # of such circuits
            if self.num_repetitions == 1:
                for pnode in pnodes:
                    _add_manifold(pnode)
                continue
            for i in range(self.num_squares):
                assert not hasattr(layer.weight._address_book, f'_in_fold_idx_{i}_0')
            assert (getattr(layer.weight._address_book, f'_in_fold_idx_{self.num_squares}_0') == torch.arange(
                self.num_squares * self.num_repetitions * self.num_variables
            )).all()
            # Make sure the pointer parameter do not apply a fold-wise permutation
            pointer_pnodes: list[TorchPointerParameter] = []
            for n in layer.weight.nodes:
                assert isinstance(n, TorchPointerParameter)
                assert n._fold_idx is None
                pointer_pnodes.append(n)
            stiefel_params: list[TorchEmbeddingStiefelParameter] = []
            scope_idx_stride = self.num_repetitions * self.num_variables
            for i, pointer_pnode in enumerate(pointer_pnodes):
                data = pointer_pnode.deref()._ptensor.data
                shape = data.shape
                assert shape[0] == self.num_variables * self.num_repetitions
                assert shape[1] == self.num_input_units
                num_states = shape[-1]
                # TODO: monkey patch that changes the type of layer.weight
                #       from TorchTensorParameter to TorchEmbeddingStiefelParameter
                stiefel_params.append(TorchEmbeddingStiefelParameter(
                    data,
                    num_variables=self.num_variables,
                    num_repetitions=self.num_repetitions,
                    num_units=self.num_input_units,
                    num_states=num_states,
                    fold_idx=layer.scope_idx[i * scope_idx_stride:(i + 1) * scope_idx_stride]
                ))
            if len(stiefel_params) == 1:
                layer.weight = stiefel_params[0]
            else:
                layer.weight = Concatenate(*stiefel_params)

def _build_region_graphs(
    name: str,
    k: int,
    num_variables: int | None = None,
    image_shape: tuple[int, int, int] | None = None,
    num_repetitions: int = 1,
    max_patch_size: int = 8,
    structured_decomposable: bool = False,
    seed: int = 42,
) -> Sequence[RegionGraph]:
    if name == "bt":
        assert num_variables is not None
        return [
            _build_bt_region_graph(
                num_variables,
                num_repetitions,
                seed=(seed if structured_decomposable else seed + i * 123),
            )
            for i in range(k)
        ]
    elif name == "rnd-bt":
        assert num_variables is not None
        return [
            _build_rnd_bt_region_graph(
                num_variables,
                num_repetitions,
                seed=(seed if structured_decomposable else seed + i * 123),
            )
            for i in range(k)
        ]
    elif name == "rnd-lt":
        assert num_variables is not None
        return [
            _build_lt_region_graph(
                num_variables,
                num_repetitions,
                random=True,
                seed=(seed if structured_decomposable else seed + i * 123),
            )
            for i in range(k)
        ]
    elif name == "lt":
        assert num_variables is not None
        return [_build_lt_region_graph(num_variables, num_repetitions, random=False) for _ in range(k)]
    elif name == "qt-2":
        assert image_shape is not None
        return [
            _build_qt_region_graph(image_shape, num_repetitions=num_repetitions, num_patch_splits=2) for _ in range(k)
        ]
    elif name in ["qt", "qt-4"]:
        assert image_shape is not None
        return [
            _build_qt_region_graph(image_shape, num_repetitions=num_repetitions, num_patch_splits=4) for _ in range(k)
        ]
    elif name == "rnd-qt-2":
        assert image_shape is not None
        return [
            _build_rnd_qt_2_region_graph(image_shape, num_repetitions, seed=seed) for _ in range(k)
        ]
    elif name == "dl-qg":
        assert image_shape is not None
        return [
            _build_dl_qg_region_graph(image_shape, max_patch_size) for _ in range(k)
        ]
    raise NotImplementedError()


def _build_bt_region_graph(
    num_variables: int,
    num_repetitions: int = 1,
    seed: int = 42
) -> RegionGraph:
    max_depth = int(np.ceil(np.log2(num_variables)))
    rgs = [RandomBinaryTree(num_variables, depth=max_depth, seed=seed) for _ in range(num_repetitions)]
    return union_region_graphs(rgs)


def _build_rnd_bt_region_graph(
    num_variables: int,
    num_repetitions: int = 1,
    seed: int = 42
) -> RegionGraph:
    max_depth = int(np.ceil(np.log2(num_variables)))
    return RandomBinaryTree(num_variables, depth=max_depth, num_repetitions=num_repetitions, seed=seed)


def _build_lt_region_graph(
    num_variables: int, num_repetitions: int = 1, random: bool = False, seed: int = 42
) -> RegionGraph:
    return LinearTree(num_variables, num_repetitions=num_repetitions, randomize=random, seed=seed)


def _build_qt_region_graph(
    image_shape: tuple[int, int, int], num_repetitions: int = 1, num_patch_splits: int = 2
) -> RegionGraph:
    return QuadTree(image_shape, num_repetitions=num_repetitions, num_patch_splits=num_patch_splits)


def _build_rnd_qt_2_region_graph(
    image_shape: tuple[int, int, int], num_repetitions: int = 1, seed: int = 42
) -> RegionGraph:
    return RandomQuadTree2(image_shape, num_repetitions=num_repetitions, seed=seed)


def _build_dl_qg_region_graph(
    image_shape: tuple[int, int, int], max_patch_size: int = 8
) -> RegionGraph:
    return DecisionLeafQuadGraph(image_shape, max_patch_size=max_patch_size)


def _build_monotonic_sym_circuits(
    region_graphs: Sequence[RegionGraph],
    num_input_units: int,
    num_sum_units: int,
    *,
    input_layer: str,
    input_layer_kwargs: dict[str, Any] | None = None,
    mono_clamp: bool = False,
) -> list[Circuit]:
    if input_layer_kwargs is None:
        input_layer_kwargs = {}

    def weight_factory_clamp(shape: tuple[int, ...]) -> Parameter:
        return Parameter.from_unary(
            ClampParameter(shape, vmin=1e-19),
            TensorParameter(*shape, initializer=UniformInitializer(0.01, 0.99)),
        )

    def weight_factory_exp(shape: tuple[int, ...]) -> Parameter:
        return Parameter.from_unary(
            ExpParameter(shape),
            TensorParameter(*shape, initializer=ExpUniformInitializer(0.0, 1.0)),
        )

    def categorical_layer_factory(
        scope: Scope, num_units: int, 
    ) -> CategoricalLayer:
        assert "num_categories" in input_layer_kwargs
        return CategoricalLayer(
            scope,
            num_units,
            num_categories=input_layer_kwargs["num_categories"],
            logits_factory=lambda shape: Parameter.from_unary(
                LogSoftmaxParameter(shape),
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
            ),
        )

    def embedding_layer_factory(
        scope: Scope, num_units: int, 
    ) -> EmbeddingLayer:
        assert "num_states" in input_layer_kwargs
        return EmbeddingLayer(
            scope,
            num_units,
            num_states=input_layer_kwargs["num_states"],
            weight_factory=lambda shape: Parameter.from_unary(
                SoftmaxParameter(shape),
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
            ),
        )

    def gaussian_layer_factory(
        scope: Scope, num_units: int, 
    ) -> GaussianLayer:
        return GaussianLayer(
            scope,
            num_units,
            mean_factory=lambda shape: Parameter.from_input(
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
            ),
            stddev_factory=lambda shape: Parameter.from_sequence(
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
                ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
            ),
        )

    def build_sym_circuit(rg: RegionGraph) -> Circuit:
        assert input_layer in ["embedding", "categorical", "gaussian"]
        weight_factory = weight_factory_clamp if mono_clamp else weight_factory_exp
        if input_layer == "categorical":
            sum_product = "cp-t"
            input_factory = categorical_layer_factory
        elif input_layer == "gaussian":
            sum_product = "cp"
            input_factory = gaussian_layer_factory
        else:
            sum_product = "cp-t"
            input_factory = embedding_layer_factory
        return rg.build_circuit(
            num_input_units=num_input_units,
            num_sum_units=num_sum_units,
            input_factory=input_factory,
            sum_product=sum_product,
            sum_weight_factory=weight_factory,
        )

    return list(map(lambda rg: build_sym_circuit(rg), region_graphs))


def _build_non_monotonic_sym_circuits(
    region_graphs: Sequence[RegionGraph],
    num_input_units: int,
    num_sum_units: int,
    *,
    model_name: str,
    input_layer: str,
    input_layer_kwargs: dict[str, Any] | None = None,
    complex: bool = False,
    use_tucker: bool = False,
) -> list[Circuit]:
    if input_layer_kwargs is None:
        input_layer_kwargs = {}

    def weight_factory(shape: tuple[int, ...]) -> Parameter:
        weight_dtype = DataType.COMPLEX if complex else DataType.REAL
        if input_layer == 'fourier':
            if use_tucker:          
                initializer = NormalInitializer(0.0, 2.0)
            else:
                initializer = NormalInitializer(0.0, 1/shape[-1]**0.5)
        else:
            if model_name == 'osos':
                initializer = NormalInitializer(0.0, 2.0)
            elif region_graphs[0].num_variables <= 2:
                initializer = UniformInitializer(-2.0, 2.0)
            else:
                initializer = UniformInitializer(0.0, 1.0)
                
        return Parameter.from_input(
            TensorParameter(*shape, initializer=initializer, dtype=weight_dtype)
        )

    def categorical_layer_factory(
        scope: Scope, num_units: int, 
    ) -> CategoricalLayer:
        assert "num_categories" in input_layer_kwargs
        return CategoricalLayer(
            scope,
            num_units,
            num_categories=input_layer_kwargs["num_categories"],
            logits_factory=lambda shape: Parameter.from_input(
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
            ),
        )

    def embedding_layer_factory(
        scope: Scope, num_units: int, 
    ) -> EmbeddingLayer:
        assert "num_states" in input_layer_kwargs
        return EmbeddingLayer(
            scope,
            num_units,
            num_states=input_layer_kwargs["num_states"],
            weight_factory=weight_factory,
        )

    def gaussian_layer_factory(
        scope: Scope, num_units: int, 
    ) -> GaussianLayer:
        return GaussianLayer(
            scope,
            num_units,
            mean_factory=lambda shape: Parameter.from_input(
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0))
            ),
            stddev_factory=lambda shape: Parameter.from_sequence(
                TensorParameter(*shape, initializer=NormalInitializer(0.0, 1.0)),
                ScaledSigmoidParameter(shape, vmin=1e-5, vmax=1.0),
            ),
        )

    def fourier_layer_factory(
        scope: Scope, num_units: int,
    ) -> FourierLayer:
        return FourierLayer(
            scope,
            num_units,
            period=input_layer_kwargs['period'][list(scope)[0]],
        )

    def build_sym_circuit(rg: RegionGraph) -> Circuit:
        assert input_layer in ["categorical", "embedding", "gaussian", "fourier"]
        if input_layer == "categorical":
            sum_product = "cp-t"
            input_factory = categorical_layer_factory
        elif input_layer == "gaussian":
            sum_product = "cp"
            input_factory = gaussian_layer_factory
        elif input_layer == "embedding":
            sum_product = "cp-t"
            input_factory = embedding_layer_factory
        else:
            sum_product = "cp"
            input_factory = fourier_layer_factory

        if use_tucker:
            sum_product = "tucker"

        return rg.build_circuit(
            num_input_units=num_input_units,
            num_sum_units=num_sum_units,
            input_factory=input_factory,
            sum_product=sum_product,
            sum_weight_factory=weight_factory,
        )

    return list(map(lambda rg: build_sym_circuit(rg), region_graphs))
