from datetime import timedelta

from tqdm import tqdm
import pandas as pd
import numpy as np
import scipy
import scipy.stats

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras import layers

from time_layers import affine_coupling_layer, rand_feat_layer, TimeRNNCell, BaseRNNCell, RNNOutputTransform


class sin_cos_layer(layers.Layer):
    def __init__(self, sin_dim, cos_dim, name=None):
        super(sin_cos_layer, self).__init__(name=name)
        self.w1 = self.add_weight(
            shape=[sin_dim], initializer='random_normal', trainable=True, name='w1')
        self.w2 = self.add_weight(
            shape=[cos_dim], initializer='random_normal', trainable=True, name='w2')
        self.b1 = self.add_weight(
            shape=[sin_dim], initializer='random_normal', trainable=True, name='b1')
        self.b2 = self.add_weight(
            shape=[cos_dim], initializer='random_normal', trainable=True, name='b2')
        self.sin_dim = sin_dim
        self.cos_dim = cos_dim

    def call(self, inputs):
        time = tf.expand_dims(inputs[:, :, -1], axis=-1)
        sin_vec = tf.math.sin(tf.math.add(tf.math.multiply(
            tf.tile(time, [1, 1, self.sin_dim]), self.w1), self.b1))
        cos_vec = tf.math.cos(tf.math.add(tf.math.multiply(
            tf.tile(time, [1, 1, self.cos_dim]), self.w2), self.b2))
        inputs = inputs[:, :, :-1]
        return tf.concat((inputs, sin_vec, cos_vec), axis=-1)


class Config:
    def __init__(self, config_dict=None):
        if not config_dict:
            self.config_dict = {}
        else:
            self.config_dict = config_dict

    @property
    def data(self):
        valid = ['weather', 'wikipedia']
        x = self.config_dict.get('data', 'weather')
        if x not in valid:
            raise ValueError(f'{x} is not valid data {valid}')
        return x

    @property
    def base_model(self):
        valid = ['CNN', 'RNN', 'Attention', 'FFT', 'VAR']
        x = self.config_dict.get('base_model', 'CNN')
        if x not in valid:
            raise ValueError(f'{x} is not valid base model {valid}')
        return x

    @property
    def time_method(self):
        valid = ['no-time', 'time', 'trigo', 'deep-time', 'RF']
        x = self.config_dict.get('base_model', 'no-time')
        if x not in valid:
            raise ValueError(f'{x} is not valid time method {valid}')
        return x

    @property
    def experiment_case(self):
        valid = ['regular', 'rand_in', 'rand_out']
        x = self.config_dict.get('experiment_case', 'regular')
        if x not in valid:
            raise ValueError(f'{x} is not valid time method {valid}')
        return x

    @property
    def random_output(self):
        return self.config_dict.get('random_output', False)

    @property
    def num_epoch(self):
        return self.config_dict.get('num_epoch', 10)

    @property
    def num_repetition(self):
        return self.config_dict.get('num_epoch', 10)

    @property
    def debug(self):
        return self.config_dict.get('debug', False)


def maptofloat32(*args):
    return (np.array(x).astype(np.float32) for x in args)


def weather_regular(dataset, target, start_index, end_index, history_size,
                    target_size, step):
    """Use previous data points (one per hour) from the past 120 hours to predict the temperature 12 hours later 
    """
    data = []
    labels = []
    data_time = []
    label_time = []
    start_index = start_index + history_size
    if end_index is None:
        end_index = len(dataset) - target_size

    for i in tqdm(range(start_index, end_index)):
        indices = range(i-history_size, i, step)
        data.append(dataset[indices])
        data_time.append(np.array(list(indices))[
                         ::-1] - i + history_size + target_size)

        labels.append(target[i+target_size])
        label_time.append(i+target_size)

    return maptofloat32(data, labels, data_time)


