import tensorflow as tf
import numpy as np
import dataset_funcs as ds
import pickle
import os
from tensorflow.keras.callbacks import CSVLogger
from tensorflow.keras import backend as K
import cpc_funcs as cpc




tot_samps_from_each=138*512
tot_samps = tot_samps_from_each*4
samps_from_each = 512
num_samples = samps_from_each*4
num_tfrecods = tot_samps // num_samples
save_dir='./nets/'
loss_dir='./nets/losses/'

class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

def build_resnet(num_channels=[16,32,32],num_blocks=[2,2,2],l2_reg=None,reg=None,inits='glorot_uniform',last_layer=True, actreg_all=0):
    inputs = tf.keras.Input(shape=(64,64,3))
    #weight_inits = tf.keras.initializers.TruncatedNormal(stddev=.001)
    weight_inits = inits
    actreg_func = tf.keras.regularizers.l1(actreg_all)
    if l2_reg is not None:
        l2_reg = tf.keras.regularizers.l2(l2_reg)
    for i, num_ch in enumerate(num_channels):
        #Downscale
        if i == 0:
            conv_out = tf.keras.layers.Conv2D(num_ch,3,padding='same',input_shape=(64,64,3), data_format="channels_last",kernel_initializer=weight_inits,kernel_regularizer=l2_reg,activity_regularizer=actreg_func)(inputs)
        else:
            conv_out = tf.keras.layers.Conv2D(num_ch,3,padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
        conv_out = tf.keras.layers.MaxPool2D(pool_size=[3,3],padding='same',strides=[2,2],data_format="channels_last")(conv_out)

        #Residual blocks
        for j in range(num_blocks[i]):
              #with tf.variable_scope('residual_%d_%d' % (i,j)):
              block_input = conv_out
              conv_out = tf.keras.layers.ReLU()(conv_out)
              conv_out = tf.keras.layers.Conv2D(num_ch,3,padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
              conv_out = tf.keras.layers.ReLU()(conv_out)
              conv_out = tf.keras.layers.Conv2D(num_ch, 3, padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
              conv_out += block_input
    conv_out = tf.keras.layers.ReLU()(conv_out)
    flattened = tf.keras.layers.Flatten()(conv_out)
    if not last_layer:
        return inputs, flattened
    if reg is None:
        conv_out = tf.keras.layers.Dense(128,use_bias=True,activation='relu',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(flattened)  
    else:
        conv_out = tf.keras.layers.Dense(128,use_bias=True,activation='relu',kernel_initializer=weight_inits, kernel_regularizer=l2_reg, activity_regularizer=tf.keras.regularizers.l1(reg))(flattened)  
    return inputs, conv_out

def build_resnet_decoder(latents,num_channels=[32,32,16],num_blocks=[2,2,2],l2_reg=None,inits='glorot_uniform',actreg_all=0):
    #weight_inits = tf.keras.initializers.TruncatedNormal(stddev=.001)
    weight_inits = inits
    if l2_reg is not None:
        l2_reg = tf.keras.regularizers.l2(l2_reg)
    actreg_func = tf.keras.regularizers.l1(actreg_all)

    inputs = tf.keras.layers.Dense(2048,use_bias=True,activation=None,kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(latents)
    for i, num_ch in enumerate(num_channels):
        if i == 0:
            conv_out = tf.keras.layers.Reshape((8,8,32))(inputs)
        else:
            conv_out = tf.keras.layers.UpSampling2D(size=2)(conv_out)
            conv_out = tf.keras.layers.Conv2D(num_ch,3,padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)

        #conv_out = tf.keras.layers.ReLU()(conv_out)
        #Residual blocks
        for j in range(num_blocks[i]):
              #with tf.variable_scope('residual_%d_%d' % (i,j)):
              block_input = conv_out
              conv_out = tf.keras.layers.Conv2D(num_ch,3,padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
              conv_out = tf.keras.layers.ReLU()(conv_out)
              conv_out = tf.keras.layers.Conv2D(num_ch, 3, padding='same',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
              conv_out = tf.keras.layers.ReLU()(conv_out)
              conv_out += block_input

    #Upscale
    conv_out = tf.keras.layers.UpSampling2D(size=2)(conv_out)
    conv_out = tf.keras.layers.Conv2D(3,3,padding='same',activation='sigmoid',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(conv_out)
    return conv_out


def build_mlp(l2_reg=None,reg=None,inits='glorot_uniform',actreg_all=0):
    inputs = tf.keras.Input(shape=(107))
    weight_inits = inits
    actreg_func = tf.keras.regularizers.l1(actreg_all)
    if l2_reg is not None:
        l2_reg = tf.keras.regularizers.l2(l2_reg)

    out = tf.keras.layers.Dense(256,use_bias=True,activation='relu',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(inputs)  

    out = tf.keras.layers.Dense(256,use_bias=True,activation='relu',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(out) 

    out = tf.keras.layers.Dense(128,use_bias=True,activation='relu',kernel_initializer=weight_inits,kernel_regularizer=l2_reg, activity_regularizer=actreg_func)(out)

    return inputs, out  


def add_FC_layer(outputs,units,activation=None,use_bias=False,scaler=1):
    if isinstance(outputs,list):
        outputs = tf.keras.layers.concatenate(outputs)
    outputs = scaler*tf.keras.layers.Dense(units,activation=activation,use_bias=use_bias)(outputs)
    return outputs 

class VAE(tf.keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        data=data[0]
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    def call(self, data):
        if len(np.shape(data)) == 5:
            data=data[0]
        z_mean, z_log_var, z  = self.encoder(data)
        y_pred = self.decoder(z) 
        reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(data, y_pred), axis=(1, 2)
                )
            )
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.total_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(reconstruction_loss)
        return y_pred    


def supervised_net(label,model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='truncated_normal',reg=None,l2actreg=0,LR=.001,opt='adam'):
    scaler =  1
    if label == 'all':
        labels = ['task','reward','action','zaxis']
        act_fs = [None,'sigmoid','tanh','tanh']
        output_dims = [4,1,38,3]
        net_name = 'SupAll_mn'+model_num
    if label == 'all3':
        labels = ['task','reward','zaxis']
        act_fs = [None,'sigmoid','tanh']
        output_dims = [4,1,3]
        net_name = 'SupAl3_mn'+model_num
    if label == 'task':
        output_dim = 4
        net_name = 'SupTsk_mn'+model_num
    elif label == 'reward':
        output_dim = 1
        net_name = 'SupRwd_mn'+model_num
        act_f = 'sigmoid'
    elif label == 'action':
        output_dim = 38
        net_name = 'SupAct_mn'+model_num
        act_f = 'tanh'
    elif label == 'joints':
        output_dim = 30
        net_name = 'SupJnt_mn'+model_num
        act_f = 'tanh'
        scaler = np.pi/2
    elif label == 'appendages':
        output_dim = 15
        net_name = 'SupApp_mn'+model_num
        act_f = 'tanh'
    elif label == 'acceleration':
        print('Do not train this')
        #output_dim = 3
        #net_name = 'SupAcc_mn'+model_num
        #act_f = None
    elif label == 'touch':
        print('Do not train this')     
    elif label == 'zaxis':
        output_dim = 3
        net_name = 'SupZax_mn'+model_num
        act_f = 'tanh'



    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        if label == 'task': 
            loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        else:
            loss="mean_squared_error"
        reloaded = True
        net_name = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_IN'+inits+'_ar'+str(reg)+'_l2AR'+str(l2actreg)+'_lr'+str(LR)+'_OP'+opt
        reloaded = False
        inputs, outputs = build_resnet(l2_reg=l2_reg,inits=inits,reg=reg,actreg_all=l2actreg) 

        if label == 'task': 
            outputs = add_FC_layer(outputs,output_dim,use_bias=False)
            loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        elif 'all' in label:
            output_list = []; loss = []
            for l in range(len(labels)):
                if labels[l] =='task':
                    output_list.append(add_FC_layer(outputs,output_dims[l],use_bias=False))
                    loss.append(tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)) #need to rescale so all obj are equal
                else:
                    output_list.append(add_FC_layer(outputs,output_dims[l],activation=act_fs[l],use_bias=True,scaler=scaler))
                    loss.append("mean_squared_error")
            outputs = output_list 
        else:
            outputs = add_FC_layer(outputs,output_dim,activation=act_f,use_bias=True,scaler=scaler)
            loss="mean_squared_error"
        model = tf.keras.Model(inputs, outputs)
        print(model.summary())

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])
    #can use schedule after looking at initial curves
    if opt == 'adadelta':
        optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate_fn)
    if opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    elif opt == 'rmsprop':
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate_fn)
    elif opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_fn)
    elif opt == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate_fn)

    model.compile(optimizer=optimizer, loss=loss)
    train, val = ds.load_and_transform_for_training(label,batch_size) #NEED TO WRITE THIS TO RETURN ALL DIF LABELS ACCORDING TO SAME NAME AS ABOVE...ALSO HOW TO ALIGN THE DIF OUTPUTS WITH THE DIF LABELS....??
    print('Training Supervised '+label+' Network')
    print(net_name)
    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')
    if not reloaded:
        first_val = model.evaluate(val,verbose=1)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)
    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1, callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)

