import torch.nn as nn
import math
import torch
import torch.nn.functional as F

__all__ = ["nett"]


class Net(nn.Module):
    def __init__(self, input_dim, num_neurons,  patches):
        super().__init__()
        
        self.fc1 = nn.Linear(input_dim,num_neurons)
        #self.fc1.bias.data = torch.zeros(size=(num_neurons,),requires_grad=False)

        self.fc2 = nn.Linear(patches*num_neurons,2)#2,patches*num_neurons)
        #self.fc2.weight.data = torch.ones(size=(2,patches*num_neurons),requires_grad=False)
        #self.fc2.bias.data = torch.zeros(size=(1,),requires_grad=False)

        
        
        #self.patches = patches

    def forward(self,x):
        #weight_matrix = self.fc1.weight
        #cubic_activation = cubic()
        #out1 = cubic_activation(F.linear(x, self.fc1.weight, self.fc1.bias))#.sum(dim=2)#.squeeze(2)#[:,:,0]
        
        out1=torch.pow(F.relu(self.fc1(x)),2)

        #out1=F.relu(torch.pow(self.fc1(x),2))
        #out1=F.relu(self.fc1(x))
        #print("out1")
        #print("nonzeros")
        #print(torch.count_nonzero(out1,(0,1)))
        #print("weights")
        #print(out1.shape)
        #print(out1)

        out2 = out1.flatten(1,2)     
        out3 = self.fc2(out2)

        #print(out3)
        return out3

def nett(input_dim,neurons,patches):
    return Net(input_dim,neurons,patches)





