import copy
from operator import index
from sre_constants import AT_BEGINNING
from tokenize import group
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *



class AutoInt(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], attention_layers=[100, 100, 100], head=2, dropout=0.5, use_bn=False):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V
        att_layers = []
        prev_layer = embed_dim
        for layer in attention_layers:
            att_layers.append(
                SelfAttention(prev_layer, layer, head)
            )
            prev_layer   = layer

        self.sa      = nn.Sequential(*att_layers)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.out     = nn.Linear(attention_layers[-1] * self.num_fields, 1)
        
        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx  = self.feature_embedding(x)
        vx  = vx.view((-1, self.num_fields, self.embed_dim))
        out = self.sa(vx)
        out = torch.flatten(out, start_dim=1)
        out = self.dropout(out)
        out = self.out(out)
        return  out.squeeze(1)


class SelfAttention(nn.Module):

    def __init__(self, input_size, size, head=2, norm=True):
        super().__init__()
        self.input_size = input_size
        self.size = size
        self.head = head
        self.norm = norm

        self.in_size = in_size = input_size // head
        self.out_size = out_size = size // head
        self.q = nn.Linear(in_size, out_size, bias=False)
        self.k = nn.Linear(in_size, out_size, bias=False)
        self.v = nn.Linear(in_size, out_size, bias=False)
        self.softmax = nn.Softmax(dim=-1)  

        self.linear = nn.Linear(input_size, size)
        if norm:
            self.normalization = nn.LayerNorm(size)  
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        b, length, c = x.size()
        shortcut = self.linear(x)
        x = x.view(b, length, self.head, self.in_size)
        x = x.transpose(1,2)
        q = self.q(x)
        k = self.k(x)   
        v = self.v(x)

        qk = torch.einsum('bhkc, bhcg->bhkg',q, k.transpose(2,3))
        qk *= self.out_size ** -0.5
        qk = self.softmax(qk)
        out = torch.einsum('bhkg, bhgc->bhkc',qk, v)
        out = out.transpose(1,2) 
        out = torch.reshape(out, (-1, length, self.size)) + shortcut
        if self.norm:
            out = self.normalization(out)
        out = self.relu(out)
        return out


class FeaturesEmbedding(torch.nn.Module):

    def __init__(self, field_dims, embed_dim):
        super().__init__()
        print(field_dims, embed_dim)
        self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim)
        self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long)
        self.offsets = torch.from_numpy(self.offsets).cuda(async=True)
        torch.nn.init.xavier_uniform_(self.embedding.weight.data)
        
    def forward(self, x):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """
        x += self.offsets
        return self.embedding(x)