def supervised_proprioaction_net(model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='truncated_normal',reg=None,l2actreg=0,LR=.001,opt='adam'):

    output_dim = 38
    net_name = 'SupPrp_mn'+model_num
    act_f = 'tanh'

    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        loss="mean_squared_error"
        reloaded = True
        net_name = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_IN'+inits+'_ar'+str(reg)+'_l2AR'+str(l2actreg)+'_lr'+str(LR)+'_OP'+opt
        reloaded = False
        inputs, outputs = build_resnet(l2_reg=l2_reg,inits=inits,reg=reg,actreg_all=l2actreg) 

        inputs2, outputs2 = build_mlp(l2_reg=l2_reg,inits=inits,reg=reg,actreg_all=l2actreg) 

        outputs = add_FC_layer([outputs,outputs2],128,activation=act_f,use_bias=True) 
        outputs = add_FC_layer(outputs,output_dim,activation=act_f,use_bias=True) 

        loss="mean_squared_error"
        model = tf.keras.Model((inputs,inputs2), outputs)
        print(model.summary())

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])
    #can use schedule after looking at initial curves
    if opt == 'adadelta':
        optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate_fn)
    if opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    elif opt == 'rmsprop':
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate_fn)
    elif opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_fn)
    elif opt == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate_fn)

    model.compile(optimizer=optimizer, loss=loss)

    train, val = ds.load_and_transform_for_training('proprio',batch_size) 
    print('Training Proprio '+model_num+' Network')
    print(net_name)
    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')
    if not reloaded:
        first_val = model.evaluate(val,verbose=1)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)
    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1, callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)

