import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class NN(nn.Module):
    def __init__(self, num_class):
        super(NN,self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=1,out_channels=8,kernel_size=13,stride=1)
        self.dropout1 = nn.Dropout(0.3) 
    
        self.conv2 = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=11,stride=1)
        self.dropout2 = nn.Dropout(0.3)
        
        self.conv3 = nn.Conv1d(in_channels=16,out_channels=32,kernel_size=9,stride=1)
        self.dropout3 = nn.Dropout(0.3)
        
        self.conv4 = nn.Conv1d(in_channels=32,out_channels=64,kernel_size=7,stride=1)
        self.dropout4 = nn.Dropout(0.3)
        
        self.fc1 = nn.Linear(12416, 256)
        self.dropout5 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout6 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(128, num_class)
        
    def forward(self, x):
        
        x = F.max_pool1d(F.relu(self.conv1(x)),kernel_size=3)
        x = self.dropout1(x)
        
        x = F.max_pool1d(F.relu(self.conv2(x)),kernel_size=3)
        x = self.dropout2(x)
        
        x = F.max_pool1d(F.relu(self.conv3(x)),kernel_size=3)
        x = self.dropout3(x)
        
        x = F.max_pool1d(F.relu(self.conv4(x)),kernel_size=3)
        x = self.dropout4(x)
        
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2])))
        x = self.dropout5(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout6(x)
        
        x = self.fc3(x)
        
        #print(x.shape)
        return x 
    
class NN2D(nn.Module):
    def __init__(self, num_class):
        super(NN2D,self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1)
        self.dropout1 = nn.Dropout(0.3) 
    
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1)
        self.dropout2 = nn.Dropout(0.3)
        
        #self.conv3 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1)
        #self.dropout3 = nn.Dropout(0.3)
        
        #self.conv4 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1)
        #self.dropout4 = nn.Dropout(0.3)
        
        self.fc1 = nn.Linear(384, 256)
        self.dropout5 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout6 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(128, num_class)
        
    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3)
        x = self.dropout1(x)
        
        x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3)
        x = self.dropout2(x)
        
        #x = F.max_pool2d(F.relu(self.conv3(x)),kernel_size=3)
        #x = self.dropout3(x)
        
        #x = F.max_pool2d(F.relu(self.conv4(x)),kernel_size=3)
        #x = self.dropout4(x)
        
        #print(x.shape)
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2]*x.shape[3])))
        x = self.dropout5(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout6(x)
        
        x = self.fc3(x)
        
        #print(x.shape)
        return x 
    
    
class NN2DMEL(nn.Module):
    def __init__(self, num_class):
        super(NN2DMEL,self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1)
        self.dropout1 = nn.Dropout(0.3) 
    
        self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1)
        self.dropout2 = nn.Dropout(0.3)
        
        #self.conv3 = nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1)
        #self.dropout3 = nn.Dropout(0.3)
        
        #self.conv4 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,stride=1)
        #self.dropout4 = nn.Dropout(0.3)
        
        self.fc1 = nn.Linear(1664, 256)
        self.dropout5 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256,128)
        self.dropout6 = nn.Dropout(0.3)
        self.fc3 = nn.Linear(128, num_class)
        
    def forward(self, x):
        
        x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3)
        x = self.dropout1(x)
        
        x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3)
        x = self.dropout2(x)
        
        #x = F.max_pool2d(F.relu(self.conv3(x)),kernel_size=3)
        #x = self.dropout3(x)
        
        #x = F.max_pool2d(F.relu(self.conv4(x)),kernel_size=3)
        #x = self.dropout4(x)
        
        #print(x.shape)
        x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2]*x.shape[3])))
        x = self.dropout5(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout6(x)
        
        x = self.fc3(x)
        
        #print(x.shape)
        return x 
    
class ReLU_full_grad(torch.autograd.Function):
    """ ReLU activation function that passes through the gradient irrespective of its input value. """

    @staticmethod
    def forward(ctx, input):
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.clone()   


class SeparableConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SeparableConvBlock, self).__init__()
        self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
        self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.batch_norm1 = nn.BatchNorm2d(in_channels)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.pointwise_conv(x)
        x = self.batch_norm2(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class DS_CNN_BP(nn.Module):
# Define the input shape and the number of classes
# input_shape = (BS, 1, 49, 10)  # (Batchsize, channels, height, width)
# num_classes = 12

    def __init__(self, input_channel=1, num_classes = 12):
        super(DS_CNN_BP, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 32, kernel_size=(10, 4), stride=(2, 2), padding=(5, 1))
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)

        self.sep_conv1 = SeparableConvBlock(32, 64)
        self.sep_conv2 = SeparableConvBlock(64, 128)
        self.sep_conv3 = SeparableConvBlock(128, 128)
        self.sep_conv4 = SeparableConvBlock(128, 128)

        self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2), padding=1)
        self.flatten = nn.Flatten()
        #self.fc = nn.Linear(2496, num_classes)
        self.fc = nn.Linear(4992, num_classes)

        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        self.gradient_num += 32*3*3 #conv1
        self.gradient_num += 64*(3*3*32) + 32*(3*3*1) #sep_conv1
        self.gradient_num += 128*(3*3*64) + 64*(3*3*1) #sep_conv2
        self.gradient_num += 128*(3*3*128) + 128*(3*3*1) #sep_conv3
        self.gradient_num += 128*(3*3*128) + 128*(3*3*1) #sep_conv4
        self.gradient_num += 12*4992 #fc

        self.activation_num += (32+64+128+128)*25*5
        self.activation_num += 1*49*10
        self.activation_num += 12

        self.error_num += 32*25*5
        self.error_num += (64+128+128+128)*25*5
        self.error_num += 4992



    def forward(self, x):
        x = self.conv1(x) #input: torch.Size([200, 1, 49, 10])
        x = self.batch_norm1(x)# torch.Size([200, 32, 25, 5])
        x = self.relu(x) # torch.Size([200, 32, 25, 5])
        x = self.dropout(x) # torch.Size([200, 32, 25, 5])

        x = self.sep_conv1(x)# torch.Size([200, 32, 25, 5])
        x = self.sep_conv2(x)# torch.Size([200, 64, 25, 5])
        x = self.sep_conv3(x)# torch.Size([200, 128, 25, 5])
        x = self.sep_conv4(x)# torch.Size([200, 128, 25, 5])

        x = self.avg_pool(x)
        #torch.Size([200, 128, 13, 3])
        x = self.flatten(x)
        #torch.Size([200, 4992])
        x = self.fc(x)
        return x
    
