import torch

import model
import modules


class SimpleCnn(model.Model):

    NAME = "simple-cnn"
    EXPECTED_DIM = 32
    CHANNELS = [16, 16, 32, 64]
    KERNEL = 3
    HIDDEN = 256

    def create_layers(self, input_size):
        """Return list of torch.nn.Module objects, layers of this network.

        Parameters:
        ===========
        input_size: tuple of int dimensions of input.
        """
        C, W, H = input_size
        assert W == H == SimpleCnn.EXPECTED_DIM

        channels = self.get_channels()
        h = self.get_hidden_size()

        return [
            self._create_block(
                in_c=C,
                out_c=channels[0],
                bias=True
            ),
            *[
                self._create_resblock(
                    in_c=channels[i],
                    out_c=channels[i+1],
                    stride=[1, 2][channels[i+1] > channels[i]]
                ) for i in range(len(channels)-1)
            ],
            modules.AssertShape(channels[-1], 8, 8),
            torch.nn.BatchNorm2d(channels[-1]),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(8),
            modules.Reshape(channels[-1]),
            torch.nn.Linear(channels[-1], h),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(h, h),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5)
        ]

    def get_channels(self):
        return SimpleCnn.CHANNELS

    def get_hidden_size(self):
        return SimpleCnn.HIDDEN

    def create_classifier(self, targets):
        """Return the classifier module of this network.

        Parameters:
        ===========
        targets: int number of classes to predict.
        """
        return torch.nn.Linear(self.get_hidden_size(), targets)

    def _create_block(self, in_c, out_c, activation=None, stride=1, bias=False):
        """Return a basic convolutional block.

        Parameters:
        ===========
        in_c: int number of input channels
        out_c: int number of output channels
        activation: optional torch.nn.Module
        """
        return torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=in_c,
                out_channels=out_c,
                kernel_size=SimpleCnn.KERNEL,
                padding=SimpleCnn.KERNEL//2,
                stride=stride,
                bias=bias
            ),
            torch.nn.BatchNorm2d(out_c),
            activation or torch.nn.Sequential()
        )

    def _create_resblock(self, in_c, out_c, stride):
        return modules.ResBlock(
            torch.nn.BatchNorm2d(in_c),
            self._create_block(
                in_c=in_c,
                out_c=out_c,
                activation=torch.nn.ReLU(),
                stride=stride
            ),
            self._create_block(
                in_c=out_c,
                out_c=out_c
            ),
            shortcut={
                1: torch.nn.Sequential(),
                2: torch.nn.Conv2d(
                    in_channels=in_c,
                    out_channels=out_c,
                    kernel_size=3,
                    padding=1,
                    stride=2,
                    bias=False
                )
            }[stride]
        )
