
import torch
import torch.nn as nn
import torch.nn.functional as F

####################################################################################
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    
    # Layers
    self.fc1 = nn.Linear(2, 100)
    self.fc2 = nn.Linear(100, 10)
    self.fc3 = nn.Linear(10, 10)
    self.final_layer= nn.Linear(10,1, bias=False)

  def forward(self, x):

    # Forward
    feature0 = x.view(-1, 2)
    x = self.fc1(x)
    feature1 = x.view(-1, 100)
    x = F.relu(x)
    feature2 = x.view(-1, 100)
    x = self.fc2(x)
    feature3 = x.view(-1, 10)
    x = F.relu(x)
    feature4 = x.view(-1, 10)
    x = self.fc3(x)
    x = x/torch.norm(x, dim=1).view(-1,1)
    feature5 = x.view(-1, 10)
    x = self.final_layer(x)
    feature6 = x.view(-1, 1)
    logits = x.flatten()
    
    return logits, feature5
####################################################################################