import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *


class DNN(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=False, **kwargs):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers
        
        self.mlp         = MLP(embed_dim * self.num_fields, dnn_layers, use_bn=use_bn)
        self.concat_dims = concat_dims = self.layers[-1]
        self.dropout     = torch.nn.Dropout(p=dropout)
        self.out         = nn.Linear(concat_dims, 1)
        self.asmg_params = list(self.named_parameters())

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 
        self.model_params = list(self.named_parameters())

    def forward(self, x, training=True):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x)
        vx = vx.view((-1, self.embed_dim * self.num_fields))
        vx = self.mlp(vx)
        vx = self.out(vx).squeeze(1)
        return vx