def supervised_justproprio_net(model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='truncated_normal',reg=None,l2actreg=0,LR=.001,opt='adam'):

    output_dim = 38
    net_name = 'SupJpr_mn'+model_num
    act_f = 'tanh'

    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        loss="mean_squared_error"
        reloaded = True
        net_name = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_IN'+inits+'_ar'+str(reg)+'_l2AR'+str(l2actreg)+'_lr'+str(LR)+'_OP'+opt
        reloaded = False
        inputs, outputs = build_mlp(l2_reg=l2_reg,inits=inits,reg=reg,actreg_all=l2actreg) 

        #outputs = add_FC_layer([outputs,outputs2],128,activation=act_f,use_bias=True) 
        outputs = add_FC_layer(outputs,output_dim,activation=act_f,use_bias=True) 

        loss="mean_squared_error"
        model = tf.keras.Model(inputs, outputs)
        print(model.summary())

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])
    #can use schedule after looking at initial curves
    if opt == 'adadelta':
        optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate_fn)
    if opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    elif opt == 'rmsprop':
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate_fn)
    elif opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_fn)
    elif opt == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate_fn)

    model.compile(optimizer=optimizer, loss=loss)

    train, val = ds.load_and_transform_for_training('jprop',batch_size) 
    print('Training Just-Proprio '+model_num+' Network')
    print(net_name)
    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')
    if not reloaded:
        first_val = model.evaluate(val,verbose=1)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)
    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1, callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)


