import torch
import torch.nn as nn
import torch.nn.functional as F

####################################################################################
class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1)
    self.feature_size = 128
    self.fc1 = nn.Linear(9216, self.feature_size)
    self.final_layer = nn.Linear(self.feature_size, 1)
        
  def normalize_features(self, x):
    batch_size = x.shape[0]
    norms = torch.norm(x.view(batch_size, -1), dim=1, keepdim=True).view(batch_size, 1, 1, 1)
    # norms = torch.linalg.norm(x, dim=1).unsqueeze(1) 
    # If norm is zero for a sample, then divide by 1
    # norms = torch.where(norms == 0, torch.tensor(1.0).to(x.device), norms)
    x = x/(norms)
    return x

  def forward(self, x):
    x = self.conv1(x) # feature1
    x = F.relu(x) # feature2
    x = self.conv2(x) # feature3
    x = F.relu(x) # feature4
    
    x = self.normalize_features(x)
    features = x.view(-1, 64*24*24)
    
    x = F.max_pool2d(x, 2) # feature5
    x = torch.flatten(x, 1) # feature6
    x = self.fc1(x) # feature7
    x = F.relu(x) # feature8
    x = self.final_layer(x) # feature9
    logits = x.flatten()
    
    return logits, features
  
####################################################################################
