import pickle

import numpy as np
import tensorflow as tf

import absl.app
import absl.flags
from absl import logging

from .misc_utils import define_flags_with_default, set_random_seed, TensorBoardLogger, print_flags


flags_def = define_flags_with_default(
    learning_rate=1e-3,
    batch_size=256,
    n_steps=10000,
)

def convnet(input_tensor, n_logits=10):
    x = input_tensor

    for i in range(4):
        x = tf.layers.conv2d(x, 32, 3)
        x = tf.nn.leaky_relu(x)

    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, 10)
    return x


def main(_):
    FLAGS = absl.flags.FLAGS

    image_x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 3])
    image_y = tf.placeholder(dtype=tf.int64, shape=[None])

    logits = convnet(image_x)
    prediction = tf.argmax(logits, axis=1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, image_y), tf.float32))

    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=image_y, logits=logits
    )

    loss = tf.reduce_mean(loss)

    train_op = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss)

    init_op = tf.global_variables_initializer()

    with open('../data/rainbow_mnist.pkl', 'rb') as fin:
        data = pickle.load(fin)

    train_images = np.concatenate([x['images'] for x in data['train']])
    train_labels = np.concatenate([x['labels'] for x in data['train']])


    with tf.Session() as sess:
        sess.run(init_op)

        for _ in range(FLAGS.n_steps):
            random_indices = np.random.choice(train_labels.shape[0], FLAGS.batch_size)
            batch_images = train_images[random_indices]
            batch_labels = train_labels[random_indices]

            loss_val, acc_val, _ = sess.run(
                [loss, accuracy, train_op],
                {image_x: batch_images, image_y: batch_labels}
            )
            print(loss_val, '   ', acc_val)


if __name__ == '__main__':
    absl.app.run(main)