def unsupervised_autoencoder_net(model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='glorot_uniform',reg=0.0001,LR=.001,l2actreg=0,opt='adam'):

    net_name = 'UnsImg_mn'+model_num
    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        reloaded = True
        net_name = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_IN'+inits+'_ar'+str(reg)+'_l2AR'+str(l2actreg)+'_lr'+str(LR)+'_OP'+opt
        reloaded = False

        inputs, outputs = build_resnet(l2_reg=l2_reg,inits=inits,reg=reg,actreg_all=l2actreg) #10e-5 
        outputs = build_resnet_decoder(outputs,l2_reg=l2_reg,inits=inits,actreg_all=l2actreg)
        model = tf.keras.Model(inputs, outputs)
        print(model.summary()) 

        loss = 'binary_crossentropy'
        metric = tf.keras.metrics.binary_crossentropy
        model = tf.keras.Model(inputs, outputs)
        print(model.summary())

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])

    if opt == 'adadelta':
        optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate_fn)
    if opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    elif opt == 'rmsprop':
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate_fn)
    elif opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_fn)
    elif opt == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate_fn)
#
    model.compile(optimizer=optimizer, loss=loss,metrics=[metric])
    train, val = ds.load_and_transform_for_training('image',batch_size)

    print('Training Unsupervised Autoencoder Network')
    print(net_name)
    if not reloaded:
        first_val = model.evaluate(val,verbose=1)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)

    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')
    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1,callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)

def unsupervised_variautoencoder_net(model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='glorot_uniform',LR=.001,opt='adam',l2actreg=0):

    net_name = 'UnsVim_mn'+model_num
    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        reloaded = True
        net_name  = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_l2AR'+str(l2actreg)+'_IN'+inits+'_lr'+str(LR)+'_OP'+opt
        reloaded = False

        inputs, outputs = build_resnet(l2_reg=l2_reg,actreg_all=l2actreg,inits=inits,last_layer=True) #10e-5

        z_mean = add_FC_layer(outputs,128,use_bias=True)
        z_log_var = add_FC_layer(outputs,128,use_bias=True)
        z = Sampling()([z_mean, z_log_var])
        encoder = tf.keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
        print(encoder.summary()) 

        latent_inputs = tf.keras.Input(shape=(128,))
        decoder_outputs = build_resnet_decoder(latent_inputs,l2_reg=l2_reg,actreg_all=l2actreg,inits=inits)

        decoder = tf.keras.Model(latent_inputs, decoder_outputs, name="decoder")
        print(decoder.summary())  
        
        model = VAE(encoder, decoder)

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])

    if opt == 'adadelta':
        optimizer = tf.keras.optimizers.Adadelta(learning_rate=learning_rate_fn)
    if opt == 'adam':
        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    elif opt == 'rmsprop':
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate_fn)
    elif opt == 'sgd':
        optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate_fn)
    elif opt == 'adagrad':
        optimizer = tf.keras.optimizers.Adagrad(learning_rate=learning_rate_fn)
