import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras import models

#broadcast the slots to a grid of size "resolution". create a position invariant grid
def spatial_broadcast(slots, resolution,rel_s_p,rel_s_s,batch_size):
  slots = tf.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
  grid = build_grid(resolution)
  grid =  tf.broadcast_to(grid, (batch_size,11,resolution[0],resolution[1] ,2))    
  rel_s_p = tf.expand_dims(rel_s_p, axis=-2)
  rel_s_p = tf.expand_dims(rel_s_p, axis=-2)  
  rel_s_p =  tf.broadcast_to(rel_s_p, (batch_size,11,resolution[0],resolution[1],2))
  rel_s_s = tf.expand_dims(rel_s_s, axis=-2)
  rel_s_s = tf.expand_dims(rel_s_s, axis=-2)  
  rel_s_s =  tf.broadcast_to(rel_s_s, (batch_size,11,resolution[0],resolution[1],2))
  rel_grid = (grid - rel_s_p) / rel_s_s 
  rel_grid = tf.reshape(rel_grid, (batch_size*11,resolution[0],resolution[1],2))  
  grid = tf.tile(slots, [1, resolution[0], resolution[1], 1])
  return grid, rel_grid

#split masks from the predicted color values.
def unstack_and_split(x, batch_size, num_channels=3):
  unstacked = tf.reshape(x, [batch_size, -1] + x.shape.as_list()[1:])
  channels, masks = tf.split(unstacked, [num_channels, 1], axis=-1)
  return channels, masks

#define the spatial broadcast decoder and recombine slots to one prediction
class SlotAttentionAutoEncoder(layers.Layer):
  def __init__(self, resolution, num_slots, num_iterations, num_channels):
    super().__init__()
    self.resolution = resolution
    self.num_slots = num_slots
    self.num_iterations = num_iterations
    self.num_channels = num_channels

    self.layer_norm = layers.LayerNormalization()

    self.decoder_initial_size = (16, 16)
    self.decoder_cnn = tf.keras.Sequential([
        layers.Conv2DTranspose(
            128, 5, strides=(2, 2), padding="SAME", activation="relu"),
        layers.Conv2DTranspose(
            128, 5, strides=(2, 2), padding="SAME", activation="relu"),
        layers.Conv2DTranspose(
            128, 5, strides=(2, 2), padding="SAME", activation="relu"),
        layers.Conv2DTranspose(
            64, 5, strides=(1, 1), padding="SAME", activation="relu"),
        layers.Conv2DTranspose(
            64, 5, strides=(1, 1), padding="SAME", activation="relu"),
        layers.Conv2DTranspose(
            self.num_channels + 1, 1, strides=(1, 1), padding="SAME", activation=None)
    ], name="decoder_cnn")

    self.dense_pos_decode = tf.keras.Sequential([
        layers.Dense(256, activation="relu"),
        layers.Dense(128)
    ], name="dense_pos_decode")


    self.mlp_inputs_decode = tf.keras.Sequential([
        layers.Dense(256, activation="relu"),
        layers.Dense(128)
    ], name="decode")

  def call(self, image):
    slots, s_p, s_s = image    

    x, rel = spatial_broadcast(slots, self.decoder_initial_size,s_p,s_s,32)
    x = x + self.dense_pos_decode(rel)
    x = self.mlp_inputs_decode(x)
    x = self.decoder_cnn(x)
    recons, masks = unstack_and_split(x, 32, self.num_channels)
    masks = tf.nn.softmax(masks, axis=1)
    recon_combined = tf.reduce_sum(recons * masks, axis=1) 
    return recon_combined, recons, masks, slots


#builds a grid of size resolution that encodes the 2D positions
def build_grid(resolution):
  ranges = [np.linspace(-1., 1., num=res) for res in resolution]
  grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
  grid = np.stack(grid, axis=-1)
  grid = np.reshape(grid, [resolution[0], resolution[1], -1])
  grid = np.expand_dims(grid, axis=0)
  grid = grid.astype(np.float32)
  return grid


def append_pos(inputs, resolution,batch_size):
  grid = build_grid(resolution)
  grid = tf.broadcast_to(grid, (batch_size,resolution[0],resolution[1],2))
  return tf.concat([inputs,grid], axis = -1)


#initializes the decoder. predicts a "num_channels"-dimensional reconstruction
def build_model(resolution, batch_size, num_slots, num_iterations,
                num_channels=3, model_type="object_discovery"):
  if model_type == "object_discovery":
    model_def = SlotAttentionAutoEncoder
  else:
    raise ValueError("Invalid name for model type.")
  num_slots = 11
  input_slots = tf.keras.Input(shape=(num_slots, 128), batch_size=batch_size, name='slots')
  input_s_p = tf.keras.Input(shape=(num_slots, 2), batch_size=batch_size, name='s_p')
  input_s_s = tf.keras.Input(shape=(num_slots, 2), batch_size=batch_size, name='s_s')
  outputs = model_def(resolution, num_slots, num_iterations, num_channels)([input_slots, input_s_p, input_s_s])
  model = tf.keras.Model(inputs=[input_slots, input_s_p, input_s_s], outputs=outputs)
  return model

