import numpy as np
import tensorflow as tf
from const import na

m = tf.math.multiply
mm = tf.linalg.matmul
mv = tf.linalg.matvec
e = tf.math.exp
a = tf.math.add
f = tf.nn.relu

@tf.function
def fwd(x,w1,b1,w2,b2,w3,b3):
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return y

@tf.function
def fwd(x,w1,b1,w2,b2,w3,b3):
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return y

@tf.function
def fwd1(x,w1,b1,w2,b2,w3,b3):
  y = a(mv(w3,f(a(mv(w2,f(a(mv(w1,x,True),b1)),True),b2)),True),b3)
  return y

@tf.function
def fwdv(x,w1,b1,w2,b2,w3,b3):
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return tf.reshape(y, [-1])

@tf.function
def fwdd(x,w1,b1,w2,b2,w3,b3):
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return tf.reshape(y, [-1])

@tf.function
def fwdr(x,w1,b1,w2,b2,w3,b3):
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return tf.reshape(y, [-1])

@tf.function
def fwd_fold(x,w1,b1,w2,b2,w3,b3):
  x = tf.tile(tf.expand_dims(x, 0), [4,1,1])
  b1 = tf.expand_dims(b1, 1)
  b2 = tf.expand_dims(b2, 1)
  b3 = tf.expand_dims(b3, 1)
  y = a(mm(f(a(mm(f(a(mm(x,w1),b1)),w2),b2)),w3),b3)
  return tf.transpose(y, [1,0,2])

@tf.function
def fwd_fold1(x,w1,b1,w2,b2,w3,b3):
  return tf.reshape(a(mv(w3,f(a(mv(w2,f(a(mv(w1,x,True),b1)),True),b2)),True),b3),[4,na])

class Layer:
  def __call__(self, x):
    return x
  @property
  def vars(self):
    return []

class Relu(Layer):
  def __call__(self, x):
    return tf.nn.relu(x)

class Dense(Layer):
  def __init__(self, nx, ny, std=0.5):
    w_init = np.random.randn(nx, ny).astype(np.float32)
    w_init *= std / np.sqrt(np.square(w_init).sum(axis=0, keepdims=True))#np.sqrt(2/nx)
    self.w = tf.Variable(w_init, dtype=tf.float32)
    b_init = tf.zeros(ny, tf.float32)
    self.b = tf.Variable(b_init, dtype=tf.float32)
  def __call__(self, x):
    return x @ self.w + self.b
  @property
  def vars(self):
    return [self.w, self.b]

class BatchDense(Layer):
  def __init__(self, nb, nx, ny, std=0.5):
    w_init = np.random.randn(nb, nx, ny).astype(np.float32)
    w_init *= std / np.sqrt(np.square(w_init).sum(axis=1, keepdims=True))#np.sqrt(2/nx)
    self.w = tf.Variable(w_init, dtype=tf.float32)
    b_init = tf.zeros([nb, ny], tf.float32)
    self.b = tf.Variable(b_init, dtype=tf.float32)
  def __call__(self, x):
    return x @ self.w + self.b
  @property
  def vars(self):
    return [self.w, self.b]
