"""recur_cnn.py
Recurrent cnn models.
"""
import torch
import torch.nn as nn


class RCNN(nn.Module):
    def __init__(self, width=64, depth=4, in_channels=3, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.iters = depth - 3
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width/2),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU(),
                                          nn.Conv2d(int(self.width/2), self.width, kernel_size=3,
                                                    stride=1),
                                          nn.ReLU())
        self.recur_block = nn.Sequential(nn.Conv2d(self.width, self.width, kernel_size=3, stride=1,
                                                   padding=1), nn.ReLU())
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        if self.dataset.upper() == "CIFAR10":
            self.linear = nn.Linear(8 * width, num_classes)
        else:
            self.linear = nn.Linear(72 * width, num_classes)

    def forward(self, x):
        self.thoughts = torch.zeros((self.iters, x.shape[0], self.num_classes)).to(x.device)

        out = self.first_layers(x)
        for i in range(self.iters):
            out = self.recur_block(out)
            thought = self.last_layers(out)
            thought = thought.view(thought.size(0), -1)
            self.thoughts[i] = self.linear(thought)
        return self.thoughts[-1]

class RCNN_MNIST(nn.Module):
    def __init__(self, width=16, depth=4, in_channels=1, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.iters = depth - 3
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU())
        self.recur_block = nn.Sequential(nn.Conv2d(self.width, self.width, kernel_size=3, stride=1,
                                                   padding=1), nn.ReLU())
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        self.linear = nn.Linear(8 * width, num_classes)

    def forward(self, x):
        self.thoughts = torch.zeros((self.iters, x.shape[0], self.num_classes)).to(x.device)

        out = self.first_layers(x)
        for i in range(self.iters):
            out = self.recur_block(out)
            thought = self.last_layers(out)
            thought = thought.view(thought.size(0), -1)
            self.thoughts[i] = self.linear(thought)
        return self.thoughts[-1]

class RCNN_EMNIST(nn.Module):
    def __init__(self, width=128, depth=4, in_channels=1, num_classes=10, dataset="CIFAR10"):
        super().__init__()
        self.dataset = dataset
        self.width = width
        self.depth = depth
        self.iters = depth - 3
        self.num_classes = num_classes
        self.in_channels = in_channels
        self.first_layers = nn.Sequential(nn.Conv2d(self.in_channels, int(self.width/2),
                                                    kernel_size=3, stride=1),
                                          nn.ReLU(),
                                          nn.Conv2d(int(self.width/2), self.width, kernel_size=3,
                                                    stride=1),
                                          nn.ReLU())
        self.recur_block = nn.Sequential(nn.Conv2d(self.width, self.width, kernel_size=3, stride=1,
                                                   padding=1), nn.ReLU())
        self.last_layers = nn.Sequential(nn.MaxPool2d(3),
                                         nn.Conv2d(self.width, 2*self.width, kernel_size=3,
                                                   stride=1),
                                         nn.ReLU(),
                                         nn.MaxPool2d(3))

        self.linear = nn.Linear(8 * width, num_classes)

    def forward(self, x):
        self.thoughts = torch.zeros((self.iters, x.shape[0], self.num_classes)).to(x.device)

        out = self.first_layers(x)
        for i in range(self.iters):
            out = self.recur_block(out)
            thought = self.last_layers(out)
            thought = thought.view(thought.size(0), -1)
            self.thoughts[i] = self.linear(thought)
        return self.thoughts[-1]

def recur_cnn_4(num_outputs=10):
    return RCNN(num_classes=num_outputs, depth=4)


def recur_cnn_5(num_outputs=10):
    return RCNN(num_classes=num_outputs, depth=5)


def recur_cnn_6(num_outputs=10):
    return RCNN(num_classes=num_outputs, depth=6)


def recur_cnn_7(num_outputs=10):
    return RCNN(num_classes=num_outputs, depth=7)


def recur_cnn_8(num_outputs=10):
    return RCNN(num_classes=num_outputs, depth=8)

def recur_cnn_4_mnist(num_outputs=10):
    return RCNN_MNIST(num_classes=num_outputs, depth=4)


def recur_cnn_5_mnist(num_outputs=10):
    return RCNN_MNIST(num_classes=num_outputs, depth=5)


def recur_cnn_6_mnist(num_outputs=10):
    return RCNN_MNIST(num_classes=num_outputs, depth=6)


def recur_cnn_7_mnist(num_outputs=10):
    return RCNN_MNIST(num_classes=num_outputs, depth=7)


def recur_cnn_8_mnist(num_outputs=10):
    return RCNN_MNIST(num_classes=num_outputs, depth=8)


def recur_cnn_4_emnist(num_outputs=47):
    return RCNN_EMNIST(num_classes=num_outputs, depth=4)


def recur_cnn_5_emnist(num_outputs=47):
    return RCNN_EMNIST(num_classes=num_outputs, depth=5)


def recur_cnn_6_emnist(num_outputs=47):
    return RCNN_EMNIST(num_classes=num_outputs, depth=6)


def recur_cnn_7_emnist(num_outputs=47):
    return RCNN_EMNIST(num_classes=num_outputs, depth=7)


def recur_cnn_8_emnist(num_outputs=47):
    return RCNN_EMNIST(num_classes=num_outputs, depth=8)


def recur_cnn_5_tinyimagenet(num_outputs=200):
    return RCNN(num_classes=num_outputs, depth=5, dataset="TINYIMAGENET")


def recur_cnn_6_tinyimagenet(num_outputs=200):
    return RCNN(num_classes=num_outputs, depth=6, dataset="TINYIMAGENET")


def recur_cnn_7_tinyimagenet(num_outputs=200):
    return RCNN(num_classes=num_outputs, depth=7, dataset="TINYIMAGENET")


def recur_cnn_8_tinyimagenet(num_outputs=200):
    return RCNN(num_classes=num_outputs, depth=8, dataset="TINYIMAGENET")
