# forward_forward/models/model_factory.py (Updated with Class Grouping)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Tuple, Any, Optional, Union
from collections import OrderedDict

from forward_forward.ff_layer import FFLayer
from forward_forward.models.layers.class_grouping import (
    ClassGroupingManager, 
    create_class_grouping_manager,
    FleaBlockWithEncoding
)
from forward_forward.models.layers.normalizer import Normalizer


class MixPool2d(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0, dilation=1, 
                 return_indices=False, ceil_mode=False):
        super(MixPool2d, self).__init__()
        
        # Initialize both pooling layers
        self.max_pool = nn.MaxPool2d(kernel_size, stride, padding, dilation, 
                                   return_indices, ceil_mode)
        self.avg_pool = nn.AvgPool2d(kernel_size, stride, padding, ceil_mode)
        
        # Store parameters for reference
        self.kernel_size = kernel_size
        self.stride = stride if stride is not None else kernel_size
        self.padding = padding
        self.dilation = dilation
        self.ceil_mode = ceil_mode
        
    def forward(self, x):
        # Apply both pooling operations
        max_pooled = self.max_pool(x)
        avg_pooled = self.avg_pool(x)
        
        # Average the results
        return (max_pooled + avg_pooled) / 2.0
    
    def extra_repr(self):
        return f'kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding}'


class AdaptivePool2d(nn.Module):
    """Adaptive pooling that switches between avg and max based on training mode."""
    
    def __init__(self, kernel_size: Union[int, Tuple[int, int]], 
                 stride: Optional[Union[int, Tuple[int, int]]] = None, 
                 padding: Union[int, Tuple[int, int]] = 0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            return F.avg_pool2d(x, self.kernel_size, self.stride, self.padding)
        else:
            return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)

# Updated registry to include class grouping enhanced layers
LAYER_REGISTRY = {
    "Linear": nn.Linear,
    "Conv1d": nn.Conv1d,
    "Conv2d": nn.Conv2d,
    "ReLU": nn.ReLU,
    "ReLU6": nn.ReLU6,
    "LeakyReLU": nn.LeakyReLU,
    "Identity": nn.Identity,
    "ELU": nn.ELU,
    "PReLU": nn.PReLU,
    "Sigmoid": nn.Sigmoid,
    "Tanh": nn.Tanh,
    "MaxPool2d": nn.MaxPool2d,
    "AvgPool2d": nn.AvgPool2d,
    "GAP1d": nn.AdaptiveAvgPool1d,
    "GAP2d": nn.AdaptiveAvgPool2d,
    "Flatten": nn.Flatten,
    "Dropout": nn.Dropout,
    "Dropout1d": nn.Dropout1d,
    "Dropout2d": nn.Dropout2d,
    "BatchNorm1d": nn.BatchNorm1d,
    "BatchNorm2d": nn.BatchNorm2d,
    "FleaBlock": FleaBlockWithEncoding,
    "Normalizer": Normalizer,
    "LayerNorm": nn.LayerNorm,
    "AdaptiveMaxPool2d": nn.AdaptiveMaxPool2d,
    "AdaptivePool2d": AdaptivePool2d,
    "MixPool2d": MixPool2d,
}

# Parameters to exclude when building layers
FF_SPECIFIC_PARAMS = {
    "trainable", "requires_labels", "ff_loss_type", "margin_tau", "skip_from",
    "conv_lr", "conv_weight_decay", "scale_lr", "scale_weight_decay", "scheduler",
    "class_groups"  # Add class_groups to FF-specific params
}

BP_SPECIFIC_PARAMS = FF_SPECIFIC_PARAMS | {
    "label_dimension", "alpha", "activation_function", 
    "g_power", "dropout"
}


# ... (keep all the existing classes: AdaptivePool2d, CenterCrop, SpaceToDepth, SkipConnection, BaseModel, ForwardForwardModel, BackpropModel, PoolingFactory unchanged)




class CenterCrop(nn.Module):
    """Center crop module for skip connections."""
    
    def __init__(self, h_out: int, w_out: int):
        super().__init__()
        self.h_out = h_out
        self.w_out = w_out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        _, _, H, W = x.shape
        top = (H - self.h_out) // 2
        left = (W - self.w_out) // 2
        return x[:, :, top:top + self.h_out, left:left + self.w_out]


