from neurobench.models import NeuroBenchModel
from ..custom.custom_connections import multihead_attention_operation

class StorkModel(NeuroBenchModel):
    """The TorchModel class wraps an nn.Module."""

    def __init__(self, net):
        """
        Initializes the TorchModel class.

        Args:
            net: A PyTorch nn.Module.

        """
        super().__init__(net)

        self.net = net
        self.net.eval()

        self.supported_layers = self.supported_layers+(multihead_attention_operation,)

    def __call__(self, batch):
        """
        Wraps forward pass of torch.nn model.

        Args:
            batch: A PyTorch tensor of shape (batch, timesteps, features*)

        Returns:
            preds: either a tensor to be compared with targets or passed to
                NeuroBenchPostProcessors.

        """
        pred_labels = self.net.predict(batch).detach().cpu()
        return pred_labels

    def __net__(self):
        """Returns the underlying network."""
        return self.net

    def activation_layers(self):
        return self.net.groups[1:-1]

    def connection_layers(self):
        """Retrieve all the connection layers of the underlying network
        (torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d, torch.nn.Conv3d)"""
        supported_layers = self.supported_layers

        def get_connection_layers(parent):
            """Returns all the connection layers."""
            connection_layers = []
            children = parent.children()
            for child in children:
                grand_children = list(child.children())

                if isinstance(child, supported_layers):
                    connection_layers.append(child)
                elif len(grand_children) != 0:  # leaf child
                    children_layers = get_connection_layers(child)
                    connection_layers.extend(children_layers)

            return connection_layers

        root = self.__net__()
        connection_layers = get_connection_layers(root)
        return connection_layers