import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Layers.layers import conv3x3_block, Conv2d, BatchNorm2d, Linear, sparse_initialize, ConvBlock, conv1x1


class fc(nn.Module):
    def __init__(self, input_shape, num_classes, L=6, N=2000, nonlinearity=nn.ReLU()):
        super(fc, self).__init__()
        size = np.prod(input_shape)

        # Linear feature extractor
        modules = [nn.Flatten()]
        modules.append(Linear(size, N, bias=False))
        modules.append(nonlinearity)
        for i in range(L - 2):
            modules.append(Linear(int(N * (0.8**i)), int(N * (0.8**(i + 1))), bias=False))
            modules.append(nonlinearity)

        # Linear classifier
        modules.append(Linear(int(N * (0.8**(L - 2))), num_classes, bias=False))
        self.model = nn.Sequential(*modules)
        self._initialize_weights(False)

    def forward(self, input):
        return self.model(input)

    @torch.no_grad()
    def _initialize_weights(self, sparse_init):
        for module in self.modules():
            if isinstance(module, Linear):
                module.weight.fill_(1)
                module.weight += torch.randn_like(module.weight) * 1e-6


class conv(nn.Module):
    def __init__(self, input_shape, num_classes, L=6, N=32, nonlinearity=nn.ReLU()):
        super(conv, self).__init__()
        size = np.prod(input_shape)
        channels, width, height = input_shape

        assert L == 6
        modules = []
        modules.append(Conv2d(channels, N, kernel_size=3))
        modules.append(nonlinearity)
        modules.append(Conv2d(N, N, kernel_size=3))
        modules.append(nonlinearity)
        modules.append(Conv2d(N, 2 * N, kernel_size=3))
        modules.append(nonlinearity)
        modules.append(Conv2d(2 * N, 2 * N, kernel_size=3))
        modules.append(nonlinearity)
        modules.append(Conv2d(2 * N, 4 * N, kernel_size=3))
        modules.append(nonlinearity)
        modules.append(Conv2d(4 * N, 4 * N, kernel_size=3))
        modules.append(nonlinearity)

        # Linear classifier
        modules.append(nn.Flatten())
        modules.append(Linear(51200, num_classes))
        self.model = nn.Sequential(*modules)
        self._initialize_weights(False)

    def forward(self, input):
        return self.model(input)

    @torch.no_grad()
    def _initialize_weights(self, sparse_init):
        for module in self.modules():
            if isinstance(module, Conv2d):
                # assert False
                module.bias.fill_(0)
                module.weight.fill_(1)
                module.weight += torch.randn_like(module.weight) * 1e-6


def mlp_fc6(groups=1, width_factor=1, **kwargs):
    assert groups == 1 and width_factor == 1
    return fc(input_shape=(3, 32, 32), num_classes=10)


def mlp_conv6(groups=1, width_factor=1, **kwargs):
    assert groups == 1 and width_factor == 1
    return conv(input_shape=(3, 32, 32), num_classes=10)


if __name__ == "__main__":
    import torch
    model = mlp_L6()
    x = torch.randn(1, 3, 32, 32)
    y = model(x)
    y.sum().backward()
    print(y)
    print(model)