class SpaceToDepth(nn.Module):
    """Space-to-depth transformation for skip connections."""
    
    def __init__(self, block_size: int, target_size: Optional[Tuple[int, int]] = None):
        super().__init__()
        self.block_size = block_size
        self.target_size = target_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        
        if self.target_size is not None:
            target_h, target_w = self.target_size
            block_h = H // target_h
            block_w = W // target_w
            
            assert H % target_h == 0 and W % target_w == 0, (
                f"Input size ({H}, {W}) must be divisible by target size ({target_h}, {target_w})"
            )
            
            x = x.view(B, C, target_h, block_h, target_w, block_w)
            x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
            x = x.view(B, C * block_h * block_w, target_h, target_w)
            
        else:
            assert H % self.block_size == 0 and W % self.block_size == 0, (
                f"Input size ({H}, {W}) must be divisible by block_size {self.block_size}"
            )
            
            new_h = H // self.block_size
            new_w = W // self.block_size
            
            x = x.view(B, C, new_h, self.block_size, new_w, self.block_size)
            x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
            x = x.view(B, C * self.block_size * self.block_size, new_h, new_w)
        
        return x


class SkipConnection(nn.Module):
    """Skip connection module supporting add and concatenate operations."""
    
    def __init__(self, skip_from: str, skip_type: str = 'cat', pool: Optional[nn.Module] = None):
        super().__init__()
        self.skip_from = skip_from
        self.skip_type = skip_type
        self.pool = pool if pool is not None else nn.Identity()

    def forward(self, x: torch.Tensor, skip_input: torch.Tensor) -> torch.Tensor:
        skip_input = self.pool(skip_input).detach()

        if self.skip_type == 'add':
            return x + skip_input
        elif self.skip_type == 'cat':
            return torch.cat([x, skip_input], dim=1)
        elif self.skip_type == 'jump':
            return skip_input
        else:
            raise ValueError(f"Unknown skip_type {self.skip_type}")


class BaseModel(nn.Module):
    """Base model class with common functionality."""
    
    def __init__(self, layers: nn.ModuleDict):
        super().__init__()
        self.layers = layers

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        outputs = {"input": x}
        curr_x = x
        
        for name, layer in self.layers.items():
            if isinstance(layer, SkipConnection):
                skip_input = outputs[layer.skip_from]
                curr_x = layer(curr_x, skip_input)
            else:
                curr_x = layer(curr_x)
            outputs[name] = curr_x
        
        return curr_x


class ForwardForwardModel(BaseModel):
    """Forward-Forward model implementation."""
    
    def __init__(self, layers: nn.ModuleDict, trainable_names: List[str]):
        super().__init__(layers)
        self.trainable_names = trainable_names


class BackpropModel(BaseModel):
    """Backpropagation model implementation."""
    
    def __init__(self, layers: nn.ModuleDict, num_classes: int):
        super().__init__(layers)
        self.num_classes = num_classes


class PoolingFactory:
    """Factory for creating pooling layers for skip connections."""
    
    @staticmethod
    def create_pooling_layer(
        curr_shape: Tuple[int, ...], 
        skip_shape: Tuple[int, ...], 
        pool_type: str
    ) -> nn.Module:
        """Create appropriate pooling layer based on shapes and type."""
        h_in, w_in = skip_shape[-2:]
        h_out, w_out = curr_shape[-2:]

        if (h_in, w_in) == (h_out, w_out):
            return nn.Identity()
        
        if pool_type == "center":
            return CenterCrop(h_out, w_out)
        
        elif pool_type == "space_to_depth":
            assert h_in % h_out == 0 and w_in % w_out == 0, (
                f"Cannot use space_to_depth: input size ({h_in}, {w_in}) "
                f"not evenly divisible by output size ({h_out}, {w_out})"
            )
            
            block_h = h_in // h_out
            block_w = w_in // w_out
            
            # For now, we require square blocks for simplicity
            assert block_h == block_w, (
                f"Space-to-depth currently requires square downsampling, "
                f"got block size ({block_h}, {block_w})"
            )
            
            return SpaceToDepth(block_size=block_h, target_size=(h_out, w_out))
        
        # Traditional pooling methods
        factor_h = h_in // h_out
        factor_w = w_in // w_out

        assert factor_h == factor_w, (
            f"Expected square downsampling from {skip_shape} to {curr_shape}"
        )
        assert (factor_h & (factor_h - 1)) == 0, (
            "Downsampling factor must be power of 2"
        )

        if pool_type == 'avg':
            return nn.AvgPool2d(kernel_size=factor_h, stride=factor_h)
        elif pool_type == 'max':
            return nn.MaxPool2d(kernel_size=factor_h, stride=factor_h)
        else:
            raise ValueError(f"Unknown pool_type {pool_type}")


