import torch
import torch.nn as nn
from torchvision.models import resnet
from collections.abc import Sequence

from utils.functions import check_basic_block_structure, check_bottleneck_structure


class SequentialModel(nn.Module):
    def __init__(self):
        super(SequentialModel, self).__init__()
        self.layers = []
        self.dims = [[3, 32, 32], 10]
        # IMPORTANT: include this line at the end of init of each daughter class
        # self.layer_input_shapes = self.get_layer_input_shapes()

    def forward(self, x):
        return self.layers(x)

    def forward_up_to_k_layer(self, x, k):
        out = x
        for i, layer in enumerate(self.layers):
            if i >= k:
                return out
            out = layer(out)
        return out

    def get_layer_input_shapes(self):
        """This function calculates the input shapes for each layer.
        This is required to compute the proper upper bound for CNN and ResNet layers.

        Returns
        -------
            The list of input shapes for each layer in `layers` and the shape of the output layer.
            Use `input_shapes[i]` to get the input shape for layer with index `i`.
            If the layer itself is a Sequential layer or a ResNet BasicBlock, it will have nested input shapes.
        """
        return SequentialModel._get_layer_input_shapes(self.layers, self.dims[0])[0]

    def _get_layer_input_shapes(layers: Sequence, inp_shape: list[int] | int) -> list:
        # check for 1D or nD input
        if isinstance(inp_shape, int):
            input_shape = [inp_shape]
        else:
            input_shape = inp_shape

        # here batchsize = 1
        t = torch.zeros([1] + input_shape)
        input_shapes = []
        last_shape = input_shape

        for layer in layers:
            if isinstance(layer, nn.Sequential):
                res, last_shape = SequentialModel._get_layer_input_shapes(layer, last_shape)
                input_shapes.append(res)
                t = layer(t)
            elif isinstance(layer, resnet.BasicBlock) or isinstance(layer, resnet.Bottleneck):
                # treat the resnet block case separately
                if isinstance(layer, resnet.BasicBlock):
                    assert check_basic_block_structure(layer)
                else:
                    assert check_bottleneck_structure(layer)

                # do it manually here
                # input of prev layer
                resnet_shapes = [last_shape]

                t = layer.conv1(t)
                last_shape = list(t.shape)[1:]

                # input to conv2
                resnet_shapes.append(last_shape)

                # skip batch norms and relu layers as they do not change the shape
                t = layer.conv2(t)
                last_shape = list(t.shape)[1:]

                # add 2nd convolution output for the bottleneck module case
                if isinstance(layer, resnet.Bottleneck):
                    # input to conv3
                    resnet_shapes.append(last_shape)

                    t = layer.conv3(t)
                    last_shape = list(t.shape)[1:]

                # the downsample layer will have the last_shape input shape
                # no need to propogate further as there is no change in shape in other layers
                input_shapes.append(resnet_shapes)
            else:
                # any default layer
                input_shapes.append(last_shape)
                t = layer(t)
                last_shape = list(t.shape)[1:]

        return input_shapes, last_shape
