import torch
import torch.nn as nn
import torch.nn.functional as F

class net(nn.Module):
    def __init__(self,input_size,mono_size,mono_feature,mono_hidden_num=16,non_mono_hidden_num=16):
        super(net,self).__init__()

        self.input_size=input_size
        self.mono_size=mono_size
        self.mono_feature=mono_feature

        self.non_mono_size=input_size-mono_size
        self.non_mono_feature=list(set(list(range(input_size))).difference(mono_feature))

        self.mono_in=nn.Linear(self.mono_size,mono_hidden_num,bias=True)
        self.non_mono_in=nn.Linear(self.non_mono_size,non_mono_hidden_num,bias=True)

        ## This hidden layer mix the features of mono and non_mono
        self.mix_hidden_1=nn.Linear(mono_hidden_num+non_mono_hidden_num,8,bias=True)
        self.mix_hidden_2=nn.Linear(8,8,bias=True)

        ## The output layer
        # self.mono_fc_out=nn.Linear(10,1,bias=True)
        # self.non_mono_fc_out=nn.Linear(10,1,bias=True)
        self.fc_out=nn.Linear(8,1,bias=True)

    def forward(self,x):
        x_mono=x[:,self.mono_feature]
        x_non_mono=x[:,self.non_mono_feature]

        x_mono=self.mono_in(x_mono)
        x_mono=F.relu(x_mono)

        x_non_mono=self.non_mono_in(x_non_mono)
        x_non_mono=F.relu(x_non_mono)

        # The hidden layer
        x=self.mix_hidden_1(torch.cat([x_mono,x_non_mono],dim=1))
        x=F.relu(x)
        # x=self.mix_hidden_2(x)
        # x=F.relu(x)

        x_out=self.fc_out(x)

        return x_out

    def reg_forward(self,feature_num,num=512):
        in_list=[]
        out_list=[]

        input_feature=torch.rand(num,feature_num)
        input_mono=input_feature[:,self.mono_feature]
        input_non_mono=input_feature[:,self.non_mono_feature]
        input_mono.requires_grad=True
        in_list.append(input_mono)

        x_mono=self.mono_in(input_mono)
        x_mono=F.relu(x_mono)
        x_non_mono=self.non_mono_in(input_non_mono)
        x_non_mono=F.relu(x_non_mono)

        # The hidden layer
        x=self.mix_hidden_1(torch.cat([x_mono,x_non_mono],dim=1))
        x=F.relu(x)
        # x=self.mix_hidden_2(x)
        # x=F.relu(x)

        # x_mono_out=self.mono_fc_out(x)
        # x_non_mono_out=self.non_mono_fc_out(x)
        x_out=self.fc_out(x)

        out_list.append(x_out)

        return in_list,out_list