class LayerBuilder:
    """Handles building individual layers from specifications with class grouping support."""
    
    @staticmethod
    def build_layer(
        spec: Dict[str, Any], 
        x: torch.Tensor, 
        num_classes: int,
        exclude_params: set = None,
        layer_name: str = ""  # Add layer name for class grouping
    ) -> Tuple[nn.Module, torch.Tensor]:
        """Build a layer from specification with class grouping support."""
        if exclude_params is None:
            exclude_params = set()
            
        layer_type = spec["type"]
        LayerClass = LAYER_REGISTRY.get(layer_type)

        if LayerClass is None:
            raise ValueError(f"Unsupported layer type: {layer_type}")

        # Filter out unwanted parameters
        exclude_params = exclude_params | {"type", "name"}
        kwargs = {k: v for k, v in spec.items() if k not in exclude_params}

        # Handle class grouping for FleaConv layers
        class_grouping_manager = None
        if layer_type in ["FleaConv1d", "FleaBlock"]:
            class_grouping_manager = create_class_grouping_manager(
                layer_name, spec, num_classes
            )
            if class_grouping_manager is not None:
                kwargs["class_grouping_manager"] = class_grouping_manager
                print(f"Layer {layer_name}: Using class grouping with {class_grouping_manager.get_num_classes()} effective classes")

        # Set layer-specific defaults
        LayerBuilder._set_layer_defaults(layer_type, kwargs, x, num_classes)

        print(f"Building layer: {layer_type} with args: {kwargs} and input shape: {x.shape}")
        layer = LayerClass(**kwargs)

        # Test the layer
        try:
            out = layer(x)
            x = out[1] if isinstance(out, tuple) else out
        except Exception as e:
            raise RuntimeError(
                f"Failed to apply layer {layer_type} with args {kwargs} "
                f"and {x.shape = }\n\nError: {e}"
            ) from e

        return layer, x

    @staticmethod
    def _set_layer_defaults(
        layer_type: str, 
        kwargs: Dict[str, Any], 
        x: torch.Tensor, 
        num_classes: int
    ) -> None:
        """Set default parameters for specific layer types."""
        if layer_type == "Linear":
            kwargs.setdefault("in_features", x.shape[1])
            if "out_features" not in kwargs:
                raise ValueError("Missing 'out_features' for Linear layer.")
        
        elif layer_type == "AdaptiveMaxPool2d":
            kwargs.setdefault("output_size", None)

        elif layer_type in {"Conv2d", "FleaBlock"}:
            kwargs.setdefault("in_channels", x.shape[1])
            if "out_channels" not in kwargs:
                raise ValueError(f"Missing 'out_channels' for {layer_type} layer.")
            
            if layer_type == "FleaBlock":
                LayerBuilder._set_flea_conv_defaults(kwargs, x, num_classes)
            else:
                kwargs.setdefault("kernel_size", 3)
                kwargs.setdefault("n_partitions", 1)
                kwargs.setdefault("padding", 1)
                kwargs.setdefault("bias", False)
                
        elif layer_type == "Flatten":
            kwargs.setdefault("start_dim", 1)

    @staticmethod
    def _set_flea_conv_defaults(
        kwargs: Dict[str, Any], 
        x: torch.Tensor, 
        num_classes: int
    ) -> None:
        """Set defaults for FleaBlock layers."""
        # Note: If class grouping manager is present, it will override label_dimension
        kwargs.setdefault("label_dimension", num_classes)
        kwargs.setdefault("kernel_size", 3)
        kwargs.setdefault("input_shape", x.shape[2])
        kwargs.setdefault("alpha", 1)
        
        if "activation_function" in kwargs:
            activation_name = kwargs["activation_function"]
            kwargs["activation_function"] = LAYER_REGISTRY.get(
                activation_name, nn.ReLU
            )()

    @staticmethod
    def build_flea_conv_replacement(
        spec: Dict[str, Any], 
        x: torch.Tensor
    ) -> Tuple[nn.Module, torch.Tensor]:
        """Build a replacement for FleaBlock using standard layers."""
        # Extract parameters
        in_channels = x.shape[1]
        out_channels = spec["out_channels"]
        kernel_size = spec.get("kernel_size", 3)
        stride = spec.get("stride", 1)
        padding = spec.get("padding", 1)
        dropout_prob = spec.get("dropout", 0.0)
        activation_name = spec.get("activation_function", "ReLU")
        
        # Get activation function from registry
        activation_fn = LAYER_REGISTRY.get(activation_name, nn.ReLU)()
        
        layers = [
            ("norm", Normalizer("std", zero_mean=True, eps=1e-6)),
            ("conv", nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False
            )),
            ("batchnorm", nn.BatchNorm2d(out_channels, affine=False)),
            ("activation", activation_fn),
        ]
        
        if dropout_prob > 0:
            layers.append(("dropout", nn.Dropout2d(p=dropout_prob)))
        
        block = nn.Sequential(OrderedDict(layers))
        
        print(f"Building FleaBlock replacement: {in_channels}->{out_channels}, "
              f"activation={activation_name}, dropout={dropout_prob}, input shape: {x.shape}")
        
        try:
            x = block(x)
        except Exception as e:
            raise RuntimeError(
                f"Failed to apply FleaBlock replacement with {x.shape = }\n\nError: {e}"
            ) from e
        
        return block, x

    @staticmethod
    def build_block(
        layer_specs: List[Dict[str, Any]], 
        block_name: str, 
        x: torch.Tensor, 
        num_classes: int,
        exclude_params: set = None
    ) -> Tuple[nn.Module, torch.Tensor]:
        """Build a sequential block of layers."""
        block_layers = []
        for i, spec in enumerate(layer_specs):
            sub_name = spec.get("name", f"{block_name}_layer_{i}")
            layer, x = LayerBuilder.build_layer(spec, x, num_classes, exclude_params, sub_name)
            block_layers.append((sub_name, layer))
        return nn.Sequential(OrderedDict(block_layers)), x


