import numpy as np
import keras
import matplotlib.pyplot as plt


class MCDV:
    """Data valuation using MC drop out in keras."""

    def __init__(self, X_train, y_train, X_test, y_test, hidden_layer_size=100, num_epochs=10, num_steps=100):
        """"""

        # Build model
        input_shape = X_train.shape[1:]
        inp = keras.models.Input(input_shape)
        z = keras.layers.Dense(hidden_layer_size, activation='relu')(inp)
        z = keras.layers.Dropout(0.3)(z, training=True)
        out = keras.layers.Dense(2, activation='softmax')(z)

        model = keras.models.Model(inputs=inp, outputs=out)

        model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])

        # Fit model
        hist = model.fit(X_train, [y_train],
                         batch_size=128,
                         epochs=num_epochs,
                         verbose=0,
                         validation_data=(X_test, [y_test]))

        # Maybe plot training curve
        plt.plot(hist.history['loss'], label="loss")
        plt.plot(hist.history['acc'], label="accuarcy")
        plt.legend()
        plt.show()

        self.model = model
        self.num_steps = num_steps

    def evaluate_mc(self, X_test, y_test, num_steps=100):
        """"""
        acc = [self.model.evaluate(X_test, y_test, verbose=0) for _ in range(num_steps)]
        print("Mean accuarcy:", np.mean(acc), "std:", np.std(acc))
        print("Min:", np.min(acc), "max:", np.max(acc))

    def predict(self, X):
        """"""
        predictions = np.array([np.argmax(self.model.predict(X, verbose=0), axis=1) for _ in range(self.num_steps)])
        bins = np.array([np.bincount(predictions[:, i], minlength=2) for i in range(predictions.shape[1])])
        pred = np.argmax(bins, axis=1)
        return pred

    def predict_dv(self, X):
        """"""
        predictions = np.array([np.argmax(self.model.predict(X, verbose=0), axis=1) for _ in range(self.num_steps)])
        std = np.std(predictions, axis=0)
        return std