import copy
from operator import concat
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *


class DCN(torch.nn.Module):

    def __init__(self, field_dims, batch_size, embed_dim=20, dnn_layers=[200, 200, 200], use_bn=False, dcn_layers=3, dropout=0.5, use_mpn=False, use_topk=False):
        super().__init__()
        self.field_dims = field_dims
        self.batch_size = batch_size
        self.num_fields = len(field_dims)
        self.embed_dim  = embed_dim
        self.layers     = dnn_layers
        self.use_mpn    = use_mpn
        self.final_layer = False

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 

        self.mlp     = MLP(embed_dim * self.num_fields, dnn_layers, use_bn)
        self.cross   = CrossNet(embed_dim * self.num_fields, dcn_layers)

        self.concat_dims = concat_dims = self.layers[-1] + embed_dim * self.num_fields
        self.out     = nn.Linear(concat_dims, 1)
        self.dropout = torch.nn.Dropout(p=dropout)

    def update_linear(self, stl):
        older = stl[0]
        keys = list(older.keys())
        with torch.no_grad():
            for name, param in self.named_parameters():
                if 'emb' not in name:
                    temp_param = [st[name] for st in stl]
                    temp_param = torch.stack(temp_param, dim=0)
                    previous_param = torch.mean(temp_param, dim=0)
                    param.copy_(previous_param * 0.1 + 0.9 * param)

    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x) 
        vx = vx.view((-1, self.embed_dim * self.num_fields))

        mlp_out   = self.mlp(vx)
        cross_out = self.cross(vx)

        concat_out = torch.cat([mlp_out, cross_out], dim=1)
        vx = self.dropout(concat_out)
        vx = self.out(vx).squeeze(1)
        return vx