class SkipConnectionBuilder:
    """Handles building skip connections."""

    @staticmethod
    def build_skip_connection(
        layer_cfg: Dict[str, Any],
        x: torch.Tensor,
        tensor_shapes: Dict[str, Tuple[int, ...]]
    ) -> Tuple[SkipConnection, torch.Tensor]:
        """Build a skip connection layer and update tensor shape."""
        skip_from = layer_cfg["skip_from"]
        skip_type = layer_cfg.get("skip_type", "add")
        pool_type = layer_cfg.get("pool_type", "avg")

        curr_shape = x.shape
        skip_shape = tensor_shapes[skip_from]

        pool = PoolingFactory.create_pooling_layer(curr_shape, skip_shape, pool_type)

        layer = SkipConnection(
            skip_from=skip_from,
            skip_type=skip_type,
            pool=pool,
        )

        # Update tensor shape based on skip type and pooling method
        if skip_type == "jump":
            if pool_type == "space_to_depth":
                h_skip, w_skip = skip_shape[-2:]
                h_curr, w_curr = curr_shape[-2:]

                block_h = h_skip // h_curr
                block_w = w_skip // w_curr

                new_channels = skip_shape[1] * block_h * block_w
                x = torch.zeros(1, new_channels, *curr_shape[2:])
            else:
                x = torch.zeros(1, skip_shape[1], *curr_shape[2:])

        elif pool_type == "space_to_depth":
            h_skip, w_skip = skip_shape[-2:]
            h_curr, w_curr = curr_shape[-2:]

            block_h = h_skip // h_curr
            block_w = w_skip // w_curr

            skip_channels_after_transform = skip_shape[1] * block_h * block_w

            if skip_type == "cat":
                new_channels = curr_shape[1] + skip_channels_after_transform
                x = torch.zeros(1, new_channels, *curr_shape[2:])
            elif skip_type == "add":
                assert curr_shape[1] == skip_channels_after_transform, (
                    f"Channel mismatch for 'add' skip with space_to_depth: "
                    f"{curr_shape[1]} != {skip_channels_after_transform}"
                )
                x = torch.zeros(1, curr_shape[1], *curr_shape[2:])

        else:
            if skip_type == "cat":
                new_channels = curr_shape[1] + skip_shape[1]
                x = torch.zeros(1, new_channels, *curr_shape[2:])
            elif skip_type == "add":
                assert curr_shape[1] == skip_shape[1], (
                    f"Channel mismatch for 'add' skip: {curr_shape[1]} != {skip_shape[1]}"
                )
                x = torch.zeros(1, curr_shape[1], *curr_shape[2:])

        return layer, x


