# %%
import numpy as np
import tensorflow as tf
import matplotlib.cm as cm
from matplotlib import pyplot as plt
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
import keras
from sklearn.model_selection import train_test_split
from sklearn.datasets import (
    make_circles,
    make_classification,
    make_moons,
    make_blobs,
)
from sklearn.svm import SVC, LinearSVC
from sklearn.ensemble import GradientBoostingClassifier
from tensorflow.keras.layers import (
    BatchNormalization,
    Activation,
    GlobalAveragePooling2D,
    Dense,
)

import tensorflow.keras.activations as activations

# packages for learning from crowdsdata_train = load_data("./crowdlayer/train_data.npy")

from crowdlayer.crowd_layers import (
    CrowdsClassification,
    MaskedMultiCrossEntropy,
)
from crowdlayer.crowd_aggregators import CrowdsCategoricalAggregator
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
from keras.models import Model
import numpy as np


nb_classes = 8
# %%
def load_data(filename):
    data = np.load(filename)
    return data


def lr_scheduler(epoch, lr):
    decay_rate = 0.1
    decay_steps = [50, 75]
    if epoch in decay_steps:
        return lr * decay_rate
    return lr


def create_normal_model():
    model = VGG16(
        include_top=False, input_shape=(224, 224, 3), weights="imagenet"
    )
    x = GlobalAveragePooling2D()(model.layers[-1].output)
    x = Dense(8, activation="softmax")(x)
    return Model(model.inputs, x)


def create_batch_norm_model():

    model = create_normal_model()
    for i, layer in enumerate(model.layers):
        if i == 0:
            input = layer.input
            x = input
        else:
            if "conv" in layer.name:
                layer.activation = activations.linear
                x = layer(x)
                x = BatchNormalization()(x)
                x = Activation("relu")(x)
            else:
                x = layer(x)

    bn_model = Model(input, x)
    return bn_model


def build_base_model(conn_type="MW"):
    opt = tf.keras.optimizers.SGD(
        learning_rate=0.1, name="SGD", momentum=0.9, decay=5e-4
    )
    model = create_batch_norm_model()
    new_model = tf.keras.models.Sequential(model.layers[:-1])
    new_model.add(Dense(8, activation="relu"))
    new_model.add(Activation("softmax"))
    new_model.add(CrowdsClassification(8, N_ANNOT, conn_type=conn_type))
    loss = MaskedMultiCrossEntropy().loss
    new_model.compile(optimizer=opt, loss=loss)
    return new_model


def one_hot(target, n_classes):
    targets = np.array([target]).reshape(-1)
    one_hot_targets = np.eye(n_classes)[targets]
    return one_hot_targets


def reframe_votes(votes):
    answers = []
    for task, ans in votes.items():
        answers.append([-1] * 8)
        for worker, lab in ans.items():
            answers[int(task)][int(worker)] = lab
    return np.array(answers)


def expected_calibration_error(num_bins=15, logits=None, labels_true=None):
    pred_y = np.argmax(logits, axis=-1)
    correct = pred_y == labels_true
    prob_y = np.max(logits, axis=-1)

    b = np.linspace(start=0, stop=1.0, num=num_bins)
    bins = np.digitize(prob_y, bins=b, right=True)

    o = 0
    for b in range(num_bins):
        mask = bins == b
        if np.any(mask):
            o += np.abs(np.sum(correct[mask] - prob_y[mask]))
    return o / logits.shape[0]


def eval_model(model, test_data, test_labels):
    # testset accuracy
    preds_test = model.predict(test_data)
    preds_test_num = np.argmax(preds_test, axis=1)
    accuracy_test = (
        1.0 * np.sum(preds_test_num == test_labels) / len(test_labels)
    )
    ece = expected_calibration_error(15, preds_test, test_labels)
    # ece = tfp.stats.expected_calibration_error(15, preds_test, test_labels)
    return accuracy_test, ece


# %%
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = False
sess = tf.compat.v1.Session(config=config)

# %%
answers = load_data("./crowdlayer/answers.npy")
N_ANNOT = answers.shape[1]
N_CLASSES = 8
answers_bin_missings = []
for i in range(len(answers)):
    row = []
    for r in range(N_ANNOT):
        if answers[i, r] == -1:
            row.append(-1 * np.ones(N_CLASSES))
        else:
            row.append(one_hot(answers[i, r], N_CLASSES)[0, :])
    answers_bin_missings.append(row)
answers_bin_missings = np.array(answers_bin_missings).swapaxes(1, 2)
answers_bin_missings.shape

# %%
data_train = load_data("./crowdlayer/data_train.npy")
data_train = data_train.reshape(-1, 224, 224, 3)
data_test = load_data("./crowdlayer/data_test.npy")
data_test = data_test.reshape(-1, 224, 224, 3)
labels_test = load_data("./crowdlayer/labels_test.npy")

accuracies, eces = [], []
for i in range(5):
    model = build_base_model()
    model.fit(
        data_train,
        answers_bin_missings,
        epochs=100,
        shuffle=True,
        batch_size=64,
        verbose=1,
        callbacks=[
            keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=0)
        ],
    )
    weights = model.layers[5].get_weights()
    model.pop()
    model.compile(
        optimizer="SGD", loss="categorical_crossentropy", metrics=["accuracy"]
    )
    accuracy_test, ece_test = eval_model(model, data_test, labels_test)
    print("Accuracy: Test: %.3f" % (accuracy_test,))
    print("ECE: Test: %.3f" % (ece_test,))
    accuracies.append(accuracy_test)
    eces.append(ece_test)
    print("####### Acc", accuracies, "###### ECE", eces)
# %%
