###### import sys
import sys
assert sys.version_info >= (3, 5)
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import tensorflow_probability as tfp
np.random.seed(42)
tf.random.set_seed(42)

randN_05 = keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)


def get_network(dim_x, dim_y, MINE_LAYER=2):
    # Clear old models/graphs
    K.clear_session()

    # total dimensionality of each sample
    total_dim = dim_x + dim_y

    # now each Input is a flat vector of length total_dim
    input_A = keras.Input(shape=(total_dim,))
    input_B = keras.Input(shape=(total_dim,))

    # shared MLP
    transform_layers = [
        layers.Dense(100, kernel_initializer=randN_05, activation="relu"),
        layers.Dropout(0.3),
    ]
    if MINE_LAYER == 2:
        transform_layers += [
            layers.Dense(100, kernel_initializer=randN_05, activation="relu"),
            layers.Dropout(0.3),
        ]
    transform_layers.append(
        layers.Dense(1, kernel_initializer=randN_05, activation=None)
    )
    transform = keras.Sequential(transform_layers)

    # apply to both joint & marginal
    output_A = transform(input_A)   # shape=(batch, 1)
    output_B = transform(input_B)   # shape=(batch, 1)

    # MINE objective via a Lambda
    output_C = layers.Lambda(
        lambda tensors: K.mean(tensors[0])
                        - K.log(K.mean(K.exp(tensors[1])))
    )([output_A, output_B])

    model = keras.Model(inputs=[input_A, input_B], outputs=output_C)
    model.compile(
        loss=loss_func,
        optimizer=keras.optimizers.Nadam(learning_rate=0.001)
    )
    return model
# def get_network(dim_x,dim_y,MINE_LAYER=2):
#     tf.keras.backend.clear_session()
    
#     input_A = keras.layers.Input(shape=[dim_x+dim_y])
#     input_B = keras.layers.Input(shape=[dim_x+dim_y])

#     transform_layers = [
#         layers.Dense(100, kernel_initializer=randN_05, activation="relu"),
#         keras.layers.Dropout(rate=0.3)
#     ]
#     if MINE_LAYER == 2:
#         transform_layers.extend([
#             layers.Dense(100, kernel_initializer=randN_05, activation="relu"),
#             keras.layers.Dropout(rate=0.3)
#         ])

#     transform_layers.append(layers.Dense(1, kernel_initializer=randN_05, activation=None))
    
#     transform = keras.models.Sequential(transform_layers)

#     output_A = transform(input_A)
#     output_B = transform(input_B)
#     output_C = tf.reduce_mean(output_A) - tf.math.log(tf.reduce_mean(tf.exp(output_B))) # MINE
#     #output_C = tf.reduce_mean(output_A) - tf.reduce_mean(tf.exp(output_B))+1 # MINE-f
#     MI_mod = keras.models.Model(inputs=[input_A, input_B], outputs=output_C)
#     MI_mod.compile(loss=loss_func, optimizer=keras.optimizers.Nadam(learning_rate=0.001))
#     return MI_mod

    

def loss_func(inp, outp):
    '''Calculate the loss: scaled negative estimated mutual information'''
    return -outp

def MINE_ready(x_sample, y_sample):
    x_sample1, x_sample2 = tf.split(x_sample, num_or_size_splits=2)
    y_sample1, y_sample2 = tf.split(y_sample, num_or_size_splits=2)
    
     # Ensure both tensors are of type float32
    x_sample1 = tf.cast(x_sample1, dtype=tf.float32)
    x_sample2 = tf.cast(x_sample2, dtype=tf.float32)
    y_sample1 = tf.cast(y_sample1, dtype=tf.float32)
    
    joint_sample = tf.concat([x_sample1, y_sample1], axis=1)
    marg_sample = tf.concat([x_sample2, y_sample1], axis=1)
    return joint_sample,marg_sample

def MINE_MI(x_sample,y_sample,total_epochs=50,MINE_LAYER=2):
    joint_sample,marg_sample = MINE_ready(x_sample,y_sample)
    
    MI_mod = get_network(x_sample.shape[-1], y_sample.shape[-1], MINE_LAYER=MINE_LAYER)
    MI_mod.compile(loss=loss_func, optimizer=keras.optimizers.Adam(learning_rate=0.001,weight_decay =5e-4))
    history_mi = MI_mod.fit((joint_sample, marg_sample), x_sample[0:int(x_sample.shape[0]//2)], epochs=total_epochs,batch_size=200,verbose=0)
    return -np.log2(np.exp(1))*history_mi.history['loss'][-1],history_mi