class ModelBuilder:
    """Main model builder class with class grouping support."""
    
    @staticmethod
    def build_model_from_config(
        architecture: List[Dict[str, Any]],
        input_shape: Tuple[int, ...],
        num_classes: int,
        model_type: str = "ff"
    ) -> nn.Module:
        """Build model from configuration with class grouping support."""
        if model_type == "ff":
            return ModelBuilder._build_ff_model(architecture, input_shape, num_classes)
        elif model_type == "bp":
            return ModelBuilder._build_bp_model(architecture, input_shape, num_classes)
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    @staticmethod
    def _build_ff_model(
        architecture: List[Dict[str, Any]],
        input_shape: Tuple[int, ...],
        num_classes: int
    ) -> ForwardForwardModel:
        """Build Forward-Forward model with class grouping support."""
        named_layers = OrderedDict()
        trainable_names = []
        tensor_shapes = {"input": torch.zeros(2, *input_shape).shape}
        x = torch.zeros(2, *input_shape)

        for idx, layer_cfg in enumerate(architecture):
            name = layer_cfg.get("name", f"layer_{idx}")
            type_ = layer_cfg["type"]

            if type_ == "skip_connection":
                layer, x = SkipConnectionBuilder.build_skip_connection(
                    layer_cfg, x, tensor_shapes
                )
            elif type_ == "block":
                layer, x = LayerBuilder.build_block(
                    layer_cfg["layers"], name, x, num_classes, FF_SPECIFIC_PARAMS
                )
            else:
                layer, x = LayerBuilder.build_layer(
                    layer_cfg, x, num_classes, FF_SPECIFIC_PARAMS, name
                )
                
                # Wrap trainable layers with FFLayer
                trainable = layer_cfg.get("trainable", False) or type_ == "FleaBlock"
                if trainable:
                    ff_loss_type = layer_cfg.get("ff_loss_type", "bce")
                    layer = FFLayer(name=name, layer=layer, ff_loss_type=ff_loss_type)
                    trainable_names.append(name)

            named_layers[name] = layer
            tensor_shapes[name] = x.shape

        return ForwardForwardModel(nn.ModuleDict(named_layers), trainable_names)

    @staticmethod
    def _build_bp_model(
        architecture: List[Dict[str, Any]],
        input_shape: Tuple[int, ...],
        num_classes: int
    ) -> BackpropModel:
        """Build Backpropagation model with class grouping support."""
        named_layers = OrderedDict()
        tensor_shapes = {"input": torch.zeros(2, *input_shape).shape}
        x = torch.zeros(2, *input_shape)

        for idx, layer_cfg in enumerate(architecture):
            name = layer_cfg.get("name", f"layer_{idx}")
            type_ = layer_cfg["type"]

            if type_ == "skip_connection":
                layer, x = SkipConnectionBuilder.build_skip_connection(
                    layer_cfg, x, tensor_shapes
                )
            elif type_ == "block":
                layer, x = LayerBuilder.build_block(
                    layer_cfg["layers"], name, x, num_classes, BP_SPECIFIC_PARAMS
                )
            else:
                layer, x = LayerBuilder.build_layer(
                    layer_cfg, x, num_classes, BP_SPECIFIC_PARAMS, name
                )

            named_layers[name] = layer
            tensor_shapes[name] = x.shape

        return BackpropModel(nn.ModuleDict(named_layers), num_classes)








    @staticmethod
    def _build_bp_model(
        architecture: List[Dict[str, Any]],
        input_shape: Tuple[int, ...],
        num_classes: int,
        add_classifier: bool = True,
        global_pool_size: Optional[int] = None
    ) -> BackpropModel:
        """Build backpropagation model."""
        named_layers = OrderedDict()
        tensor_shapes = {"input": torch.zeros(2, *input_shape).shape}
        x = torch.zeros(2, *input_shape)

        for idx, layer_cfg in enumerate(architecture):
            name = layer_cfg.get("name", f"layer_{idx}")
            type_ = layer_cfg["type"]

            if type_ == "skip_connection":
                layer, x = SkipConnectionBuilder.build_skip_connection(
                    layer_cfg, x, tensor_shapes
                )
            elif type_ == "block":
                layer, x = LayerBuilder.build_block(
                    layer_cfg["layers"], name, x, num_classes, BP_SPECIFIC_PARAMS
                )
            elif type_ == "FleaBlock":
                layer, x = LayerBuilder.build_flea_conv_replacement(layer_cfg, x)
            else:
                layer, x = LayerBuilder.build_layer(
                    layer_cfg, x, num_classes, BP_SPECIFIC_PARAMS
                )

            named_layers[name] = layer
            tensor_shapes[name] = x.shape

        # Add classifier if requested
        if add_classifier:
            ModelBuilder._add_classifier(named_layers, x, num_classes, global_pool_size)

        return BackpropModel(nn.ModuleDict(named_layers), num_classes)

    @staticmethod
    def _add_classifier(
        named_layers: OrderedDict,
        x: torch.Tensor,
        num_classes: int,
        global_pool_size: Optional[int] = None
    ) -> None:
        """Add classifier layers to the model."""
        # Add global average pooling if we still have spatial dimensions
        if len(x.shape) == 4:  # [batch, channels, height, width]
            pool_size = global_pool_size if global_pool_size is not None else 1
            pool_layer = nn.AdaptiveAvgPool2d(pool_size)
            named_layers["global_pool"] = pool_layer
            x = pool_layer(x)
        
        # Flatten for classifier
        if len(x.shape) > 2:
            flatten_layer = nn.Flatten()
            named_layers["flatten"] = flatten_layer
            x = flatten_layer(x)
        
        # Add classifier
        classifier = nn.Linear(x.shape[1], num_classes)
        named_layers["classifier"] = classifier