def weather_randin(dataset, target, start_index, end_index, history_size, target_size, step):
    """Use previous full data points (one per hour) from the past 120 hours to predict the temperature of 
    10 sampled future time points in the future 4-12 hours
    """
    data = []
    labels = []
    data_time = []
    start_index = start_index + history_size
    if end_index is None:
        end_index = len(dataset) - target_size

    for i in tqdm(range(start_index, end_index)):
        #target_size_list = np.random.choice(list(range(min_target, max_target)), num_target, replace = False)
        pos = np.array(list(range(i-history_size, i)))
        mean = np.random.choice(pos, 1)[0]
        prob = scipy.stats.norm.pdf((pos - mean) * 0.05)
        prob = prob / np.sum(prob)
        indices = np.sort(np.random.choice(
            pos, int(history_size*1.0/step), replace=False, p=prob))
        data.append(dataset[indices])
        data_time.append(np.array(list(indices))[
                         ::-1] - i + history_size + target_size)
        labels.append(target[i+target_size])
    return maptofloat32(data, labels, data_time)


def weather_randout(dataset, target, start_index, end_index, history_size,
                    num_target, min_target, max_target, step):
    """Use previous full data points (one per hour) from the past 120 hours 
    to predict the temperature of 10 sampled future time points in the future 4-12 hours
    """
    data = []
    labels = []
    data_time = []
    start_index = start_index + history_size
    if end_index is None:
        end_index = len(dataset) - max_target

    for i in tqdm(range(start_index, end_index)):
        target_size_list = np.random.choice(
            list(range(min_target, max_target)), num_target, replace=False)
        indices = range(i-history_size, i, step)
        for target_size in target_size_list:
            data.append(dataset[indices])
            data_time.append(np.array(list(indices))[
                             ::-1] - i + history_size + target_size)
            labels.append(target[i+target_size])
    return maptofloat32(data, labels, data_time)


def create_weather_data(config, csv_path, append_time=True):
    target_column = 'T (degC)'
    time_column = 'Date Time'
    feat_columns = ['p (mbar)', 'T (degC)', 'Tpot (K)', 'Tdew (degC)',
                    'rh (%)', 'VPmax (mbar)', 'VPact (mbar)', 'VPdef (mbar)', 'sh (g/kg)',
                    'H2OC (mmol/mol)', 'rho (g/m**3)', 'wv (m/s)', 'max. wv (m/s)',
                    'wd (deg)']

    df = pd.read_csv(csv_path)
    features = df[feat_columns]
    features.index = df['Date Time']
    data_meta = {}
    feat = features.values
    feat_mean = feat.mean(axis=0)
    feat_std = feat.std(axis=0)

    feat = (feat - feat_mean) / feat_std

    if config.experiment_case == 'regular':
        TRAIN_VAL_SPLIT = 300000
        VAL_TEST_SPLIT = 350000
        END = 400000
        past_history = 720
        future_target = 72
        STEP = 6
        x_train, y_train, x_time_train = weather_regular(feat, feat[:, 1], 0,
                                                         TRAIN_VAL_SPLIT, past_history,
                                                         future_target, STEP)
        x_val, y_val, x_time_val = weather_regular(feat, feat[:, 1],
                                                   TRAIN_VAL_SPLIT, VAL_TEST_SPLIT, past_history,
                                                   future_target, STEP)
        x_test, y_test, x_time_test = weather_regular(feat, feat[:, 1],
                                                      VAL_TEST_SPLIT, END, past_history,
                                                      future_target, STEP)
    elif config.experiment_case == 'rand_in':
        TRAIN_VAL_SPLIT = 100000
        VAL_TEST_SPLIT = 120000
        END = 140000
        past_history = 720
        min_future = 24
        max_future = 72
        num_future = 10
        future_target = 72
        STEP = 6

        x_train, y_train, x_time_train = weather_randin(feat, feat[:, 1], 0,
                                                        TRAIN_VAL_SPLIT, past_history,
                                                        future_target, STEP)
        x_val, y_val, x_time_val = weather_randin(feat, feat[:, 1],
                                                  TRAIN_VAL_SPLIT, VAL_TEST_SPLIT, past_history,
                                                  future_target, STEP)
        x_test, y_test, x_time_test = weather_randin(feat, feat[:, 1],
                                                     VAL_TEST_SPLIT, END, past_history,
                                                     future_target, STEP)

    elif config.experiment_case == 'rand_out':
        TRAIN_VAL_SPLIT = 50000
        VAL_TEST_SPLIT = 60000
        END = 80000
        past_history = 720
        min_future = 24
        max_future = 72
        num_future = 5
        STEP = 6
        METHOD = 'random'

        x_train, y_train, x_time_train = weather_randout(feat, feat[:, 1], 0,
                                                         TRAIN_VAL_SPLIT, past_history,
                                                         num_future, min_future, max_future, STEP)
        x_val, y_val, x_time_val = weather_randout(feat, feat[:, 1],
                                                   TRAIN_VAL_SPLIT, VAL_TEST_SPLIT, past_history,
                                                   num_future, min_future, max_future, STEP)
        x_test, y_test, x_time_test = weather_randout(feat, feat[:, 1],
                                                      TRAIN_VAL_SPLIT, VAL_TEST_SPLIT, past_history,
                                                      num_future, min_future, max_future, STEP)

    if append_time:
        x_train = np.concatenate(
            (x_train, np.expand_dims(x_time_train, axis=-1)), axis=-1)
        x_val = np.concatenate(
            (x_val, np.expand_dims(x_time_val, axis=-1)), axis=-1)
        x_test = np.concatenate(
            (x_test, np.expand_dims(x_time_test, axis=-1)), axis=-1)

    BATCH_SIZE = 256
    BUFFER_SIZE = 10000

    train_data_single = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_data_single = train_data_single.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    val_data_single = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    val_data_single = val_data_single.batch(BATCH_SIZE)

    test_data_single = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    test_data_single = test_data_single.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

    data_meta['embedding_dim'] = x_train.shape[-1]

    return (train_data_single, val_data_single, test_data_single), data_meta



