import sys
import math
# from https://github.com/sungyubkim/GBML/blob/master/net/resnet.py
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):

    def __init__(self, in_channels, out_channels, out_features, size):
        super().__init__()
       
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_features = out_features
        self.size = size
        '''
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=8, stride=8, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )
        '''
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=4, stride=4, padding=1),
            nn.BatchNorm2d(self.in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=2, stride=2, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.out_channels * math.ceil(self.size/7) * math.ceil(self.size/7), out_features=self.out_features),
        )
        self.init_params()
        return None

    def init_params(self):
        for k, v in self.named_parameters():
            if ('conv' in k) or ('meta' in k):
                if ('weight' in k):
                    nn.init.kaiming_uniform_(v)
                elif ('bias' in k):
                    nn.init.constant_(v, 0.0)
            elif ('bn' in k):
                if ('weight' in k):
                    nn.init.constant_(v, 1.0)
                elif ('bias' in k):
                    nn.init.constant_(v, 0.0)
        return None

    def forward(self, x, is_pre=True):
       
        out = self.encoder(x) 
     
        out = self.decoder(out.reshape(out.shape[0], -1)) 

        return out


# Some results on number of model parameters:

# For fashionmnist and mnist dataset, each image has 1*28*28 = 784 dimensions.
# If in_channels = 1, out_channels = 1, out_features = 10, number of parameters is 167.
# If in_channels = 1, out_channels = 2, out_features = 10, number of parameters is 324.

# For cifar-10 dataset, each image has 3*32*32 = 3072 dimensions.
# If in_channels = 3, out_channels = 3, out_features = 10, number of parameters is 167.
# If in_channels = 3, out_channels = 4, out_features = 10, number of parameters is 324.

# For imagenet dataset, each image has 3*32*32 = 3072 dimensions.
# If in_channels = 3, out_channels = 3, out_features = 10, number of parameters is 167.
# If in_channels = 3, out_channels = 4, out_features = 10, number of parameters is 324.