import tensorflow as tf
import numpy as np

import utils
import model
from hyperparameters import hps

from tensorflow_gan.examples.mnist import util as eval_util


def evaluate(c_model, g_models):
    real_image_ds = utils.get_dataset(False)

    distance = utils.TVDistance()
    for (real_imgs, _) in real_image_ds:

        rvs = tf.random.normal(shape=(hps.batch_size, hps.noise_dim))
        rvs = tf.split(rvs, hps.num_gens, axis=0)
        fake_images = tf.nest.map_structure(
            lambda rv, gen_i: gen_i.predict(rv),
            rvs, g_models
        )
        fake_images = tf.concat(fake_images, axis=0)
        fake_logits = c_model.predict(fake_images)
        fake_dists = tf.nn.softmax(fake_logits)
        fake_dists_list = tf.split(fake_dists, hps.num_gens, axis=0)

        distance.update_state(fake_dists_list[1], fake_dists_list[0])

    return distance.result()


def save_im(g_models):

    rvs = tf.random.normal(shape=(hps.num_gens*hps.batch_size, hps.noise_dim))
    rvs = tf.split(rvs, hps.num_gens, axis=0)

    fake_images = tf.nest.map_structure(
        lambda rv, gen_i: gen_i.predict(rv),
        rvs, g_models
    )
    fake_images = tf.concat(fake_images, axis=0)
    generated_imgs = (fake_images * 127.5) + 127.5
    for i in range(hps.num_gens*hps.batch_size):
        img = generated_imgs[i]
        img = tf.keras.preprocessing.image.array_to_img(img)
        filename = "gen/{i}.png".format(i=i)
        filename = hps.gen_img_dir + filename
        img.save(filename)


def average_class_probs(g_models, c_model):
    rvs = tf.random.normal(shape=(hps.num_gens * hps.batch_size, hps.noise_dim))
    rvs = tf.split(rvs, hps.num_gens, axis=0)

    fake_images = tf.nest.map_structure(
        lambda rv, gen_i: gen_i.predict(rv),
        rvs, g_models
    )
    class_list = tf.nest.map_structure(
        lambda batch: tf.nn.softmax(c_model.predict(batch)),
        fake_images
    )

    classes = []
    for i in range(hps.num_gens):
        classes.append(tf.reduce_mean(class_list[i], axis=0))

    return classes

def get_fid(gen_lst):
    (_, _), (real_imgs, labels) = tf.keras.datasets.mnist.load_data()
    real_imgs = (real_imgs - 127.5) / 127.5
    real_imgs = tf.expand_dims(real_imgs, -1)
    real_imgs = tf.cast(real_imgs, tf.float32)

    rvs = tf.random.normal(shape=(real_imgs.shape[0], hps.noise_dim))
    rvs = tf.split(rvs, hps.num_gens, axis=0)

    fake_imgs = tf.nest.map_structure(
        lambda rv, gen_i: gen_i.predict(rv),
        rvs, gen_lst
    )
    fake_imgs = tf.cast(tf.concat(fake_imgs, axis=0), tf.float32)

    fid = eval_util.mnist_frechet_distance(real_imgs, fake_imgs)

    return fid

if __name__ == "__main__":
    np.set_printoptions(suppress=True)
    tf.keras.backend.clear_session()

    classifier = model.Classifier()
    classifier.build((None, 32, 32, 1))
    classifier.load_weights(hps.savedir + "classifier" + ".h5")

    generators = []
    for i in range(hps.num_gens):
        gen = model.Generator(i)
        gen.build((None, hps.noise_dim))
        gen.load_weights(hps.savedir + "gen{}".format(i) + ".h5")
        generators.append(gen)

    dist = evaluate(classifier, generators)
    save_im(generators)

    generators = []
    for i in range(hps.num_gens):
        gen = model.Generator(i, True)
        gen.build((None, hps.noise_dim))
        gen.load_weights(hps.savedir + "gen{}".format(i) + ".h5")
        generators.append(gen)

    FID = get_fid(generators)

    print("Generator mean FID = {}".format(FID))
    print("Generators' image distance = {}".format(dist) + " (TVD)")

    c = average_class_probs(generators, classifier)
    for i in range(hps.num_gens):
        print("Average class probability generated by gen{}(%):".format(i))
        print(c[i].numpy() * 100)
