import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Conv1D,  Flatten, Dense, Conv1DTranspose, Reshape, Input, LSTM, Lambda, Concatenate

class CSAE(tf.keras.Model):

  def __init__(self, seq_len, latent_dim, feat_dim, hidden_layer_sizes, Lambda=0.1):
    super(CSAE, self).__init__()

    self.hidden_layer_sizes = hidden_layer_sizes    
    self.latent_dim = latent_dim
    self.feat_dim = feat_dim
    self.seq_len = seq_len
    self.Lambda = Lambda 

    self.lstm_x = LSTM(8)
    self.dense_x_encoder = Dense(8, activation="relu")
    self.dense_concat = Dense(64, activation="relu")
    self.dense_x_decoder = Dense(8, activation="relu")

    self.concat = Concatenate()

    self.conv1 = Conv1D(
      filters = self.hidden_layer_sizes[0], 
      kernel_size=3, 
      strides=2, 
      activation='relu', 
      padding='same')
    self.conv2 = Conv1D(
      filters = self.hidden_layer_sizes[1], 
      kernel_size=3, 
      strides=2, 
      activation='relu', 
      padding='same')
    self.flatten = Flatten()
    self.dense1 = Dense(self.latent_dim, name="z_mean")
    
    self.dense3 = Dense(1600, name="dec_dense", activation='relu')  #self.encoder_last_dense_dim, name="dec_dense", activation='relu'
    self.reshape1 = Reshape(target_shape=(-1, self.hidden_layer_sizes[-1]), name="dec_reshape")
    self.convTr1 =  Conv1DTranspose(
      filters = self.hidden_layer_sizes[-1], 
      kernel_size=3, 
      strides=2, 
      padding='same',
      activation='relu')
    self.convTr2 = Conv1DTranspose(
      filters = self.hidden_layer_sizes[-2], 
      kernel_size=3, 
      strides=2, 
      padding='same',
      activation='relu')
    self.convTr3 = Conv1DTranspose(
      filters = self.feat_dim, 
        kernel_size=3, 
        strides=2, 
        padding='same',
        activation='relu')
    self.flatten2 = Flatten(name='dec_flatten')
    self.dense4 = Dense(self.seq_len * self.feat_dim, name="decoder_dense_final")
    self.reshape2 = Reshape(target_shape=(self.seq_len, self.feat_dim))

  def encoder(self, label, x, y):
    x = self.dense_x_encoder(x)
    y = self.conv1(y)  
    y = self.conv2(y)
    y = self.flatten(y) 
    t = self.concat([label,x,y])
    t = self.dense_concat(t)
    z = self.dense1(t)
    return z

  def decoder(self, label, x, z):    
    x = self.dense_x_decoder(x)
    y = self.concat([label,x,z])
    y = self.dense3(y)
    y = self.reshape1(y)
    y = self.convTr1(y)
    y = self.convTr2(y)
    y = self.convTr3(y)
    y = self.flatten2(y)
    y = self.dense4(y)
    y = self.reshape2(y)
    return y

  def reconstruction(self, inputs, out):
    rec = tf.keras.backend.mean(tf.keras.backend.abs(inputs - out), axis=1)
    return rec

  def regularization(self, inputs, out):
    regu = self.Lambda*tf.math.reduce_sum(tf.math.abs(self.z), axis=1)
    return regu

  def loss_(self, inputs, out):
    rec = self.reconstruction(inputs, out)
    regu = self.regularization(inputs, out)
    loss = rec + regu
    return loss

  def cf_generation(self, label_real, label_cf, x, y):
    x = self.lstm_x(x)
    z = self.encoder(label_real, x, y)
    out = self.decoder(label_cf, x,z)
    return out

  def call(self, inputs):
    label, input_x, input_y = inputs
    x = self.lstm_x(input_x)
    z = self.encoder(label, x, input_y)
    self.z = z
    out = self.decoder(label, x, z)
    return out