def create_wiki_data(config, csv_path, append_time=True):
    (train_x, train_y), (val_x, val_y) = (None, None), (None, None)
    if config.experiment_case == 'regular':
        pass
    elif config.experiment_case == 'rand_in':
        df = pd.read_csv(csv_path)
        data_start_date = df.columns[1]
        data_end_date = df.columns[-1]
        page_df = df['Page'].str.rsplit('_', n=3, expand=True) # split page string and expand to multiple columns 
        page_df.columns = ['name','project','access','agent']
        dow_ohe = pd.get_dummies(pd.to_datetime(df.columns[1:]).dayofweek)
        dow_array = np.expand_dims(dow_ohe.values, axis=0) # add sample dimension
        dow_array = np.tile(dow_array,(df.shape[0],1,1))

        page_df = page_df.drop('name', axis=1)
        page_array = pd.get_dummies(page_df).values
        page_array = np.expand_dims(page_array, axis=1) # add timesteps dimension
        page_array = np.tile(page_array,(1,dow_array.shape[1],1)) # repeat OHE array along timesteps dimension 
        exog_array = page_array
        from datetime import timedelta
        pred_steps = 14
        pred_length=timedelta(pred_steps)
        first_day = pd.to_datetime(data_start_date) 
        last_day = pd.to_datetime(data_end_date)
        val_pred_start = last_day - pred_length + timedelta(1)
        val_pred_end = last_day
        train_pred_start = val_pred_start - pred_length
        train_pred_end = val_pred_start - timedelta(days=1)
        enc_length = train_pred_start - first_day
        train_enc_start = first_day
        train_enc_end = train_enc_start + enc_length - timedelta(1)
        val_enc_start = train_enc_start + pred_length
        val_enc_end = val_enc_start + enc_length - timedelta(1)
        date_to_index = pd.Series(index=pd.Index([pd.to_datetime(c) for c in df.columns[1:]]),
                          data=[i for i in range(len(df.columns[1:]))])

        series_array = df[df.columns[1:]].values

        def get_time_block_series(series_array, date_to_index, start_date, end_date):
            
            inds = date_to_index[start_date:end_date]
            return series_array[:,inds]

        def transform_series_encode(series_array):
            
            series_array = np.log1p(np.nan_to_num(series_array)) # filling NaN with 0
            series_mean = series_array.mean(axis=1).reshape(-1,1) 
            series_array = series_array - series_mean
            series_array = series_array.reshape((series_array.shape[0],series_array.shape[1], 1))
            
            return series_array, series_mean

        def transform_series_decode(series_array, encode_series_mean):
            
            series_array = np.log1p(np.nan_to_num(series_array)) # filling NaN with 0
            series_array = series_array - encode_series_mean
            series_array = series_array.reshape((series_array.shape[0],series_array.shape[1], 1))
            
            return series_array

        # Randomly select 200 timepoints from the sequence (original length 522) to predict the 14-day interval values from 523-536
        num_samples = 200
        exog_inds = date_to_index[train_enc_start:train_enc_end]
    
        # sample of series from enc_start to enc_end  
        encoder_input_data = get_time_block_series(series_array, date_to_index, 
                                                train_enc_start, train_enc_end)
        encoder_input_data, encode_series_mean = transform_series_encode(encoder_input_data)

        # sample of series from pred_start to pred_end 
        decoder_target_data = get_time_block_series(series_array, date_to_index, 
                                                    train_pred_start, train_pred_end)
        decoder_target_data = transform_series_decode(decoder_target_data, encode_series_mean)
        train_x, train_y = [], []
        index_range = tqdm(range(encoder_input_data.shape[0])) if not config.debug else tqdm(range(1000))
        for i in index_range:
        # for i in tqdm(range(1000)):
            inds = np.sort(np.random.choice(exog_inds, num_samples, replace = False))
            time_array = exog_inds[-1] - inds
            train_x.append(np.concatenate((np.expand_dims(encoder_input_data[i,inds,0],-1), exog_array[i,inds,:], np.expand_dims(time_array,-1)), axis=-1))
            train_y.append(decoder_target_data[i])

        train_x = np.array(train_x)
        train_y = np.array(train_y)[:, :, 0]

        exog_inds = date_to_index[val_enc_start:val_enc_end]
        exog_inds = exog_inds - min(exog_inds)
            
        # sample of series from enc_start to enc_end  
        encoder_input_data = get_time_block_series(series_array, date_to_index, 
                                                val_enc_start, val_enc_end)
        encoder_input_data, encode_series_mean = transform_series_encode(encoder_input_data)

        # sample of series from pred_start to pred_end 
        decoder_target_data = get_time_block_series(series_array, date_to_index, 
                                                    val_pred_start, val_pred_end)
        decoder_target_data = transform_series_decode(decoder_target_data, encode_series_mean)
        val_x, val_y = [], []
        index_range = tqdm(range(encoder_input_data.shape[0])) if not config.debug else tqdm(range(1000))
        for i in index_range:
            inds = np.sort(np.random.choice(exog_inds, num_samples, replace = False))
            time_array = exog_inds[-1] - inds
            val_x.append(np.concatenate((np.expand_dims(encoder_input_data[i,inds,0],-1), exog_array[i,inds,:], np.expand_dims(time_array,-1)), axis=-1))
            val_y.append(decoder_target_data[i])

        val_x = np.array(val_x)
        val_y = np.array(val_y)[:, :, 0]
        

    elif config.experiment_case == 'rand_out':
        df = pd.read_csv(csv_path)
        data_start_date = df.columns[1]
        data_end_date = df.columns[-1]
        page_df = df['Page'].str.rsplit('_', n=3, expand=True) # split page string and expand to multiple columns 
        page_df.columns = ['name','project','access','agent']
        dow_ohe = pd.get_dummies(pd.to_datetime(df.columns[1:]).dayofweek)
        dow_array = np.expand_dims(dow_ohe.values, axis=0) # add sample dimension
        dow_array = np.tile(dow_array,(df.shape[0],1,1))
        page_df = page_df.drop('name', axis=1)
        page_array = pd.get_dummies(page_df).values
        page_array = np.expand_dims(page_array, axis=1) # add timesteps dimension
        page_array = np.tile(page_array,(1,dow_array.shape[1],1)) # repeat OHE array along timesteps dimension 
        exog_array = page_array
        from datetime import timedelta
        pred_steps = 15
        pred_length=timedelta(pred_steps)
        first_day = pd.to_datetime(data_start_date) 
        last_day = pd.to_datetime(data_end_date)
        train_val_split = int((last_day - first_day).days / 2)
        val_pred_start = last_day - timedelta(train_val_split) + timedelta(1)
        val_pred_end = last_day
        train_pred_start = val_pred_start - pred_length
        train_pred_end = val_pred_start - timedelta(days=1)
        enc_length = train_pred_start - first_day
        train_enc_start = first_day
        train_enc_end = train_enc_start + enc_length - timedelta(1)
        val_enc_start = train_enc_start + pred_length
        val_enc_end = val_enc_start + enc_length - timedelta(1)
        date_to_index = pd.Series(index=pd.Index([pd.to_datetime(c) for c in df.columns[1:]]),
                          data=[i for i in range(len(df.columns[1:]))])
        series_array = df[df.columns[1:]].values
        def get_time_block_series(series_array, date_to_index, start_date, end_date):
            inds = date_to_index[start_date:end_date]
            return series_array[:,inds]
        def transform_series_encode(series_array):
            series_array = np.log1p(np.nan_to_num(series_array)) # filling NaN with 0
            series_mean = series_array.mean(axis=1).reshape(-1,1) 
            series_array = series_array - series_mean
            series_array = series_array.reshape((series_array.shape[0],series_array.shape[1], 1))
            return series_array, series_mean
        def transform_series_decode(series_array, encode_series_mean):
            series_array = np.log1p(np.nan_to_num(series_array)) # filling NaN with 0
            series_array = series_array - encode_series_mean
            series_array = series_array.reshape((series_array.shape[0],series_array.shape[1], 1))
            return series_array

        exog_inds = date_to_index[train_enc_start:train_enc_end]
        # sample of series from enc_start to enc_end  
        encoder_input_data = get_time_block_series(series_array, date_to_index, 
                                                train_enc_start, train_enc_end)
        encoder_input_data, encode_series_mean = transform_series_encode(encoder_input_data)
        # sample of series from pred_start to pred_end 
        decoder_target_data = get_time_block_series(series_array, date_to_index, 
                                                    train_pred_start, train_pred_end)
        decoder_target_data = transform_series_decode(decoder_target_data, encode_series_mean)
        num_samples = 5
        train_x, train_y = [], []
        #for i in tqdm(range(1000)):
        index_range = tqdm(range(encoder_input_data.shape[0])) if not config.debug else tqdm(range(1000))
        for i in index_range:
            inds = exog_inds
            target = np.random.choice(list(range(len(decoder_target_data[i]))), num_samples, replace = False)
            for tar in target:
                time_array = tar + inds[::-1]
                train_x.append(np.concatenate((np.expand_dims(encoder_input_data[i,inds,0],-1), exog_array[i,inds,:], np.expand_dims(time_array,-1)), axis=-1))
                train_y.append(decoder_target_data[i][tar])

        train_x = np.array(train_x)
        train_y = np.array(train_y)
        exog_inds = date_to_index[val_enc_start:val_enc_end]
        exog_inds = exog_inds - min(exog_inds)
        # sample of series from enc_start to enc_end  
        encoder_input_data = get_time_block_series(series_array, date_to_index, 
                                                val_enc_start, val_enc_end)
        encoder_input_data, encode_series_mean = transform_series_encode(encoder_input_data)
        # sample of series from pred_start to pred_end 
        decoder_target_data = get_time_block_series(series_array, date_to_index, 
                                                    val_pred_start, val_pred_end)
        decoder_target_data = transform_series_decode(decoder_target_data, encode_series_mean)
        val_x, val_y = [], []
        index_range = tqdm(range(encoder_input_data.shape[0])) if not config.debug else tqdm(range(1000))
        for i in index_range:
            inds = exog_inds
            target = np.random.choice(list(range(len(decoder_target_data[i]))), num_samples, replace = False)
            for tar in target:
                time_array = tar + inds[::-1]
                val_x.append(np.concatenate((np.expand_dims(encoder_input_data[i,inds,0],-1), exog_array[i,inds,:], np.expand_dims(time_array,-1)), axis=-1))
                val_y.append(decoder_target_data[i][tar])
        val_x = np.array(val_x)
        val_y = np.array(val_y)

    
    return (train_x, train_y), (val_x, val_y)

