import copy
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *


class FM(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_mpn=False, use_topk=False):
        super().__init__()
        self.field_dims = field_dims
        self.num_fields = len(field_dims)
        self.embed_dim  = embed_dim
        self.layers     = dnn_layers
        self.final_layer = False

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 
        self.linear_embedding  = FeaturesEmbedding(field_dims, 1)
        self.bias = nn.Parameter(torch.zeros((1,)))

        self.fm  = FactorizationMachine()
        # MPN setting
        self.use_mpn = use_mpn
        if use_mpn:
            self.mpn_1 = MessagePassing(1, 1, head=1, use_topk=use_topk)
            self.mpn_2 = nn.Linear(1, 1)

    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x.clone())
        lx = self.linear_embedding(x.clone())
        fm_out = self.fm(vx)

        out = fm_out + torch.sum(lx, dim=[1,2]) + self.bias
        if self.use_mpn:
            if training:
                mpn_out = self.mpn_1(nn.functional.normalize(out.unsqueeze(1)))
                mpn_out = self.mpn_2(mpn_out).squeeze(1)
            return (out, mpn_out) if training else out
        return out


class DeepFM(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=False):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers
        self.final_layer = False
        self.use_bn      = use_bn

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 

        self.mlp = MLP(embed_dim * self.num_fields, dnn_layers, use_bn=use_bn)
        self.out = nn.Linear(dnn_layers[-1], 1)

        self.fm  = FactorizationMachine(reduce_dim=True)
        self.bias = nn.Parameter(torch.zeros((1,)))
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x) 
        
        mlp_out = self.mlp(vx.view((-1, self.embed_dim * self.num_fields)))
        mlp_out = self.dropout(mlp_out)
        fm_out = self.fm(vx)

        out = self.out(mlp_out).squeeze(1)
        out += fm_out + self.bias
        return out


class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        print(sum(field_dims), embed_dim)
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.register_buffer('offsets', torch.from_numpy(offsets))
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        # x = x + x.new_tensor(self.offsets).unsqueeze(0)
        x += self.offsets
        return self.embedding(x)


if __name__ == '__main__':
    fuck = FeDPN([100,200,300,400], 20)
    sample = torch.Tensor([[0,1,2,3], [2,3,4,5]]).long()
    out = fuck(sample)
    print(out)
