import numpy as np
import tensorflow as tf
from nn.layer import BatchDense, Dense, Relu, fwd, fwd1, fwd_fold, fwd_fold1, fwdv
from const import ns, na, nh, a_scale
from util.rms import π_rms

class GMMPolicy:
  def __init__(self):
    self.z_net = z_net = []
    self.μ_net = μ_net = []
    self.σ_net = σ_net = []
    self.v_net = v_net = []

    z_net += [Dense(ns,nh//2)]
    z_net += [Relu()]
    z_net += [Dense(nh//2,nh//2)]
    z_net += [Relu()]
    z_net += [Dense(nh//2,4)]

    μ_net += [BatchDense(4,ns,nh//2)]
    μ_net += [Relu()]
    μ_net += [BatchDense(4,nh//2,nh//2)]
    μ_net += [Relu()]
    μ_net += [BatchDense(4,nh//2,na)]

    σ_net += [BatchDense(4,ns,nh//2)]
    σ_net += [Relu()]
    σ_net += [BatchDense(4,nh//2,nh//2)]
    σ_net += [Relu()]
    σ_net += [BatchDense(4,nh//2,na)]
    σ_net[-1].b.assign_add([[-1]*na]*4)

    v_net += [Dense(ns,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,1)]

    self.vars = self._vars
    self.v_vars = self._v_vars
    self.net_vars = self._net_vars

    self.z_vars = self.net_vars[0]
    self.μ_vars = self.net_vars[1]
    self.σ_vars = self.net_vars[2]

  @tf.function
  def fwd(self,s):
    s = π_rms.nrm(s)
    return (fwd(s, *self.z_vars),
          fwd_fold(s, *self.μ_vars),
          fwd_fold(s, *self.σ_vars))

  @tf.function
  def fwd1(self, s):
    s = π_rms.nrm(s)
    return (fwd1(s, *self.z_vars),
          fwd_fold1(s, *self.μ_vars),
          fwd_fold1(s, *self.σ_vars))

  @tf.function
  def fwdv(self, s):
    s = π_rms.nrm(s)
    return fwdv(s, *self.v_vars)

  @tf.function
  def act(self, s):
    lz, μs, σs = self.fwd1(s)
    z = tf.reshape(tf.random.categorical(
        tf.expand_dims(lz,0),1,tf.int32),[])
    μ = tf.gather(μs, z)
    σ = tf.gather(σs, z)
    ε = tf.random.normal(tf.shape(σ))
    u = μ + ε * tf.exp(σ)
    a = a_scale * tf.tanh(u/a_scale)
    return a, u

  @property
  def _vars(self):
    ret = []
    for layer in self.z_net + self.μ_net + self.σ_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _v_vars(self):
    ret = []
    for layer in self.v_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _net_vars(self):
    ret = []
    for net in [self.z_net, self.μ_net, self.σ_net]:
      net_ret = []
      for layer in net:
        net_ret.extend(layer.vars)
      ret.append(net_ret)
    return ret