from mpi4py import MPI
import tensorflow as tf, baselines.common.tf_util as U, numpy as np

class RunningMeanStd(object):
    # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
    def __init__(self, epsilon=1e-2, shape=()):

        self._sum = tf.get_variable(
            dtype=tf.float64,
            shape=shape,
            initializer=tf.constant_initializer(0.0),
            name="runningsum", trainable=False)
        self._sumsq = tf.get_variable(
            dtype=tf.float64,
            shape=shape,
            initializer=tf.constant_initializer(epsilon),
            name="runningsumsq", trainable=False)
        self._count = tf.get_variable(
            dtype=tf.float64,
            shape=(),
            initializer=tf.constant_initializer(epsilon),
            name="count", trainable=False)
        self.shape = shape

        self.mean = tf.to_float(self._sum / self._count)
        self.std = tf.sqrt( tf.maximum( tf.to_float(self._sumsq / self._count) - tf.square(self.mean) , 1e-2 ))

        newsum = tf.placeholder(shape=self.shape, dtype=tf.float64, name='sum')
        newsumsq = tf.placeholder(shape=self.shape, dtype=tf.float64, name='var')
        newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
        self.incfiltparams = U.function([newsum, newsumsq, newcount], [],
            updates=[tf.assign_add(self._sum, newsum),
                     tf.assign_add(self._sumsq, newsumsq),
                     tf.assign_add(self._count, newcount)])


    def update(self, x):
        x = x.astype('float64')
        n = int(np.prod(self.shape))
        totalvec = np.zeros(n*2+1, 'float64')
        addvec = np.concatenate([x.sum(axis=0).ravel(), np.square(x).sum(axis=0).ravel(), np.array([len(x)],dtype='float64')])
        MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
        self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])

    def noupdate(self):
        n = int(np.prod(self.shape))
        zeros = np.zeros((1, n), 'float64')
        totalvec = np.zeros(n*2+1, 'float64')
        addvec = np.concatenate([zeros.sum(axis=0).ravel(), np.square(zeros).sum(axis=0).ravel(), np.array([0],dtype='float64')])
        MPI.COMM_WORLD.Allreduce(addvec, totalvec, op=MPI.SUM)
        self.incfiltparams(totalvec[0:n].reshape(self.shape), totalvec[n:2*n].reshape(self.shape), totalvec[2*n])

@U.in_session
def test_runningmeanstd():
    for (x1, x2, x3) in [
        (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
        (np.random.randn(3,2), np.random.randn(4,2), np.random.randn(5,2)),
        ]:

        rms = RunningMeanStd(epsilon=0.0, shape=x1.shape[1:])
        U.initialize()

        x = np.concatenate([x1, x2, x3], axis=0)
        ms1 = [x.mean(axis=0), x.std(axis=0)]
        rms.update(x1)
        rms.update(x2)
        rms.update(x3)
        ms2 = U.eval([rms.mean, rms.std])

        assert np.allclose(ms1, ms2)

@U.in_session
def test_dist():
    np.random.seed(0)
    p1,p2,p3=(np.random.randn(3,1), np.random.randn(4,1), np.random.randn(5,1))
    q1,q2,q3=(np.random.randn(6,1), np.random.randn(7,1), np.random.randn(8,1))

    # p1,p2,p3=(np.random.randn(3), np.random.randn(4), np.random.randn(5))
    # q1,q2,q3=(np.random.randn(6), np.random.randn(7), np.random.randn(8))

    comm = MPI.COMM_WORLD
    assert comm.Get_size()==2
    if comm.Get_rank()==0:
        x1,x2,x3 = p1,p2,p3
    elif comm.Get_rank()==1:
        x1,x2,x3 = q1,q2,q3
    else:
        assert False

    rms = RunningMeanStd(epsilon=0.0, shape=(1,))
    U.initialize()

    rms.update(x1)
    rms.update(x2)
    rms.update(x3)

    bigvec = np.concatenate([p1,p2,p3,q1,q2,q3])

    def checkallclose(x,y):
        print(x,y)
        return np.allclose(x,y)

    assert checkallclose(
        bigvec.mean(axis=0),
        U.eval(rms.mean)
    )
    assert checkallclose(
        bigvec.std(axis=0),
        U.eval(rms.std)
    )


if __name__ == "__main__":
    # Run with mpirun -np 2 python <filename>
    test_dist()
