import numpy as np
import tensorflow as tf
from nn.layer import Dense, Relu, fwd, fwd1, fwdv
from const import ns, na, nh, a_scale
from util.rms import π_rms

class GaussianPolicy:
  def __init__(self):
    self.μ_net = μ_net = []
    self.σ_net = σ_net = []
    self.v_net = v_net = []
    self.q_net = q_net = []

    μ_net += [Dense(ns,nh)]
    μ_net += [Relu()]
    μ_net += [Dense(nh,nh)]
    μ_net += [Relu()]
    μ_net += [Dense(nh,na)]

    σ_net += [Dense(ns,nh)]
    σ_net += [Relu()]
    σ_net += [Dense(nh,nh)]
    σ_net += [Relu()]
    σ_net += [Dense(nh,na)]
    σ_net[-1].b.assign_add([-1]*na)

    v_net += [Dense(ns,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,nh)]
    v_net += [Relu()]
    v_net += [Dense(nh,1)]

    self.vars = self._vars
    self.v_vars = self._v_vars
    self.net_vars = self._net_vars

    self.μ_vars = self.net_vars[0]
    self.σ_vars = self.net_vars[1]

  @tf.function
  def fwd(self,s):
    s = π_rms.nrm(s)
    return fwd(s, *self.μ_vars), fwd(s, *self.σ_vars)

  @tf.function
  def fwd1(self, s):
    s = π_rms.nrm(s)
    return fwd1(s, *self.μ_vars), fwd1(s, *self.σ_vars)

  @tf.function
  def fwdv(self, s):
    s = π_rms.nrm(s)
    return fwdv(s, *self.v_vars)

  @tf.function
  def act(self, s):
    μ, σ = self.fwd1(s)
    ε = tf.random.normal(tf.shape(σ))
    u = μ + ε * tf.exp(σ)
    a = a_scale * tf.tanh(u/a_scale)
    return a, u

  @property
  def _vars(self):
    ret = []
    for layer in self.μ_net + self.σ_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _v_vars(self):
    ret = []
    for layer in self.v_net:
      ret.extend(layer.vars)
    return ret

  @property
  def _net_vars(self):
    ret = []
    for net in [self.μ_net, self.σ_net]:
      net_ret = []
      for layer in net:
        net_ret.extend(layer.vars)
      ret.append(net_ret)
    return ret