import torch as t
import torch.nn as nn
import numpy as np


class CNNblock(nn.Module):
    def __init__(self,input_dim,output_dim,filter_size,pl=False):
        super(CNNblock,self).__init__()
        self.pl = pl
        self.conv1 = nn.Conv2d(input_dim,output_dim,filter_size,1)
        self.bn = nn.BatchNorm2d(output_dim)
        self.drop = nn.Dropout()
        if self.pl == True:
            self.pool = nn.AvgPool2d(2,2)
        self.relu = nn.ReLU()

    def forward(self,x):
        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.drop(out)
        if self.pl == True:
            out = self.pool(out)
        return out

class CNN_0(nn.Module):
    def __init__(self,img_dim):
        super(CNN_0, self).__init__()
        self.num_feature = 8
        self.filter_size = 4

        self.layer1 = self.block_make(img_dim,self.num_feature,self.filter_size,True) 
        self.layer2 = self.block_make(self.num_feature,self.num_feature*4,self.filter_size,True) 
        self.layer3 = self.block_make(self.num_feature*4,self.num_feature * 8,self.filter_size,False) 
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.num_feature * 8 ,1000)
        
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)
            elif isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)

    def block_make(self,input_dim,output_dim,filter_size,pl):
        layers = []
        layers.append(CNNblock(input_dim,output_dim,filter_size,pl))
        return nn.Sequential(*layers)

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = t.flatten(out,1)
        out = self.fc(out)

        return out

class CNN_1(nn.Module):
    def __init__(self,img_dim):
        super(CNN_1, self).__init__()
        self.num_feature = 32
        self.filter_size_1 = 3
        
        self.layer1 = self.block_make(img_dim,self.num_feature,self.filter_size_1,True)
        self.layer2 = self.block_make(self.num_feature,self.num_feature*2,self.filter_size_1,True) 
        self.layer3 = self.block_make(self.num_feature*2,self.num_feature*4,self.filter_size_1,True)
        self.layer4 = self.block_make(self.num_feature*4,self.num_feature*8,self.filter_size_1,False) 
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.num_feature*8,1000)
        
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)
            elif isinstance(m,nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.fill_(0)

    def block_make(self,input_dim,output_dim,filter_size,pl):
        layers = []
        layers.append(CNNblock(input_dim,output_dim,filter_size,pl))
        return nn.Sequential(*layers)

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = t.flatten(out,1)
        out = self.fc(out)
        
        return out