import tensorflow as tf
from tensorflow.keras import layers


def res_block(inputs, norm_type, activation, dropout, ff_dim):
    Norm = layers.LayerNormalization if norm_type == 'L' else layers.BatchNormalization

    # Temporal Linear
    x = Norm(axis=[-2, -1])(inputs)
    x = tf.transpose(x, perm=[0, 2, 1]) # [Batch, Channel, Input Length]
    x = layers.Dense(x.shape[-1], activation=activation)(x)
    x = tf.transpose(x, perm=[0, 2, 1]) # [Batch, Input Length, Channel]
    x = layers.Dropout(dropout)(x)
    res = x + inputs

    # Feature Linear
    x = Norm(axis=[-2, -1])(res)
    x = layers.Dense(ff_dim, activation=activation)(x) # [Batch, Input Length, FF_Dim]
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(inputs.shape[-1])(x) # [Batch, Input Length, Channel]
    x = layers.Dropout(dropout)(x)
    return x + res


def build_model(input_shape, pred_len, norm_type, activation, n_block,
                dropout, ff_dim, target_slice):
    inputs = tf.keras.Input(shape=input_shape)
    x = inputs # [Batch, Input Length, Channel]
    for _ in range(n_block):
        x = res_block(x, norm_type, activation, dropout, ff_dim)

    if target_slice:
        x = x[:,:,target_slice]

    x = tf.transpose(x, perm=[0, 2, 1]) # [Batch, Channel, Input Length]
    x = layers.Dense(pred_len)(x) # [Batch, Channel, Output Length]
    outputs = tf.transpose(x, perm=[0, 2, 1]) # [Batch, Output Length, Channel])

    return tf.keras.Model(inputs, outputs)