# Convenience functions to maintain backward compatibility
def build_model_from_config(
    architecture: List[Dict[str, Any]],
    input_shape: Tuple[int, ...],
    num_classes: int,
) -> nn.Module:
    """Build Forward-Forward model from configuration."""
    return ModelBuilder.build_model_from_config(
        architecture, input_shape, num_classes, model_type="ff"
    )

def build_bp_model_from_config(
    architecture: List[Dict[str, Any]],
    input_shape: Tuple[int, ...],
    num_classes: int,
    add_classifier: bool = True,
    global_pool_size: Optional[int] = None
) -> nn.Module:
    """Build backpropagation model from configuration."""
    return ModelBuilder._build_bp_model(
        architecture, input_shape, num_classes, add_classifier, global_pool_size
    )


# Legacy function aliases for backward compatibility
def _build_block(layer_specs, block_name, x, num_classes):
    """Legacy function - use LayerBuilder.build_block instead."""
    return LayerBuilder.build_block(layer_specs, block_name, x, num_classes, FF_SPECIFIC_PARAMS)


def _build_layer(spec, x, num_classes):
    """Legacy function - use LayerBuilder.build_layer instead."""
    return LayerBuilder.build_layer(spec, x, num_classes, FF_SPECIFIC_PARAMS)


def _build_bp_block(layer_specs, block_name, x, num_classes):
    """Legacy function - use LayerBuilder.build_block instead."""
    return LayerBuilder.build_block(layer_specs, block_name, x, num_classes, BP_SPECIFIC_PARAMS)


def _build_bp_layer(spec, x, num_classes):
    """Legacy function - use LayerBuilder.build_layer instead."""
    if spec["type"] == "FleaBlock":
        return LayerBuilder.build_flea_conv_replacement(spec, x)
    return LayerBuilder.build_layer(spec, x, num_classes, BP_SPECIFIC_PARAMS)


def _build_flea_conv_replacement(spec, x):
    """Legacy function - use LayerBuilder.build_flea_conv_replacement instead."""
    return LayerBuilder.build_flea_conv_replacement(spec, x)
