
@define(slots=False)
class RecurrentFeedbackNetworkParams(BaseParameters):
    use_lateral_connection: bool = True
    lateral_layer_params: List[BaseParameters] = None
    use_feedback_connection: bool = True
    feedback_layer_params: List[BaseParameters] = None
    layer_params: List[BaseParameters] = []

class RecurrentFeedbackNetwork(AbstractModel, CommonModelMixin):
    @define(slots=False)
    class ModelParams(BaseParameters):
        common_params: CommonModelParams
        recurrent_feedback_net_params: RecurrentFeedbackNetworkParams
        def __attrs_post_init__(self):
            self.common_params: CommonModelParams = CommonModelParams()
    
    def __init__(self, params: BaseParameters) -> None:
        super().__init__(params)
        self.params: RecurrentFeedbackNetwork.ModelParams = params
        self.load_common_params()
        
        self.use_lateral_connection = self.params.recurrent_feedback_net_params.use_lateral_connection
        self.lateral_layer_params = self.params.recurrent_feedback_net_params.lateral_layer_params
        self.use_feedback_connection = self.params.recurrent_feedback_net_params.use_feedback_connection
        self.feedback_layer_params = self.params.recurrent_feedback_net_params.feedback_layer_params
        self.layer_params = self.params.recurrent_feedback_net_params.layer_params

    def _make_network(self):
        x = torch.rand(1,*(self.params.common_params.input_size))
        layers = []
        lat_layers = []
        conv_shapes = []
        print(x.shape)
        for lp in self.params.layer_params:
            lp.common_params.input_size = x.shape[1:]
            l = lp.cls(lp)
            x = l(x)
            if isinstance(x, tuple):
                x = x[0]
            layers.append(l)
        self.trunk = nn.ModuleList(layers)

class MHSAParams(BaseParameters):
    num_heads: int = 1
    

class MultiHeadSelfAttentionLayer(AbstractModel):
    def _make_network(self):
        self.heads = nn.MultiheadAttention(self.num_units, self.num_heads, dropout=self.dropout_p, 
                                            bias=self.bias, batch_first=True)


- add mask to consistency optimization
- add CO layerwise starting from the input
- zero initialization of weights for concatenated position embeddings
