import torch
import torch.nn as nn
import numpy as np

class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        print(field_dims)
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.register_buffer('offsets', torch.from_numpy(offsets))
        
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x += self.offsets
        return self.embedding(x)

class FactorizationMachine(nn.Module):

    def __init__(self, reduce_dim=True):
        super().__init__()
        self.reduce_dim = reduce_dim

    def forward(self, x):
        """
        :param x:   FloatTensor B*F*E
        """
        square_of_sum = torch.sum(x, dim=1)**2                  # B*E
        sum_of_square = torch.sum(x**2, dim=1)                  # B*E
        fm = square_of_sum - sum_of_square                      # B*E
        if self.reduce_dim:
            fm = torch.sum(fm, dim=1)                           # B
        return 0.5 * fm                                         # B*E/B

class CrossNet(torch.nn.Module):

    def __init__(self, input_size, nlayers):
        super().__init__()
        self.nlayers = nlayers
        self.w = torch.nn.ModuleList([
            torch.nn.Linear(input_size, 1, bias=False) for _ in range(nlayers)
        ])
        self.b = torch.nn.ParameterList([
            torch.nn.Parameter(torch.zeros((input_size, ))) for _ in range(nlayers)
        ])

    def forward(self, x):
        """
        :param x:   FloatTensor B*(FxE)
        :return:    FloatTensor B*(FxE)
        """
        x0 = x
        for l in range(self.nlayers):
            xw = self.w[l](x)
            x = x0*xw + self.b[l] + x
        return x


class MLP(nn.Module):

    def __init__(self, input_size, sizes, use_bn=False):
        super().__init__()
        layers = list()
        c_in = input_size
        for size in sizes:
            layers.append(nn.Linear(c_in, size))
            if use_bn:
                layers.append(nn.BatchNorm1d(size))
            layers.append(nn.ReLU())
            c_in = size
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        """
        :param x:   FloatTensor B*ninput
        :return:    FloatTensor B*nouput
        """
        return self.mlp(x)



class LabelMessagePassing(nn.Module):

    def __init__(self, input_size, size, head=4, use_qkv=False, use_bn=False):
        super().__init__()
        self.input_size = input_size
        self.size = size
        self.head = head
        self.use_bn = use_bn
        self.use_qkv = use_qkv

        self.in_size = in_size = input_size // head
        self.out_size = out_size = size // head
        if self.use_qkv:
            self.q = nn.Linear(in_size, out_size)
            self.k = nn.Linear(in_size, out_size)
            self.v = nn.Linear(in_size, out_size)

        if use_bn:
            self.norm = nn.BatchNorm1d(size)
        else:
            self.norm = nn.LayerNorm(size)  

        self.softmax = nn.Softmax(dim=-1)  
    
    def padding_bias(self, label):
        positive_w = label == 1
        negative_w = label == 0

        positive_b = -1e9 * negative_w.float()
        negative_b = -1e9 * positive_w.float()
        return positive_w, positive_b, negative_w, negative_b

    def forward(self, x, y=None):
        
        pw, pb, nw, nb = self.padding_bias(y)
        x = x.view(-1, self.head, self.in_size)
        x = x.transpose(0,1)
        if self.use_qkv:
            q = self.q(x)
            k = self.k(x)   
            v = self.v(x)
        else:
            q = k = v = x


        qk = torch.bmm(q, k.transpose(1,2))
        pos_qk = qk + pb[None, None, -1]
        neg_qk = qk + nb[None, None, -1]
        pos_qk *= self.out_size ** -0.5
        neg_qk *= self.out_size ** -0.5

        pos_qk = self.softmax(pos_qk)
        neg_qk = self.softmax(neg_qk)
        out = torch.bmm(pos_qk, v) + torch.bmm(neg_qk, v)
        out = out.transpose(0,1)
        out = torch.reshape(out, (-1, self.size))
        out = self.norm(out)
        return out

class BatchDPO(nn.Module):

    def __init__(self, input_size, size, routes=4, use_bn=True, use_sigmoid=True):
        super().__init__()
        self.input_size = input_size
        self.size = size
        self.routes = routes

        self.cond_layer = nn.Linear(input_size, routes)
        self.lw = nn.Linear(routes, input_size * size, bias=False)
        self.lb = nn.Linear(routes, size, bias=False)

        if use_bn:
            self.norm = nn.BatchNorm1d(size)
        else:
            self.norm = nn.LayerNorm(size)  

        if use_sigmoid:
            self.act = nn.Sigmoid()
        else:
            self.act = nn.Softmax(dim=-1)

        self.relu = nn.ReLU(inplace=True)
        for i in range(routes):
            torch.nn.init.xavier_uniform_(self.lw.weight.data[:,i].view(input_size, size))
            torch.nn.init.xavier_uniform_(self.lb.weight.data[:,i].view(1, size))

    def forward(self, x):

        cond = x.mean(dim=0, keepdim=True)
        cond = self.cond_layer(cond)
        cond_w = self.lw(cond)
        cond_b = self.lb(cond)

        cond_w = cond_w.view(self.input_size, self.size)

        out = torch.mm(x, cond_w) + cond_b
        out = self.norm(out)
        out = self.relu(out)
        return out


