import tensorflow as tf
import numpy as np
from mpi4py import MPI

def set_flat6(flat, flat_sizes, v1, v2, v3, v4, v5, v6):
  f1, f2, f3, f4, f5, f6 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))

def set_flat7(flat, flat_sizes, v1, v2, v3, v4, v5, v6, v7):
  f1, f2, f3, f4, f5, f6, f7 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))

def set_flat10(flat, flat_sizes, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))

def set_flat12(flat, flat_sizes, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))

def set_flat13(flat, flat_sizes, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))
  v13.assign(tf.reshape(f13, tf.shape(v13)))

def set_flat18(flat, flat_sizes,
    v1, v2, v3, v4, v5, v6, v7, v8, v9, v10,
    v11,v12,v13,v14,v15,v16,v17,v18):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11,f12,f13,f14,f15,f16,f17,f18 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))
  v13.assign(tf.reshape(f13, tf.shape(v13)))
  v14.assign(tf.reshape(f14, tf.shape(v14)))
  v15.assign(tf.reshape(f15, tf.shape(v15)))
  v16.assign(tf.reshape(f16, tf.shape(v16)))
  v17.assign(tf.reshape(f17, tf.shape(v17)))
  v18.assign(tf.reshape(f18, tf.shape(v18)))


def set_flat20(flat, flat_sizes,
    v1, v2, v3, v4, v5, v6, v7, v8, v9, v10,
    v11,v12,v13,v14,v15,v16,v17,v18,v19,v20):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11,f12,f13,f14,f15,f16,f17,f18,f19,f20 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))
  v13.assign(tf.reshape(f13, tf.shape(v13)))
  v14.assign(tf.reshape(f14, tf.shape(v14)))
  v15.assign(tf.reshape(f15, tf.shape(v15)))
  v16.assign(tf.reshape(f16, tf.shape(v16)))
  v17.assign(tf.reshape(f17, tf.shape(v17)))
  v18.assign(tf.reshape(f18, tf.shape(v18)))
  v19.assign(tf.reshape(f19, tf.shape(v19)))
  v20.assign(tf.reshape(f20, tf.shape(v20)))


def set_flat24(flat, flat_sizes,
    v1, v2, v3, v4, v5, v6, v7, v8, v9, v10,
    v11,v12,v13,v14,v15,v16,v17,v18,v19,v20,
    v21,v22,v23,v24):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))
  v13.assign(tf.reshape(f13, tf.shape(v13)))
  v14.assign(tf.reshape(f14, tf.shape(v14)))
  v15.assign(tf.reshape(f15, tf.shape(v15)))
  v16.assign(tf.reshape(f16, tf.shape(v16)))
  v17.assign(tf.reshape(f17, tf.shape(v17)))
  v18.assign(tf.reshape(f18, tf.shape(v18)))
  v19.assign(tf.reshape(f19, tf.shape(v19)))
  v20.assign(tf.reshape(f20, tf.shape(v20)))
  v21.assign(tf.reshape(f21, tf.shape(v21)))
  v22.assign(tf.reshape(f22, tf.shape(v22)))
  v23.assign(tf.reshape(f23, tf.shape(v23)))
  v24.assign(tf.reshape(f24, tf.shape(v24)))


def set_flat30(flat, flat_sizes,
    v1, v2, v3, v4, v5, v6, v7, v8, v9, v10,
    v11,v12,v13,v14,v15,v16,v17,v18,v19,v20,
    v21,v22,v23,v24,v25,v26,v27,v28,v29,v30):
  f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25,f26,f27,f28,f29,f30 = tf.split(flat, flat_sizes)
  v1.assign(tf.reshape(f1, tf.shape(v1)))
  v2.assign(tf.reshape(f2, tf.shape(v2)))
  v3.assign(tf.reshape(f3, tf.shape(v3)))
  v4.assign(tf.reshape(f4, tf.shape(v4)))
  v5.assign(tf.reshape(f5, tf.shape(v5)))
  v6.assign(tf.reshape(f6, tf.shape(v6)))
  v7.assign(tf.reshape(f7, tf.shape(v7)))
  v8.assign(tf.reshape(f8, tf.shape(v8)))
  v9.assign(tf.reshape(f9, tf.shape(v9)))
  v10.assign(tf.reshape(f10, tf.shape(v10)))
  v11.assign(tf.reshape(f11, tf.shape(v11)))
  v12.assign(tf.reshape(f12, tf.shape(v12)))
  v13.assign(tf.reshape(f13, tf.shape(v13)))
  v14.assign(tf.reshape(f14, tf.shape(v14)))
  v15.assign(tf.reshape(f15, tf.shape(v15)))
  v16.assign(tf.reshape(f16, tf.shape(v16)))
  v17.assign(tf.reshape(f17, tf.shape(v17)))
  v18.assign(tf.reshape(f18, tf.shape(v18)))
  v19.assign(tf.reshape(f19, tf.shape(v19)))
  v20.assign(tf.reshape(f20, tf.shape(v20)))
  v21.assign(tf.reshape(f21, tf.shape(v21)))
  v22.assign(tf.reshape(f22, tf.shape(v22)))
  v23.assign(tf.reshape(f23, tf.shape(v23)))
  v24.assign(tf.reshape(f24, tf.shape(v24)))
  v25.assign(tf.reshape(f25, tf.shape(v25)))
  v26.assign(tf.reshape(f26, tf.shape(v26)))
  v27.assign(tf.reshape(f27, tf.shape(v27)))
  v28.assign(tf.reshape(f28, tf.shape(v28)))
  v29.assign(tf.reshape(f29, tf.shape(v29)))
  v30.assign(tf.reshape(f30, tf.shape(v30)))

set_flat_call = {
  6: tf.function(set_flat6),
  7: tf.function(set_flat7),
  10: tf.function(set_flat10),
  12: tf.function(set_flat12),
  13: tf.function(set_flat13),
  18: tf.function(set_flat18),
  20: tf.function(set_flat20),
  24: tf.function(set_flat24),
  30: tf.function(set_flat30)
}

class SetFlat(object):
  def __init__(self, var_list):
    self.var_list = var_list
    self.shapes = [var.shape.as_list() for var in var_list]
    self.flat_sizes = tf.constant([int(np.prod(shape)) for shape in self.shapes], tf.int32)
    self.fn = set_flat_call[len(var_list)]
  def __call__(self, θ):
    self.fn(θ, self.flat_sizes, *self.var_list)

class GetFlat(object):
  def __init__(self, var_list):
    self.var_list = var_list
  @tf.function
  def __call__(self):
    return tf.concat([tf.reshape(v, [-1]) for v in self.var_list], 0)

class MpiOptimizer:
  def __init__(self, var_list):
    self.set_flat = SetFlat(var_list)
    self.get_flat = GetFlat(var_list)
    self.comm = MPI.COMM_WORLD
    self.comm_size = self.comm.Get_size()
    self.buffer = np.empty(int(tf.reduce_sum(self.set_flat.flat_sizes)), dtype=np.float32)
  def update(self, localg, stepsize):
    raise NotImplementedError
  def sync(self):
    if self.comm is None:
      return
    θ = self.get_flat().numpy()
    self.comm.Bcast(θ, root=0)
    self.set_flat(θ)
  def check_synced(self):
    if self.comm is None:
      return
    if self.comm.Get_rank() == 0: # this is root
      θ = self.get_flat().numpy()
      self.comm.Bcast(θ, root=0)
    else:
      θ_local = self.get_flat().numpy()
      self.comm.Bcast(self.buffer, root=0)
      θ_root = self.buffer
      assert not(np.isnan(θ_root).any())
      if not((θ_root == θ_local).all()):
        self.set_flat(θ_root)

class MpiAdam(MpiOptimizer):
  def __init__(self, var_list, *, β1=0.9, β2=0.999):
    self.β1 = β1
    self.β2 = β2
    total_size = sum([int(np.prod(v.shape.as_list())) for v in var_list])
    self.m = tf.Variable(tf.zeros(total_size, tf.float32))
    self.v = tf.Variable(tf.zeros(total_size, tf.float32))
    self.t = 0
    super().__init__(var_list)

  @tf.function
  def update_tf(self, globalg, stepsize):
    β1 = self.β1
    β2 = self.β2
    # a = stepsize
    # a = stepsize * tf.sqrt(1 - β2**(self.t))/(1 - β1**(self.t))
    a = stepsize * tf.sqrt(1 - β2**(10*self.t))/(1 - β1**(10*self.t))
    self.m.assign(β1 * self.m + (1 - β1) * globalg)
    self.v.assign(β2 * self.v + (1 - β2) * tf.square(globalg))
    step = (- a) * self.m / (tf.sqrt(self.v) + 1e-8)
    # step = (- stepsize) * self.m / (tf.sqrt(self.v) + 1e-8)
    θ_before = self.get_flat()
    θ = θ_before + step
    θ = tf.where(tf.math.is_finite(θ), θ, θ_before)
    self.set_flat(θ)

  def update(self, localg, stepsize):
    comm = self.comm
    if self.comm_size > 1 and self.t % 100 == 0:
      self.check_synced()
    self.t += 10
    if self.comm_size > 1:
      buff = self.buffer
      comm.Allreduce(localg, buff, op=MPI.SUM)
      globalg = tf.divide(buff, self.comm_size)
      self.update_tf(globalg, stepsize)
    else:
      self.update_tf(localg, stepsize)