def rnn_base(output_dim, rnn_units=32, ffn_units=16):
    base_RNN = tf.keras.Sequential()
    base_RNN.add(layers.GRU(rnn_units))
    base_RNN.add(layers.Dense(ffn_units, activation='relu'))
    base_RNN.add(layers.Dense(output_dim))
    base_RNN.compile(optimizer=tf.keras.optimizers.RMSprop(), loss='mae')
    return base_RNN


def rnn_triangular(output_dim, rnn_units=32, ffn_units=16, sin_feat_dim=10, cos_feat_dim=10):
    base_RNN = tf.keras.Sequential()
    base_RNN.add(sin_cos_layer(sin_feat_dim, cos_feat_dim, name='sin_cos'))
    base_RNN.add(layers.GRU(rnn_units))
    base_RNN.add(layers.Dense(ffn_units, activation='relu'))
    base_RNN.add(layers.Dense(output_dim))
    base_RNN.compile(optimizer=tf.keras.optimizers.RMSprop(), loss='mae')
    return base_RNN


def rnn_deep_time(output_dim, rnn_units=32,
                  ffn_units=16,
                  time_dim=32,
                  init_dim=100,
                  inter_dims=(32, 16),
                  time_only=True,
                  invert_layer=True):
    trnn_model = tf.keras.models.Sequential()
    trnn_model.add(layers.RNN(TimeRNNCell(output_dim=rnn_units,
                                          combine='concat',
                                          cell_type='gru',
                                          time_dim=time_dim,
                                          intercept=True,
                                          time_only=time_only,
                                          invert_layer=invert_layer,
                                          distribution='uniform',
                                          init_dim=init_dim,
                                          num_affine_block=1,
                                          inter_dims=inter_dims,
                                          name='trnn_cell')))
    trnn_model.add(layers.Dense(ffn_units, activation='relu'))
    trnn_model.add(layers.Dense(output_dim))
    trnn_model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), loss='mae')
    return trnn_model