#
    model.compile(optimizer=optimizer)

    train, val = ds.load_and_transform_for_training('image',batch_size)

    print('Training Unsupervised Variational Autoencoder Network')
    print(net_name)
    if not reloaded:
        #print('pred',model.predict(val)); b=d
        first_val = model.evaluate(val,verbose=1); print('FV',first_val)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)

    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')
    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1,callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)
        encoder.save(save_dir+net_name+'_enc')

def unsupervised_cpc_net(model_num,batch_size=128,save=True,epochs=75,l2_reg=.001,inits='glorot_uniform',LR=.001,l2actreg=0):
    net_name = 'UnsCpc_mn'+model_num
    opt = 'adam'
    try:
        net_name_load = [f for f in os.listdir(loss_dir) if net_name in f][0][:-10]
        model = tf.keras.models.load_model(save_dir+net_name_load)
        print(model.summary())
        print('Pre-trained model loaded')
        reloaded = True
        net_name = net_name_load


    except:
        net_name += '_L2'+str(l2_reg)+'_l2AR'+str(l2actreg)+'_IN'+inits+'_lr'+str(LR)+'_OP'+opt
        reloaded = False
        K.set_learning_phase(1)
        #inits=tf.keras.initializers.TruncatedNormal(stddev=0.01)
        inputs, outputs = build_resnet(l2_reg=l2_reg,inits=inits,actreg_all=l2actreg) #10e-5
        outputs = K.expand_dims(outputs, axis=1)
        outputs = K.tile(outputs,[1,2,1])

        encoder_model = tf.keras.models.Model(inputs, outputs, name='encoder')

        # Define rest of model
        x_input = tf.keras.layers.Input((64,64,3)) #frames in, 1 removed
        #x_encoded = tf.keras.layers.TimeDistributed(encoder_model)(x_input)
        x_encoded = encoder_model(x_input)

        context = cpc.network_autoregressive(x_encoded)
        preds = cpc.network_prediction(context, 128, 1) #code size, pred frames

        y_input = tf.keras.layers.Input((64,64,3)) #frames out, 1 removed
        y_encoded = encoder_model(y_input)

        dot_product_probs = cpc.CPCLayer()([preds, y_encoded])

        model = tf.keras.models.Model(inputs=[x_input, y_input], outputs=dot_product_probs)


    # Prepare data


    train, val = ds.load_and_transform_for_CPCtraining(batch_size)

    learning_rate_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay([25*(tot_samps//batch_size),50*(tot_samps//batch_size)], [LR,LR/10,LR/100])
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
    model.compile(
            optimizer=optimizer,
            loss='binary_crossentropy',
            metrics=['binary_accuracy']
        )

    print('Training Unsupervised Contrastive Predictive Coding Network')
    print(net_name)
    if not reloaded:
        #print('pred',model.predict(val)); b=d
        first_val = model.evaluate(val,verbose=1); print('FV',first_val)
        np.save(save_dir+'firstvals/'+net_name+'Firstval.npy',first_val)

    csv_logger = CSVLogger(loss_dir+net_name+'Losses.csv', append=True, separator=';')

    model.fit(train,verbose=2,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1,callbacks=[csv_logger])
    if save:
        model.save(save_dir+net_name)
        encoder_model.save(save_dir+net_name+'_enc')



 

def random_net(model_num,save=True):

    inputs, outputs = build_resnet() 
    model = tf.keras.Model(inputs, outputs)
    print(model.summary())
    #opt = tf.keras.optimizers.Adadelta(learning_rate=0.001)
    #model.compile(optimizer=opt, loss='binary_crossentropy')
    #train, val = ds.load_and_transform_for_training('image',batch_size)
    print('Training Random Network')
    #model.fit(train,verbose=1,epochs=epochs,steps_per_epoch=tot_samps//batch_size,validation_data=val,validation_steps=None,validation_freq=1)
    if save:
        model.save(save_dir+"UntRnd_mn"+model_num)

#def local_agg_net(model_num)



