import torch.nn as nn
import torch
from functools import partial
from bypass.core.activation import ActivationForBypass, ActivationForActivionChange


class DNNMNIST(nn.Module):
    def __init__(self):
        super(DNNMNIST, self).__init__()
        self.proj1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Sequential(nn.Linear(128, 64),nn.Linear(64, 64))
        self.proj2 = nn.Linear(64, 10)

        self.relu1=ActivationForBypass(512,nn.ReLU())
        self.relu2=ActivationForBypass(256,nn.ReLU())
        self.relu3=ActivationForBypass(128,nn.ReLU())
        self.relu4=ActivationForBypass(64,nn.ReLU())

    def forward(self, x):
        x = x.float()
        h1= self.relu1(self.proj1(x.view(-1, 784)))
        h2= self.relu2(self.fc2(h1))
        h3= self.relu3(self.fc3(h2))
        h4= self.relu4(self.fc4(h3))
        h5= self.proj2(h4)

        return h5
class simplecnn(nn.Module):
    def __init__(self):
        super(simplecnn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=5, stride=1)
        self.fc1 = nn.Linear(9 * 24 * 24, 50)
        self.fc2 = nn.Linear(50, 10)

        self.relu1=nn.ReLU()
        self.relu2=nn.ReLU()
        self.relu3=ActivationForBypass(50,nn.ReLU())


    def forward(self, x):
        x=self.conv1(x)
        x=self.relu1(x)
        x=self.conv2(x)
        x=self.relu2(x)
        x= x.view(-1, 9 * 24 * 24)
        x=self.fc1(x)
        x=self.relu3(x)
        x=self.fc2(x)
        return x

if __name__ == '__main__':
    from bypass.core.detect import BypassDetector, register_detector,remove_detector
    model=DNNMNIST()
    detect=BypassDetector()
    register_detector(model,detect)

    dummy_input=torch.ones([1,28,28])
    model.forward(dummy_input)

    detect.summarize()

    # from torchsummary import summary
    # # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device='cpu'
    # summary(model.to(device=device), (1, 28,28),device='cpu')
    remove_detector(model)
    print(1)