class EntMinAttention(nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x, lbd=0.5):
        bs, f, c = x.size()
        n = f - 1
        d = (x - x.mean(dim=1, keepdim=True)).pow(2)
        v = d.sum(dim=1) / n
        E_inv = d / (4 * (v + lbd)) + 0.5
        return x * torch.sigmoid(E_inv)

class MetaMBR(nn.Module):

    def __init__(self, input_size, size, batch_size):
        super().__init__()
        
        self.input_size = input_size
        self.size = size
        self.batch_size = batch_size

        self.fc = nn.Linear(input_size, size)
        self.batch_fc = nn.Linear(batch_size, batch_size)
        self.relu = nn.ReLU(inplace=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        b, c = x.size()
        if b != self.batch_size:
            return x
        batch_x = self.batch_fc(x.transpose(0, 1))
        batch_x = self.softmax(batch_x)
        batch_x = self.fc(batch_x.transpose(0,1))
        return self.relu(x + batch_x)
        

class MessagePassing(nn.Module):

    def __init__(self, input_size, size, k=1, use_topk=True, head=4, use_qkv=True, use_bn=False):
        super().__init__()
        self.input_size = input_size
        self.size = size
        self.head = head
        self.use_bn = use_bn
        self.use_qkv = use_qkv
        self.use_topk = use_topk
        self.topk = k

        self.in_size = in_size = input_size // head
        self.out_size = out_size = size // head
        self.q = nn.Linear(in_size, out_size)
        self.k = nn.Linear(in_size, out_size)
        if use_topk:
            self.v = nn.Linear(input_size, out_size)
        else:
            self.v = nn.Linear(in_size, out_size)
        self.softmax = nn.Softmax(dim=-1)  

        self.fc = nn.Linear(input_size, size)
        self.norm = nn.LayerNorm(size)  
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        b, c = x.size()
        shortcut = x.clone()
        x = x.view(-1, self.head, self.in_size)
        x = x.transpose(0,1)
        q = self.q(x)
        k = self.k(x)   

        qk = torch.bmm(q, k.transpose(1,2))
        if self.use_topk:
            qk_topk_value, qk_topk_indices = torch.topk(qk, self.topk, dim=-1)
            qk = qk_topk_value
            v = shortcut[qk_topk_indices.view(-1)]
            v = v.view(self.head, b, self.topk, c)
            v = self.v(v)
        else:
            v = self.v(x) # h, b, c if topk else h, b, k, c
            
        qk *= self.out_size ** -0.5
        qk = self.softmax(qk) # h, b, k if topk else h, b, b

        if self.use_topk:
            out = qk.unsqueeze(3) * v
            out = out.mean(dim=2)
        else:
            out = torch.bmm(qk, v)
        out = out.transpose(0,1) 
        out = torch.reshape(out, (-1, self.size))
        shortcut = self.fc(shortcut)
        out = self.norm(out + shortcut)
        out = self.relu(out)
        return out

class FieldMessagePassing(torch.nn.Module):

    def __init__(self, input_size, size, head=4, use_qkv=True, use_bn=False):
        super().__init__()
        self.input_size = input_size
        self.size = size
        self.head = head
        self.use_bn = use_bn
        self.use_qkv = use_qkv

        self.in_size = in_size = input_size // head
        self.out_size = out_size = size // head
        self.q = nn.Linear(in_size, out_size)
        self.k = nn.Linear(in_size, out_size)
        self.v = nn.Linear(in_size, out_size)
        self.softmax = nn.Softmax(dim=-1)  
        self.norm = nn.LayerNorm(size)  
        self.relu = nn.ReLU(inplace=True)
    
    def fields_sim(self, fields):
        fields_sim = fields[:, None] == fields[:, None] # [1,b] == [b, 1]
        fields_sim = fields_sim.float() # [b, b]
        return torch.where(fields_sim == 1.0, torch.zeros_like(fields_sim), torch.ones_like(fields_sim) * -1e9)

    def forward(self, x, fields):
        
        x = x.view(-1, self.head, self.in_size)
        x = x.transpose(0,1)
        q = self.q(x)
        k = self.k(x)   
        v = self.v(x)

        # dot-product similarity
        qk = torch.bmm(q, k.transpose(1,2))
        qk *= self.out_size ** -0.5

        fields_sim = self.fields_sim(fields).unsqueeze(0)
        qk += fields_sim

        qk = self.softmax(qk)
        out = torch.bmm(qk, v)
        out = out.transpose(0,1) 
        out = torch.reshape(out, (-1, self.size))
        out = self.norm(out)
        out = self.relu(out)
        return out