class CVAE(tf.keras.Model):

  def __init__(self, seq_len, latent_dim, feat_dim, hidden_layer_sizes, recon_weight=200):
    super(CVAE, self).__init__()

    self.hidden_layer_sizes = hidden_layer_sizes    
    self.latent_dim = latent_dim
    self.feat_dim = feat_dim
    self.seq_len = seq_len
    self.recon_weight = recon_weight

    #for the KL analytical calculation
    self.var_prior = 1.
    det_cov_pz = self.var_prior**(self.latent_dim)
    self.log_det_cov_pz = tf.math.log(det_cov_pz)

    self.lstm_x = LSTM(8)
    self.dense_x_encoder = Dense(8, activation="relu")
    self.dense_concat = Dense(64, activation="relu")
    self.dense_x_decoder = Dense(8, activation="relu")

    self.concat = Concatenate()

    self.conv1 = Conv1D(
      filters = self.hidden_layer_sizes[0], 
      kernel_size=3, 
      strides=2, 
      activation='relu', 
      padding='same')
    self.conv2 = Conv1D(
      filters = self.hidden_layer_sizes[1], 
      kernel_size=3, 
      strides=2, 
      activation='relu', 
      padding='same')
    self.flatten = Flatten()
    self.dense1 = Dense(self.latent_dim, name="z_mean")
    self.dense2 = Dense(self.latent_dim, name="z_mean")
    
    self.dense3 = Dense(1600, name="dec_dense", activation='relu')  #self.encoder_last_dense_dim, name="dec_dense", activation='relu'
    self.reshape2 = Reshape(target_shape=(-1, self.hidden_layer_sizes[-1]), name="dec_reshape")
    self.convTr1 =  Conv1DTranspose(
      filters = self.hidden_layer_sizes[-1], 
      kernel_size=3, 
      strides=2, 
      padding='same',
      activation='relu')
    self.convTr2 = Conv1DTranspose(
      filters = self.hidden_layer_sizes[-2], 
      kernel_size=3, 
      strides=2, 
      padding='same',
      activation='relu')
    self.convTr3 = Conv1DTranspose(
      filters = self.feat_dim, 
        kernel_size=3, 
        strides=2, 
        padding='same',
        activation='relu')
    self.flatten2 = Flatten(name='dec_flatten')
    self.dense4 = Dense(self.seq_len * self.feat_dim, name="decoder_dense_final")
    self.reshape3 = Reshape(target_shape=(self.seq_len, self.feat_dim))

  def encoder(self, label, x, y):
    x = self.dense_x_encoder(x)
    y = self.conv1(y)  
    y = self.conv2(y)
    y = self.flatten(y) 
    t = self.concat([label,x,y])
    t = self.dense_concat(t)
    mean = self.dense1(t)
    logvar = self.dense2(t)
    return mean, logvar

  def sample(self, mean, logvar):  
    eps = tf.random.normal(shape=(tf.shape(mean)[0], self.latent_dim))
    z = eps * tf.exp(logvar * .5) + mean
    return z

  def decoder(self, label, x, z):    
    x = self.dense_x_decoder(x)
    y = self.concat([label,x,z])
    y = self.dense3(y)
    y = self.reshape2(y)
    y = self.convTr1(y)
    y = self.convTr2(y)
    y = self.convTr3(y)
    y = self.flatten2(y)
    y = self.dense4(y)
    y = self.reshape3(y)
    return y

  def reconstruction(self, inputs, out):
    logpx_z = tf.keras.backend.mean(tf.keras.backend.abs(inputs - out), axis=1) 
    return logpx_z

  def kl(self, inputs, out):
    mean = self.mean
    logvar = self.logvar
    var = tf.math.exp(logvar)
    det_cov_qz_x = tf.math.reduce_prod(var, axis=1)
    kl = 0.5*(self.log_det_cov_pz - tf.math.log(det_cov_qz_x) - self.latent_dim + tf.math.reduce_sum((mean*mean/self.var_prior), axis=1) + tf.math.reduce_sum(var/self.var_prior, axis=1))
    return kl

  def loss_(self, inputs, out):
    rec = tf.reshape(self.reconstruction(inputs, out), [-1,1])
    kl =  tf.reshape(self.kl(inputs, out), [-1,1])
    vae_loss = (self.recon_weight)*rec + kl
    return vae_loss

  def cf_generation(self, label_real, label_cf, x, y):
    x = self.lstm_x(x)
    mean, logvar = self.encoder(label_real, x, y)
    z = self.sample(mean, logvar)
    out = self.decoder(label_cf, x,z)
    return out

  def call(self, inputs):
    label, input_x, input_y = inputs
    x = self.lstm_x(input_x)
    self.mean, self.logvar = self.encoder(label, x, input_y)
    self.z = self.sample(self.mean, self.logvar)
    z = self.z
    out = self.decoder(label, x,z)
    return out
  
class forecast_model(tf.keras.models.Model):
    def __init__(self, pred_steps):   
        super(forecast_model, self).__init__()
        self.lstm1 = LSTM(32, return_sequences=True)
        self.lstm2 = LSTM(32, return_sequences=False)
        self.concat = Concatenate()
        self.dense1 = Dense(32, activation="relu")
        self.dense2 = Dense(pred_steps)

    def call(self, inputs):
        event , input_ts = inputs
        x = self.lstm1(input_ts)
        x = self.lstm2(x)
        x = self.concat([event, x])
        x = self.dense1(x)
        out = self.dense2(x)
        return out

class event_predictor(tf.keras.models.Model):
    def __init__(self):   
        super(event_predictor, self).__init__()
        self.lstm1 = LSTM(32, return_sequences=True)
        self.lstm2 = LSTM(32, return_sequences=False)
        self.dense1 = Dense(32)
        self.dense2 = Dense(1, activation=tf.keras.activations.sigmoid)

    def call(self, inputs):
        x = self.lstm1(inputs)
        x = self.lstm2(x)
        x = self.dense1(x)
        out = self.dense2(x)
        return out
