import numpy as np
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self,input_channels, num_classes):
        super(CNN, self).__init__()
        #model param
        self.input_channels=input_channels
        self.num_classes=num_classes
        self.kernal_size=3
        self.conv2_W=5 
        self.conv2_H=5
        self.featuremap_layers_size=[input_channels,32,32,64,64,512,self.num_classes] # 4 conv 1 dense 1 output
        self.noParams=[15, 140, 250, 500, 2000, 100]
        print(self.noParams)
        #model arch
        self.feature_1 = nn.Sequential(
            nn.Conv2d(self.featuremap_layers_size[0], self.featuremap_layers_size[1], kernel_size=self.kernal_size, stride=1,bias=True),
            nn.ReLU(inplace=True),
        )
        self.feature_2 = nn.Sequential(
            nn.Conv2d(self.featuremap_layers_size[1], self.featuremap_layers_size[2], kernel_size=self.kernal_size, stride=1, bias=True),
            nn.ReLU(inplace=True),
        )
        self.maxpool_BN_2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
	    #nn.Dropout(0.25),
            nn.BatchNorm2d(self.featuremap_layers_size[2]),
        )
        self.feature_3 = nn.Sequential(
            nn.Conv2d(self.featuremap_layers_size[2], self.featuremap_layers_size[3], kernel_size=self.kernal_size, stride=1, bias=True),
            nn.ReLU(inplace=True),
        )
        self.feature_4 = nn.Sequential(
            nn.Conv2d(self.featuremap_layers_size[3], self.featuremap_layers_size[4], kernel_size=self.kernal_size, stride=1, bias=True),
            nn.ReLU(inplace=True),
        )
        self.maxpool_BN_4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
	    nn.BatchNorm2d(self.featuremap_layers_size[4]),
	    #nn.Dropout(0.25),
        )
        self.linear_layer = nn.Sequential(
            nn.Linear(self.conv2_W*self.conv2_H*self.featuremap_layers_size[4], self.featuremap_layers_size[5]),
	    nn.BatchNorm1d(self.featuremap_layers_size[5]),
	    nn.ReLU(inplace=True),
	    #nn.Dropout(0.5),
        )
        self.classifier =  nn.Linear(self.featuremap_layers_size[5], self.num_classes)

        self.layers_names = []
        for name, param in self.named_parameters():
            if self.take_layer(name,param):
                self.layers_names.append(name)
        self.layers_names.append(name)

    def forward(self, x):
        f_x1 = self.feature_1(x)
        f_x2 = self.feature_2(f_x1)
        M_x2 = self.maxpool_BN_2(f_x2)
        f_x3 = self.feature_3(M_x2)
        f_x4 = self.feature_4(f_x3)
        M_x4 = self.maxpool_BN_4(f_x4)
        flatten = M_x4.view(M_x4.size(0), -1)
        #x = torch.flatten(x, 1)
        L_x1 = self.linear_layer(flatten)
        x = self.classifier(L_x1)
        return x, f_x1, f_x2, f_x3, f_x4, L_x1    

    def take_layer(self, name,param):
        if len(param.shape)>1:
            return True
        else:
            return False

    def last_layer(self,name):
        if ((name in self.layers_names[-1]) or (name in self.layers_names[-2])):
            return True
        else:
            return False
    def flatten_layer(self,name):
        if("linear_layer.0" in name):
            return True
        return False
        