import collections
import math
import os
import time

import numpy as np
import pandas as pd
import scipy
import scipy.stats

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class RNN_recom_layer_one_layer(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, emb_dim, rnn_units, ffn_units, emb_weights=None):
        super(RNN_recom_layer_one_layer, self).__init__()
        
        self.emb_weights = emb_weights
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rnn_units = rnn_units
        
        
        if not self.emb_weights is None:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, weights = [self.emb_weights], mask_zero=True)
            self.model =  tf.keras.Sequential([
                self.emb_layer,
                layers.GRU(self.rnn_units)
            ])
        else:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, mask_zero=True)
            self.model =  tf.keras.Sequential([
                self.emb_layer,
                layers.GRU(self.rnn_units)
            ])
            
        self.ffn = layers.Dense(ffn_units)
        
        
    def call(self, inputs):
        seq = inputs[:,:-1]
        target = inputs[:,-1]
        
        seq_emb = self.model(seq)
        target_emb = self.emb_layer(target)
        target_emb = self.ffn(target_emb)
        
        return tf.concat((seq_emb, target_emb), axis=-1)
    

class RNN_recom_layer_two_layers(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, emb_dim, rnn_units, ffn_units, emb_weights=None):
        super(RNN_recom_layer_two_layers, self).__init__()
        
        self.emb_weights = emb_weights
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rnn_units = rnn_units
        
        
        if not self.emb_weights is None:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, weights = [self.emb_weights], mask_zero=True)
            self.model =  tf.keras.Sequential([
                self.emb_layer,
                layers.GRU(self.rnn_units, return_sequences = True),
                layers.GRU(self.rnn_units)
            ])
        else:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, mask_zero=True)
            self.model =  tf.keras.Sequential([
                self.emb_layer,
                layers.GRU(self.rnn_units, return_sequences = True),
                layers.GRU(self.rnn_units)
            ])
            
        self.ffn = layers.Dense(ffn_units)
        
        
    def call(self, inputs):
        seq = inputs[:,:-1]
        target = inputs[:,-1]
        
        seq_emb = self.model(seq)
        target_emb = self.emb_layer(target)
        target_emb = self.ffn(target_emb)
        
        return tf.concat((seq_emb, target_emb), axis=-1)
    
    
class tRNN_recom_layer_one_layer(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, emb_dim, rnn_units, ffn_units, emb_weights=None, name=None, **kwargs):
        super(tRNN_recom_layer_one_layer, self).__init__(name)
        
        self.emb_weights = emb_weights
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rnn_units = rnn_units
        
        
        if not self.emb_weights is None:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, weights = [self.emb_weights], 
                                              mask_zero=True, name='emb_layer')
            self.tRNN_layer = layers.RNN(TimeRNNCell(output_dim = rnn_units, name='trnncell', **kwargs))
        else:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, 
                                              mask_zero=True, name='emb_layer')
            self.tRNN_layer = layers.RNN(TimeRNNCell(output_dim = rnn_units, name='trnncell', **kwargs))
            
        self.ffn = layers.Dense(ffn_units, name='ffn_dense')
        
        
    def call(self, inputs):
        seq = inputs[:, 0, :-1]
        target = inputs[:, 0, -1]
        time = tf.expand_dims(inputs[:, 1, :-1], axis=-1)

        seq_emb = self.emb_layer(seq)
        mask = self.emb_layer.compute_mask(seq)
        seq_emb_time = tf.concat((seq_emb, time), axis=-1)
        
        seq_emb_out = self.tRNN_layer(seq_emb_time, mask=mask)
        target_emb = self.emb_layer(target)
        target_emb = self.ffn(target_emb)
        output = tf.concat((seq_emb_out, target_emb), axis=-1)
        
        return output
    
    
class tRNN_recom_layer_two_layer(tf.keras.layers.Layer):
    
    def __init__(self, vocab_size, emb_dim, rnn_units, ffn_units, emb_weights=None, cell_type='gru', **kwargs):
        super(tRNN_recom_layer_two_layer, self).__init__()
        
        self.emb_weights = emb_weights
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.rnn_units = rnn_units
        self.cell_type = cell_type
        
        if not self.emb_weights is None:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, weights = [self.emb_weights], mask_zero=True)
            self.rnn_layer1 = tf.keras.layers.RNN(BaseRNNCell(rnn_units, self.cell_type), return_sequences = True)
            self.transform_layer = RNNOutputTransform(in_dim = rnn_units, **kwargs)
            self.rnn_layer2 = layers.GRU(rnn_units)
        else:
            self.emb_layer = layers.Embedding(self.vocab_size, self.emb_dim, mask_zero=True)
            self.rnn_layer1 = tf.keras.layers.RNN(BaseRNNCell(rnn_units, self.cell_type), return_sequences = True)
            self.transform_layer = RNNOutputTransform(in_dim = rnn_units, **kwargs)
            self.rnn_layer2 = layers.GRU(rnn_units)
            
        self.ffn = layers.Dense(ffn_units)
        
        
    def call(self, inputs):
        seq = inputs[:, 0, :-1]
        target = inputs[:, 0, -1]
        time = tf.expand_dims(inputs[:, 1, :-1], axis=-1)

        seq_emb = self.emb_layer(seq)
        mask = self.emb_layer.compute_mask(seq)
        seq_emb_time = tf.concat((seq_emb, time), axis=-1)
        
        seq_emb_out = self.rnn_layer1(seq_emb_time, mask=mask)
        seq_emb_out = self.transform_layer(seq_emb_out)
        seq_emb_out = self.rnn_layer2(seq_emb_out, mask=mask)
        target_emb = self.emb_layer(target)
        target_emb = self.ffn(target_emb)
        output = tf.concat((seq_emb_out, target_emb), axis=-1)
        
        return output
    
    
    