from .custom.models import CustomRecurrentSpikingModel
import torch
import stork
from .custom.custom_connections import ChannelAttentionConnection
from .custom.custom_connections import ChannelAttentionConnection_multiHead

class torch_style_model(torch.nn.Module):
    def __init__(self, model):
        super(torch_style_model, self).__init__()
        self.module_list = torch.nn.ModuleList(model.propagate_seq)


    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        return x




#
# def stork2torch(model:CustomRecurrentSpikingModel):
#     """Convert stork model to a PyTorch-style model."""
#
#     class torch_style_model(torch.nn.Module):
#         def __init__(self, model):
#             super(torch_style_model, self).__init__()
#             self.model = model
#             self.propagate_seq = model.propagate_seq
#
#         def forward(self, x):
#             for g_or_c in self.propagate_seq:
#                 x = g_or_c(x)
#             return x
#
#
#     # torch_model =
#     # for g_or_c in model.propagate_seq: