from mpi4py import MPI
import numpy as np
import tensorflow as tf
from const import ns, na

@tf.function
def compute_vec(x):
  return tf.concat([tf.reduce_sum(x, 0), tf.reduce_sum(tf.square(x), 0), tf.expand_dims(tf.cast(tf.shape(x)[0], tf.float32), 0)], 0)

class RMS(object):
  def __init__(self, n, decay=0.99999):
    self.n = n
    self._sum   = tf.Variable(tf.zeros(n, tf.float32))
    self._sumsq = tf.Variable(tf.ones(n, tf.float32))
    self._cnt = tf.Variable(tf.ones(1, tf.float32))
    self.decay = tf.constant(decay, tf.float32)
    self.buffer = np.zeros(n*2+1, 'f4')
    self.size_splits = tf.constant([n, n, 1], tf.int32)

  @tf.function
  def inc(self, newval):
    newsum, newsumsq, newcount = tf.split(newval, self.size_splits)

    decay = self.decay
    self._sum.assign(self._sum * decay + tf.convert_to_tensor(newsum))
    self._sumsq.assign(self._sumsq * decay + tf.convert_to_tensor(newsumsq))
    self._cnt.assign(self._cnt * decay + tf.convert_to_tensor(newcount))

  def update(self, x):
    n = self.n
    total = self.buffer
    addvec = compute_vec(x)
    MPI.COMM_WORLD.Allreduce(addvec, total, op=MPI.SUM)
    self.inc(total)

  @tf.function
  def nrm(self, x):
    #return x
    mean = tf.divide(self._sum, self._cnt)
    std = tf.sqrt( tf.maximum(tf.divide(self._sumsq, self._cnt) - tf.square(mean), 1e-8))

    return (x - tf.stop_gradient(mean)) / tf.stop_gradient(std)

  @property
  def vars(self):
    return [self._sum, self._sumsq, self._cnt]

π_rms = None
s_rms = None
a_rms = None

def init():
  global π_rms, s_rms, a_rms
  π_rms = RMS(ns)
  s_rms = RMS(ns)
  a_rms = RMS(na)
