"""
MinibatchProx
"""

from functools import partial

import numpy as np
import tensorflow as tf

DEFAULT_OPTIMIZER = partial(tf.train.AdamOptimizer, beta1=0)

# pylint: disable=R0903
class MiniImageNetModel:
    """
    A model for Mini-ImageNet classification.
    """
    def __init__(self, num_classes, optimizer=DEFAULT_OPTIMIZER, **optim_kwargs):
        self.input_ph = tf.placeholder(tf.float32, shape=(None, 84, 84, 3))
        out = self.input_ph
        for _ in range(4):
            out = tf.layers.conv2d(out, 32, 3, padding='same')
            out = tf.layers.batch_normalization(out, training=True)
            out = tf.layers.max_pooling2d(out, 2, 2, padding='same')
            out = tf.nn.relu(out)
        out = tf.reshape(out, (-1, int(np.prod(out.get_shape()[1:]))))
        self.logits = tf.layers.dense(out, num_classes)
        self.label_ph = tf.placeholder(tf.int32, shape=(None,))

        self.lam = tf.placeholder(tf.float32, shape=(None))  # Weight for performing weights regularization
        self.vars = tf.trainable_variables()
        self.w_phs = [tf.placeholder(v.dtype.base_dtype, shape=v.get_shape())
                      for v in self.vars]

        self.w_losses = [tf.reduce_mean(tf.squared_difference(v, val)) for v, val in zip(self.vars, self.w_phs)]
        # pdb.set_trace()
        self.w_loss = tf.add_n(self.w_losses) / len(self.w_losses)



        self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.label_ph,
                                                                   logits=self.logits)
        self.predictions = tf.argmax(self.logits, axis=-1)
        self.minimize_op = optimizer(**optim_kwargs).minimize(self.loss+self.w_loss)