import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *




class FeMPN(torch.nn.Module):
    # Feature Message Passing Network
    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=True):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 

        self.mlp         = MLP(embed_dim * self.num_fields, dnn_layers, use_bn=use_bn)
        self.concat_dims = concat_dims = self.layers[-1]
        self.out         = nn.Linear(concat_dims, 1)
        self.dropout     = torch.nn.Dropout(p=dropout)

        # MPN setting
        self.mpn_1 = MessagePassing(dnn_layers[-1], dnn_layers[-1])
        self.mpn_2 = nn.Linear(concat_dims, 1)
        
        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx      = self.feature_embedding(x)
        vx      = vx.view((-1, self.embed_dim * self.num_fields))
        mlp_out = self.mlp(vx)
        out = self.dropout(mlp_out)
        out = self.out(out).squeeze(1)

        if training:
            mpn_out = self.mpn_1(mlp_out)
            mpn_out = self.mpn_2(mpn_out).squeeze(1)
        return (out, mpn_out) if training else out

class FMMPN(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=False):
        super().__init__()
        self.field_dims = field_dims
        self.num_fields = len(field_dims)
        self.embed_dim  = embed_dim
        self.layers     = dnn_layers
        self.final_layer = False

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 
        self.linear_embedding  = FeaturesEmbedding(field_dims, 1)
        self.bias = nn.Parameter(torch.zeros((1,)))

        self.fm  = FactorizationMachine()
        # MPN setting
        self.mpn_1 = MessagePassing(1, 1, head=1)
        self.mpn_2 = nn.Linear(1, 1)

    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x.clone())
        lx = self.linear_embedding(x.clone())
        fm_out = self.fm(vx)

        out = fm_out + self.bias + torch.sum(lx, dim=[1,2])
        if training:
            mpn_out = self.mpn_1(out.unsqueeze(1))
            mpn_out = self.mpn_2(mpn_out).squeeze(1)
        return (out, mpn_out) if training else out


class FiMPN(torch.nn.Module):
    # Field Message Passing Network
    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 

        self.mlp         = MLP(embed_dim * self.num_fields, dnn_layers)
        self.concat_dims = concat_dims = self.layers[-1]
        self.out         = nn.Linear(concat_dims, 1)
        self.dropout     = torch.nn.Dropout(p=dropout)

        # MPN setting
        self.mpn_1 = FieldMessagePassing(dnn_layers[-1], dnn_layers[-1])
        self.mpn_2 = nn.Linear(concat_dims, 1)
        
        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx      = self.feature_embedding(x)
        vx      = vx.view((-1, self.embed_dim * self.num_fields))
        mlp_out = self.mlp(vx)
        out = self.dropout(mlp_out)
        out = self.out(out).squeeze(1)

        if training:
            mpn_out = self.mpn_1(mlp_out, x[:, 0])
            mpn_out = self.mpn_2(mpn_out).squeeze(1)
        return (out, mpn_out) if training else out

class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.register_buffer('offsets', torch.from_numpy(offsets))
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x += self.offsets
        return self.embedding(x)