def cnn_base(embedding_dim, output_dim, sin_dim=16, cos_dim=16, n_filters=16, filter_width=2, pred_steps=1):
    dilation_rates = [2**i for i in range(5)]
    history_seq = layers.Input(shape=(None, embedding_dim))
    x = history_seq
    for dilation_rate in dilation_rates:
        x = layers.Conv1D(filters=n_filters,
                          kernel_size=filter_width,
                          padding='causal',
                          dilation_rate=dilation_rate,
                          kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.Dense(16, activation='relu')(x)
    x = layers.Dense(output_dim)(x)

    def slice(x, seq_length):
        return x[:, -seq_length:, :]

    pred_seq_train = layers.Lambda(
        slice, arguments={'seq_length': pred_steps})(x)
    cnn_model = Model(history_seq, pred_seq_train)
    cnn_model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return cnn_model


def cnn_triangular(embedding_dim, output_dim, sin_dim=16, cos_dim=16, n_filters=16, filter_width=2, pred_steps=1):
    dilation_rates = [2**i for i in range(5)]
    history_seq = layers.Input(shape=(None, embedding_dim))
    x = history_seq
    x = sin_cos_layer(sin_dim, cos_dim)(x)
    for dilation_rate in dilation_rates:
        x = layers.Conv1D(filters=n_filters,
                          kernel_size=filter_width,
                          padding='causal',
                          dilation_rate=dilation_rate,
                          kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.Dense(16, activation='relu')(x)
    x = layers.Dense(output_dim)(x)

    def slice(x, seq_length):
        return x[:, -seq_length:, :]

    pred_seq_train = layers.Lambda(
        slice, arguments={'seq_length': pred_steps})(x)
    cnn_model = Model(history_seq, pred_seq_train)
    cnn_model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return cnn_model


def cnn_deep_time(embedding_dim, output_dim, time_dim=32, in_dim=32,
                  time_only=True, invert_layer=True, actnorm=True,
                  n_filters=16, filter_width=2, pred_steps=1, num_affine_block=1,
                  init_dim=100, inter_dims=(32, 16)):
    dilation_rates = [2**i for i in range(5)]
    history_seq = layers.Input(shape=(None, embedding_dim))
    x = history_seq[:, :, :-1]
    x_t = tf.expand_dims(history_seq[:, :, -1], axis=-1)
    x = rand_feat_layer(in_dim=in_dim,
                        combine='concat',
                        time_dim=time_dim,
                        intercept=True,
                        time_only=time_only,
                        invert_layer=invert_layer,
                        distribution='uniform',
                        init_dim=init_dim,
                        num_affine_block=num_affine_block,
                        inter_dims=inter_dims,
                        actnorm=actnorm,
                        name='time_layer')(x, x_t)

    # x = layers.Dense(32, activation='relu')(x)
    for dilation_rate in dilation_rates:
        x = layers.Conv1D(filters=n_filters,
                          kernel_size=filter_width,
                          padding='causal',
                          dilation_rate=dilation_rate,
                          kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)

    x = layers.Dense(16, activation='relu')(x)
    x = layers.Dense(output_dim)(x)

    def slice(x, seq_length):
        return x[:, -seq_length:, :]

    pred_seq_train = layers.Lambda(
        slice, arguments={'seq_length': pred_steps})(x)
    tcnn_model = Model(history_seq, pred_seq_train)
    tcnn_model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return tcnn_model


def attention_regular(embedding_dim, output_dim):
    history_seq = layers.Input(shape=(None, embedding_dim))
    x = history_seq
    x_q = layers.Dense(16, activation='relu',
                       kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x_v = layers.Dense(16, activation='relu',
                       kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x_k = layers.Dense(16, activation='relu',
                       kernel_regularizer=tf.keras.regularizers.l2(0.001))(x)
    x = layers.Attention(use_scale=True, causal=True)([x_q, x_v, x_k])
    x = layers.GlobalAveragePooling1D()(x)
    # x = layers.Dense(16, activation='relu')(x)
    pred_seq_train = layers.Dense(output_dim)(x)
    Tattn_model = Model(history_seq, pred_seq_train)
    Tattn_model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return Tattn_model

def attention_deep_time(embedding_dim, output_dim):
    history_seq = layers.Input(shape=(None, embedding_dim))

    x = history_seq[:,:,:-1]
    x_t = tf.expand_dims(history_seq[:,:,-1], axis=-1)
    x = rand_feat_layer(in_dim=embedding_dim-1, 
                     combine = 'concat', 
                     time_dim = 32, 
                     intercept=True, 
                     time_only=True,
                     invert_layer=True,
                     distribution = 'uniform', 
                     init_dim = 100,
                     num_affine_block = 1,
                     inter_dims = [],
                     actnorm = True,
                     l2_reg = 0.05,
                     name = 'time_layer')(x, x_t)
    x_q = layers.Dense(16, activation='relu', kernel_regularizer = tf.keras.regularizers.l2(0.001))(x)
    x_v = layers.Dense(16, activation='relu', kernel_regularizer = tf.keras.regularizers.l2(0.001))(x)
    x_k = layers.Dense(16, activation='relu', kernel_regularizer = tf.keras.regularizers.l2(0.001))(x)

    x = layers.Attention(use_scale=True, causal=True)([x_q, x_v, x_k])
    x = layers.GlobalAveragePooling1D()(x)

    # x = layers.Dense(16, activation='relu')(x)
    pred_seq_train = layers.Dense(output_dim)(x)
    Tattn_model = Model(history_seq, pred_seq_train)
    Tattn_model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return Tattn_model



def attention_triangular(embedding_dim, output_dim, sin_dim=16, cos_dim=16):
    history_seq = layers.Input(shape=(None, embedding_dim))
    x = history_seq
    x = sin_cos_layer(sin_dim, cos_dim)(x)
    x_q = layers.Dense(32, activation='relu')(x)
    x_v = layers.Dense(32, activation='relu')(x)
    x_k = layers.Dense(32, activation='relu')(x)
    x = layers.Attention(use_scale=True, causal=True)([x_q, x_v, x_k])
    x = layers.GlobalAveragePooling1D()(x)

    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(.2)(x)
    pred_seq_train = layers.Dense(output_dim)(x)

    model = Model(history_seq, pred_seq_train)
    model.compile(optimizer=tf.keras.optimizers.Adam(), loss='mae')
    return model 
