import torch
import torch.nn as nn
import torch.nn.functional as F

class net2(nn.Module):
    def __init__(self,input_size,mono_size,mono_feature,mono_hidden_num=5,non_mono_hidden_num=10):
        super(net2,self).__init__()
        self.input_size = input_size
        self.mono_size = mono_size
        self.mono_feature = mono_feature

        self.non_mono_size = input_size - mono_size
        self.non_mono_feature = list(set(list(range(input_size))).difference(mono_feature))

        ## The weights of the NN
        self.fc_in=nn.Linear(self.input_size,4,bias=True)
        self.fc2=nn.Linear(4,3,bias=True)
        self.fc_out=nn.Linear(3,1,bias=True)

    def forward(self,x):
        x_mono=x[:,self.mono_feature]
        x_non_mono=x[:,self.non_mono_feature]

        x=self.fc_in(torch.cat([x_mono,x_non_mono],dim=1))
        x=F.relu(x)
        x=self.fc2(x)
        x=F.relu(x)
        x_out=self.fc_out(x)

        return x_out

    def reg_forward(self,feature_num,num=512):
        in_list=[]
        out_list=[]

        input_feature = torch.rand(num, feature_num)
        input_mono = input_feature[:, self.mono_feature]
        input_non_mono = input_feature[:, self.non_mono_feature]
        input_mono.requires_grad=True
        in_list.append(input_mono)

        # the forward
        x = self.fc_in(torch.cat([input_mono, input_non_mono], dim=1))
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x_out = self.fc_out(x)

        out_list.append(x_out)

        return in_list,out_list
