import copy
from operator import index
from tokenize import group
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os, sys
try:
    from .layers import *
except:
    sys.path.insert(0, '..')
    sys.path.insert(0, '../model')
    from layers import *


def get_triu_indices(n, diag_offset=1):
    """get the row, col indices for the upper-triangle of an (n, n) array"""
    return np.triu_indices(n, diag_offset)


class InnnerProduct(torch.nn.Module):

    def forward(self, x):
        """
        :param x:   FloatTensor B*F*E
        :return:    FloatTensor B*(Fx(F-1))
        """
        nfield = x.size(1)
        vi_indices, vj_indices = get_triu_indices(nfield)
        vi, vj = x[:, vi_indices], x[:, vj_indices]             # B*(Fx(F-1)/2)*E
        inner_product = torch.sum(vi * vj, dim=-1)              # B*(Fx(F-1)/2)
        return inner_product

class PNN(torch.nn.Module):

    def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=False, use_mpn=False, use_topk=False):
        super().__init__()
        self.field_dims  = field_dims
        self.num_fields  = len(field_dims)
        self.embed_dim   = embed_dim
        self.layers      = dnn_layers

        self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V
        self.pnn               = InnnerProduct()
        self.mlp               = MLP(self.num_fields * (self.num_fields - 1) // 2 + embed_dim * self.num_fields, dnn_layers, use_bn=use_bn)
        self.concat_dims       = concat_dims = self.layers[-1]
        self.dropout           = torch.nn.Dropout(p=dropout)
      
        self.out = nn.Linear(concat_dims, 1)
        # MPN setting
        self.use_mpn = use_mpn
        if use_mpn:
            self.mpn_1 = MessagePassing(dnn_layers[-1], dnn_layers[-1], use_topk=use_topk)
            self.mpn_2 = nn.Linear(concat_dims, 1)
        
        
    def forward(self, x, training=False):
        """
        :param x: Long tensor of size ``(batch_size, num_fields)``
        """ 
        vx      = self.feature_embedding(x)
        vx_prod = self.pnn(vx)
        vx      = vx.view((-1, self.embed_dim * self.num_fields))
        mlp_out = self.mlp(torch.cat([vx_prod, vx], dim=-1))
        vx      = self.dropout(mlp_out)
        vx      = self.out(vx).squeeze(1)
        return  vx

# class FieldGroup(nn.Module):

#     def __init__(self, dim, field_index=4):
#         super().__init__()
#         self.field_index = field_index
#         self.dim = dim
#         self.register_buffer('padding_fea', torch.zeros((1,dim), dtype=torch.float32))
#         self.register_buffer('padding_ind', torch.zeros(1, dtype=torch.long))
#         self.register_buffer('range', torch.arange(10000))

#         self.register_buffer('zero', torch.zeros(1, dtype=torch.float32))
#         self.register_buffer('neg', torch.ones(1,dtype=torch.float32) * -1.e-9)

#     def map_idx_to_index(self, unique_out, counts):
#         max_counts = counts.max()

#         counts_range = self.range[:max_counts] + 1
#         mask_counts = torch.le(counts_range, counts.view(-1, 1))
#         counts_repeat = counts_range.repeat((unique_out.size(0), 1))
#         cumsum_counts = torch.cumsum(counts, dim=0).view(-1, 1)
#         counts_repeat_new = counts_repeat[1:, :] + cumsum_counts[:-1, :]
#         counts_repeat_new = torch.cat([counts_repeat[0:1], counts_repeat_new], axis=0)
#         counts_repeat_new *= mask_counts.int()
#         non_zero_index_map = torch.where(torch.reshape(counts_repeat_new, [-1]))
#         return counts_repeat_new.view(-1), non_zero_index_map[0]
    
#     def forward(self, x, features, y=None):
#         dim = features.size(-1)
        
#         fields = x[:, self.field_index]
#         sorted_fields, indices = torch.sort(fields)
        
#         unique_out, counts = torch.unique(fields, return_counts=True)
#         group_index, index_map = self.map_idx_to_index(unique_out, counts)
#         group_bias = torch.where(torch.eq(group_index, 0.), self.neg, self.zero)
        
#         padded_features = torch.cat([self.padding_fea, features], dim=0)
#         padded_indice = torch.cat([self.padding_ind, indices + 1], dim=0)
        
#         out_indice = padded_indice[group_index]
#         out_features = padded_features[out_indice].view(-1, counts.max(), dim)
        
#         if y is not None:
#             padding_y = torch.cat([self.padding_ind, y], dim=0)
#             out_y = padding_y[out_indice]
        
#             return out_features, out_y, group_bias.view(-1, counts.max()).unsqueeze(1).unsqueeze(1), index_map
#         else:
#             return out_features, group_bias.view(-1, counts.max()).unsqueeze(1).unsqueeze(1), index_map

# class SelfAttention(nn.Module):

#     def __init__(self, input_size, size, head=4):
#         super().__init__()
#         self.input_size = input_size
#         self.size = size
#         self.head = head

#         self.in_size = in_size = input_size // head
#         self.out_size = out_size = size // head
#         self.q = nn.Linear(in_size, out_size)
#         self.k = nn.Linear(in_size, out_size)
#         self.v = nn.Linear(in_size, out_size)
#         self.softmax = nn.Softmax(dim=-1)  

#         # self.norm = nn.LayerNorm(size)  
#         self.relu = nn.ReLU(inplace=True)

#     def forward(self, x, bias=None):
#         b, length, c = x.size()
#         x = x.view(b, length, self.head, self.in_size)
#         x = x.transpose(1,2)
#         q = self.q(x)
#         k = self.k(x)   
#         v = self.v(x)

#         qk = torch.einsum('bhkc, bhcg->bhkg',q, k.transpose(2,3))
#         qk *= self.out_size ** -0.5
#         if bias is not None:
#             qk += bias
#         qk = self.softmax(qk)
#         out = torch.einsum('bhkg, bhgc->bhkc',qk, v)
#         out = out.transpose(1,2) 
#         out = torch.reshape(out, (-1, length, self.size))
#         # out = self.norm(out)
#         out = self.relu(out)
#         return out

# class FGSelfAttention(nn.Module):

#     def __init__(self, input_size, size, field_index=4) -> None:
#         super().__init__()
#         self.field_index = field_index
#         self.input_size  = input_size
#         self.size        = size
#         self.field_group = FieldGroup(input_size,field_index)
#         self.attention   = SelfAttention(input_size, size, head=2)
    
#     def forward(self, x, features, y=None):
#         if y is not None:
#             out_features, out_y, group_bias, index_map = self.field_group(x, features, y)
#         else:
#             out_features, group_bias, index_map = self.field_group(x, features)

#         attention_output = self.attention(out_features, group_bias)
#         attention_output = attention_output.view(-1, self.size)

#         attention_output = attention_output[index_map]
#         if y is not None:
#             return attention_output, out_y[index_map]
#         return attention_output

# class FGDNN(nn.Module):

#     def __init__(self, field_dims, embed_dim=20, dnn_layers=[200, 200, 200], dropout=0.5, use_bn=True, field_index=2):
#         super().__init__()
#         self.field_dims  = field_dims
#         self.num_fields  = len(field_dims)
#         self.embed_dim   = embed_dim
#         self.layers      = dnn_layers
#         self.field_index = field_index

#         self.feature_embedding = FeaturesEmbedding(field_dims, self.embed_dim) # use one embedding matrix to include both U and V
#         self.mlp               = MLP(embed_dim * self.num_fields, dnn_layers, use_bn=use_bn)
#         self.concat_dims       = concat_dims = self.layers[-1]
#         self.out               = nn.Linear(concat_dims, 1)
#         self.dropout           = torch.nn.Dropout(p=dropout)

#         self.fg = FGSelfAttention(concat_dims, concat_dims // 2, field_index=0)
#         self.fg_out = nn.Linear(concat_dims//2, 1)
        
        
#     def forward(self, x, training=False, y=None):
#         """
#         :param x: Long tensor of size ``(batch_size, num_fields)``
#         """ 
#         vx      = self.feature_embedding(x)
#         vx      = vx.view((-1, self.embed_dim * self.num_fields))
#         mlp_out = self.mlp(vx)
#         vx      = self.dropout(mlp_out)
#         vx      = self.out(vx)

#         if y is not None:
#             vg, vy = self.fg(x, mlp_out, y.long())
#             vg = self.fg_out(vg)
#             return vx.squeeze(1), vg.squeeze(1), vy
#         else:
#             vg = self.fg(x, mlp_out)
#             vg = self.fg_out(vg)
#             return  vx.squeeze(1), vg.squeeze(1)


