import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class FiDPO(torch.nn.Module):

    def __init__(self, c_in, c_out, field_nums, head=1, route_nums=8, use_shortcut=False, use_bn=False, use_bias=True):
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        ''''''
        self.l1 = torch.nn.Linear(c_in, route_nums)
        self.lw = torch.nn.Linear(route_nums, c_in * c_out, bias=False)
        self.lb = torch.nn.Linear(route_nums, c_out, bias=False)
        self.bn = torch.nn.BatchNorm1d(c_out)
        self.relu = torch.nn.ReLU(inplace=True)
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax(dim=1)

        self.shortcut    = torch.nn.Linear(c_in, c_out)
        self.field_cross = torch.nn.Linear(field_nums, field_nums, bias=False)
        for i in range(route_nums):
            torch.nn.init.xavier_uniform_(self.lw.weight.data[:,i].view(c_in, c_out))
            torch.nn.init.xavier_uniform_(self.lb.weight.data[:,i].view(1, c_out))

    def forward(self, x):
        cond_tensor = x.mean(dim=2)
        routes = self.l1(cond_tensor)
        routes = self.sigmoid(routes)
        w = self.lw(routes)
        b = self.lb(routes)
        w = w.view([-1, self.c_out, self.c_in])
        out = torch.bmm(w, x) + b.unsqueeze(2)

        field_nums = x.size(2)
        field_x = x.view(-1, field_nums)
        field_x = self.field_cross(field_x)
        field_x = field_x.view(-1, self.c_in, field_nums).transpose(1,2)
        field_x = self.shortcut(field_x).view(-1, field_nums, self.c_out).transpose(1,2)
        out += field_x
        out = self.bn(out)
        out = self.relu(out)
        return out

class FiDPN(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, layers=[200, 200, 200], dropout=0.5, log_ebd=False):
        super().__init__()
        self.field_dims = field_dims
        self.num_fields = len(field_dims) 
        
        self.embed_dim  = embed_dim
        self.layers = layers
        #===========================================================        
        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 
        DPN_layers = list()
        origin = c_in = self.embed_dim
        for i in self.layers:
            DPN_layers.append(FiDPO(c_in, i, self.num_fields, route_nums=8))
            c_in = i
        self.out = nn.Linear(self.layers[-1] * self.num_fields, 1)    
        self.DPN = torch.nn.ModuleList(DPN_layers)
        self.dropout = torch.nn.Dropout(p=dropout)


        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x) 
        vx = vx.view((-1, self.embed_dim, self.num_fields))
        origin = vx
        for layer in self.DPN:
            vx = layer(vx)

        vx = vx.view((-1, self.num_fields * self.layers[-1]))
        vx = self.dropout(vx)
        vx = self.out(vx).squeeze(1)
        return vx

class FeDPO(torch.nn.Module):

    def __init__(self, c_in, c_out, cond_size, head=1, route_nums=8, use_shortcut=False, use_bn=False, use_bias=True, act='sigmoid'):
        super().__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.use_shortcut = use_shortcut
        ''''''
        self.l1 = torch.nn.Linear(cond_size, route_nums)
        self.lw = torch.nn.Linear(route_nums, c_in * c_out, bias=False)
        self.lb = torch.nn.Linear(route_nums, c_out, bias=False)
        self.bn = torch.nn.BatchNorm1d(c_out)
        self.relu = torch.nn.ReLU(inplace=True)
        if act == 'sigmoid':
            self.act = torch.nn.Sigmoid()
        elif act == 'softmax':
            self.act = torch.nn.Softmax(dim=1)
        else:
            self.act = None
        for i in range(route_nums):
            torch.nn.init.xavier_uniform_(self.lw.weight.data[:,i].view(c_in, c_out))
            torch.nn.init.xavier_uniform_(self.lb.weight.data[:,i].view(1, c_out))

        if use_shortcut:
            self.shortcut = torch.nn.Linear(c_in, c_out)

    def forward(self, x, cond_tensor):
        routes = self.l1(cond_tensor)
        if self.act is not None:
            routes = self.act(routes)
        w = self.lw(routes)
        b = self.lb(routes)
        w = w.view([-1, self.c_in, self.c_out])
        out = torch.einsum('bi, bij->bj', x, w) + b

        if self.use_shortcut:
            out += self.shortcut(x)

        out = self.bn(out)
        out = self.relu(out)
        return out


               
class FeDPN(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, layers=[300, 300], dropout=0.5, log_ebd=False):
        super().__init__()
        self.field_dims = field_dims
        self.num_fields = len(field_dims) 
        
        self.embed_dim  = embed_dim
        self.layers = layers
        #===========================================================        
        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V 
        DPN_layers = list()
        origin = c_in = self.num_fields * self.embed_dim
        for i in self.layers:
            DPN_layers.append(FeDPO(c_in, i, c_in, route_nums=4))
            c_in = i
        self.out = nn.Linear(self.layers[-1], 1)    
        self.DPN = torch.nn.ModuleList(DPN_layers)
        self.dropout = torch.nn.Dropout(p=dropout)

    

        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx = self.feature_embedding(x) 
        vx = vx.view((-1, self.num_fields * self.embed_dim))
        origin = vx
        for layer in self.DPN:
            vx = layer(vx, vx)
        vx = self.dropout(vx)
        vx = self.out(vx).squeeze(1)
        return vx

class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        # print(sum(field_dims), embed_dim)
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.offsets = torch.from_numpy(self.offsets).cuda(async=True)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x += self.offsets
        return self.embedding(x)


if __name__ == '__main__':
    fuck = FeDPN([100,200,300,400], 20)
    sample = torch.Tensor([[0,1,2,3], [2,3,4,5]]).long()
    out = fuck(sample)
    print(out)
