import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch

class SimpleNet(nn.Module):
    def __init__(self,in_features=32*32*3,out_features=10,hidden_num=100000):
        super(SimpleNet,self).__init__()
        self.main=nn.Sequential(nn.Flatten(),
                                nn.Linear(in_features,hidden_num),
                                nn.ReLU(inplace=False),
                                nn.Linear(hidden_num,hidden_num),
                                nn.ReLU(inplace=False),
                                )
        self.fc=nn.Linear(hidden_num, out_features)
    def forward(self, x):

        h= self.main(x)
        out=self.fc(h)
        return out

class SimpleCNN(nn.Module):
    def __init__(self,kernel_size=3,out_features=10,hidden_num=64):
        super(SimpleCNN,self).__init__()
        self.main=nn.Sequential(
                nn.Conv2d(3, hidden_num, kernel_size=3, stride=1, padding=1),
                #nn.BatchNorm2d(64),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
                nn.Conv2d(hidden_num, hidden_num, kernel_size=3, stride=1, padding=1),
                #nn.BatchNorm2d(128),
                nn.ReLU(),
                nn.AvgPool2d(kernel_size=2, stride=2, padding=1),
                nn.Conv2d(hidden_num,hidden_num,kernel_size=3, stride=1, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten(),
                nn.Linear(hidden_num, out_features),
                                )
    def forward(self, x):
        out= self.main(x)
        return out

class VGG_Block(nn.Module):
    def __init__(self,inplane,plane,n):
        super(VGG_Block,self).__init__()
        layer=[]
        for i in range(n):
            layer+=[nn.Conv2d(inplane, plane, kernel_size=3, stride=1, padding=1,bias=False),
                nn.BatchNorm2d(plane),
                nn.ReLU()]
            inplane=plane


class Simple_cifar(nn.Module):
    def __init__(self,num_classes=10,filter_num=64,n=1):
        super(Simple_cifar,self).__init__()
        inplane=3
        plane=filter_num
        layers = []
        for i in range(n):
            layers += [nn.Conv2d(inplane, plane, kernel_size=3, stride=1, padding=1, bias=False),
                      nn.BatchNorm2d(plane),
                      nn.ReLU()]
            inplane = plane
        layers+=[nn.MaxPool2d(kernel_size=2, stride=2)]
        plane=2*plane
        for i in range(n):
            layers += [nn.Conv2d(inplane, plane, kernel_size=3, stride=1, padding=1, bias=False),
                      nn.BatchNorm2d(plane),
                      nn.ReLU()]
            inplane = plane
        layers+=[nn.MaxPool2d(kernel_size=2, stride=2)]
        plane = 2 * plane
        for i in range(n):
            layers += [nn.Conv2d(inplane, plane, kernel_size=3, stride=1, padding=1, bias=False),
                      nn.BatchNorm2d(plane),
                      nn.ReLU()]
            inplane = plane
        layers += [nn.AdaptiveAvgPool2d((7,7)),
                nn.Flatten(),]
        layers+=[nn.Linear(plane * 7 * 7, 4096,bias=False),
            nn.BatchNorm1d(4096),
            nn.ReLU(True),
            nn.Linear(4096, 4096,bias=False),
            nn.BatchNorm1d(4096),
            nn.ReLU(True),
            nn.Linear(4096, num_classes),]
        self.main=nn.Sequential(*layers)
    def forward(self, x):
        out= self.main(x)
        return out