# forward_forward/ff_layer.py

import torch
import torch.nn as nn
from typing import Callable, Union

class FFLayer(nn.Module):
    """
    A Forward-Forward wrapper for a block/layer, providing goodness computation for
    positive and negative samples. Loss selection is left to the trainer.
    """

    def __init__(
        self,
        name: str,
        layer: nn.Module,
        goodness_fn: Union[str, Callable] = "sum_of_squares",
        ff_loss_type: str = "bce",
    ):
        """
        Args:
            name (str): Unique name for the layer block.
            layer (nn.Module): Internal layers (e.g., nn.Sequential).
            goodness_fn (str or Callable): Goodness metric (if required).
        """
        super().__init__()
        if not isinstance(layer, nn.Module):
            raise ValueError("Expected `layer` to be a torch.nn.Module.")
        self.name = name
        self.layer = layer
        self.goodness_fn = goodness_fn
        self.ff_loss_type = ff_loss_type
        # For compatibility with Trainer
        self.predict = getattr(self.layer, "predict", None)
        self.accepts_label = getattr(self.layer, "accepts_label", True)
        self.accepts_mode = getattr(self.layer, "accepts_mode", True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Standard forward (for inference/testing)."""
        return self.layer(x)