class DS_CNN_BP_v2(nn.Module):
# Define the input shape and the number of classes
# input_shape = (BS, 1, 49, 10)  # (Batchsize, channels, height, width)
# num_classes = 12

    def __init__(self, input_channel=1, num_classes = 12):
        super(DS_CNN_BP_v2, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=(10, 4), stride=(2, 2), padding=(5, 1))
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)

        self.sep_conv1 = SeparableConvBlock(64, 64)
        self.sep_conv2 = SeparableConvBlock(64, 64)
        self.sep_conv3 = SeparableConvBlock(64, 64)
        self.sep_conv4 = SeparableConvBlock(64, 64)

        #self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2), padding=1)
        #self.avg_pool = nn.AdaptiveAvgPool2d((5,5))
        self.avg_pool = nn.AvgPool2d((25,5))
        self.flatten = nn.Flatten()
        #self.fc = nn.Linear(2496, num_classes)
        #self.fc = nn.Linear(4992, num_classes)
        #128, 3, 3
        self.fc = nn.Linear(64, num_classes)

        self.activation_num = 0
        self.gradient_num = 0
        self.error_num = 0
        # estimate the actvation and gradient number
        self.gradient_num += 32*3*3 #conv1
        self.gradient_num += 64*(3*3*32) + 64*(3*3*1) #sep_conv1
        self.gradient_num += 64*(3*3*64) + 64*(3*3*1) #sep_conv2
        self.gradient_num += 64*(3*3*128) + 64*(3*3*1) #sep_conv3
        self.gradient_num += 64*(3*3*128) + 64*(3*3*1) #sep_conv4
        self.gradient_num += 12*64#fc

        self.activation_num += (64+64+64+64)*25*5
        self.activation_num += 1*49*10
        self.activation_num += 12

        self.error_num += 32*25*5
        self.error_num += (64+64+64+64)*25*5
        self.error_num += 64



    def forward(self, x):
        x = self.conv1(x) #input: torch.Size([200, 1, 49, 10])
        x = self.batch_norm1(x)# torch.Size([200, 64, 25, 5])
        x = self.relu(x) # torch.Size([200, 64, 25, 5])
        x = self.dropout(x) # torch.Size([200, 64, 25, 5])

        x = self.sep_conv1(x)# torch.Size([200, 64, 25, 5])
        x = self.sep_conv2(x)# torch.Size([200, 64, 25, 5])
        x = self.sep_conv3(x)# torch.Size([200, 64, 25, 5])
        x = self.sep_conv4(x)# torch.Size([200, 64, 25, 5])

        x = self.avg_pool(x)
        #torch.Size([200, 64, 1, 1])
        x = self.flatten(x)
        #torch.Size([200, 64])
        x = self.fc(x)
        return x