"""Common stuff for CIFAR10-based development scripts."""
import collections
import os
from typing import Any, List, Sequence

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

from em.fishers import per_example

from local_scripts.soc import soc_dev_common as sdc

tfd = tfp.distributions


###############################################################################
###############################################################################

BASE_DIR = '/fruitbasket/users/m/project_data/extract_merge1/soc_dev_cifar'

MODEL_CKPT = os.path.join(BASE_DIR, 'conv_ckpt.h5')

###############################################################################
###############################################################################


def make_model(model_type):
    if model_type == 'keras':
        cls = tf.keras.Sequential
    elif model_type == 'hf':
        cls = sdc.HfLogitsSequential
    model = cls([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=[32, 32, 3]),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10),
    ])
    return model


def load_training_dataset(batch_size: int):
    (train_images, train_labels), _ = tf.keras.datasets.cifar10.load_data()
    train_images = train_images / 255.00
    tr_ds = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(train_images),
        tf.data.Dataset.from_tensor_slices(train_labels),
    ))
    return tr_ds.repeat().shuffle(1000).batch(batch_size)


def _train_and_save_model(ckpt: str):
    model = make_model('keras')
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])
    tr_ds = load_training_dataset(batch_size=256)
    model.fit(tr_ds, epochs=6, steps_per_epoch=1024)
    model.save_weights(ckpt)


def load_trained_model(ckpt: str = MODEL_CKPT, model_type: str = 'hf') -> tf.keras.Model:
    ckpt = os.path.expanduser(ckpt)

    # Train and saves the model if it doesn't exist.
    if not os.path.exists(MODEL_CKPT):
        _train_and_save_model(ckpt)

    model = make_model(model_type)
    model.load_weights(ckpt)

    return model


###############################################################################
###############################################################################


def compute_component_size(model: tf.keras.Model) -> int:
    return tf.reduce_sum([tf.size(v) for v in model.trainable_variables]).numpy()


###############################################################################
###############################################################################

def create_dataset_for_per_example_fishers(batch_size: int):
    (train_images, train_labels), _ = tf.keras.datasets.cifar10.load_data()
    train_images = train_images / 255.00
    tr_ds = tf.data.Dataset.from_tensor_slices(train_images)
    return tr_ds.repeat().shuffle(1000).batch(batch_size)


def create_soc_dataset(
    trained_model, 
    soc_batch_size: int, 
    pe_batch_size: int, 
    *, 
    prefetch_factor: int = 4,
):

    pe_ds = create_dataset_for_per_example_fishers(pe_batch_size)

    def gen():
        stream = per_example.stream_per_example_diagonal_fishers(
            trained_model,
            dataset=pe_ds,
            variables=trained_model.trainable_variables,
            unbatch=False,
            expectation_wrt_logits=False,
        )
        for batch_fishers in stream:
            batch_fishers = [
                tf.reshape(f, [pe_batch_size, -1])
                for f in batch_fishers
            ]
            yield tf.concat(batch_fishers, axis=-1)

    prefetch_size = soc_batch_size * prefetch_factor
    component_size = compute_component_size(trained_model)

    soc_ds = tf.data.Dataset.from_generator(
        gen,
        output_signature=tf.TensorSpec(shape=[None, component_size], dtype=tf.float32),
    )
    soc_ds = soc_ds.unbatch().prefetch(prefetch_size).batch(soc_batch_size)
    soc_ds = soc_ds.map(lambda x: (x, x))
    return soc_ds


###############################################################################
###############################################################################

def imshow_cifar_multi(x, row_size: int):
    x = tf.reshape(x, [-1, 32, 32, 3])
    n_images = x.shape[0]
    n_rows = n_images // row_size
    if n_images % row_size:
        n_rows += 1
    n_cols = row_size

    fig, axs = plt.subplots(n_rows, n_cols)
    for i in range(n_images):
        row, col = divmod(i, row_size)
        axs[row, col].axis('off')
        axs[row, col].imshow(x[i], cmap='gray')
    plt.tight_layout()

    plt.show()
