import numpy as np
import tensorflow as tf
from const import a_scale

def sample_τ_segment(π, env, size):
  reset = env.reset
  step = env.step
  act = π.act

  normal = lambda x: np.random.normal(size=x).astype('f4')
  s = reset()
  a = env.A.sample()
  states = np.array([s for _ in range(size)], dtype=np.float32)
  actions = np.array([a for _ in range(size)], dtype=np.float32)
  raw_actions = np.array([a for _ in range(size)], dtype=np.float32)
  rewards = np.array([0. for _ in range(size)], dtype=np.float32)
  next_states = np.array([s for _ in range(size)], dtype=np.float32)

  terminals = np.ones(size, np.bool)
  absorbings = np.ones(size, np.bool)
  initials = np.ones(size, np.bool)

  cnt = 0
  new = True
  cur_len = 0
  cur_ret = np.float32(0.)
  ep_len = []
  ep_ret = []

  terminal = False

  while True:
    if cnt > 0 and cnt % size == 0:
      yield {
          "s": tf.constant(states.copy()),
          "a": tf.constant(actions.copy()),
          "u": tf.constant(raw_actions.copy()),
          "r_true": tf.constant(rewards.copy()),
          "ś": tf.constant(next_states.copy()),
          "t": tf.constant(terminals.copy()),
          "b": tf.constant(absorbings.copy()),
          "i": tf.constant(initials.copy()),
          "ep_len": ep_len,
          "ep_ret": ep_ret}
      ep_ret = []
      ep_len = []

    i = cnt % size

    initials[i] = new
    a, u = act(s)
    states[i] = s
    actions[i] = a
    raw_actions[i] = u
    ś, r, terminal, truncated = step(a)
    terminals[i] = terminal
    absorbings[i] = terminal and not truncated
    new = terminal
    rewards[i] = r
    next_states[i] = s = ś
    cur_ret += r
    cur_len += 1
    cnt += 1
    if terminal:
      ep_len.append(cur_len)
      ep_ret.append(cur_ret)
      cur_ret = np.float32(0.)
      cur_len = 0
      s = reset()



def sample_τ_segment_abs(π, env, size):
  reset = env.reset
  step = env.step
  act = π.act

  normal = lambda x: np.random.normal(size=x).astype('f4')
  s = reset()
  a = env.A.sample()
  states = np.array([s for _ in range(size)], dtype=np.float32)
  actions = np.array([a for _ in range(size)], dtype=np.float32)
  raw_actions = np.array([a for _ in range(size)], dtype=np.float32)
  rewards = np.array([0. for _ in range(size)], dtype=np.float32)
  next_states = np.array([s for _ in range(size)], dtype=np.float32)

  terminals = np.ones(size, np.bool)
  absorbings = np.ones(size, np.bool)
  absorbing_masks = np.ones(size, np.bool)
  initials = np.ones(size, np.bool)

  cnt = 0
  new = True
  cur_len = 0
  cur_ret = np.float32(0.)
  ep_len = []
  ep_ret = []

  terminal = False
  absorbing_phase = False

  while True:
    if cnt > 0 and cnt % size == 0:
      yield {
          "s": tf.constant(states.copy()),
          "a": tf.constant(actions.copy()),
          "u": tf.constant(raw_actions.copy()),
          "r_true": tf.constant(rewards.copy()),
          "ś": tf.constant(next_states.copy()),
          "t": tf.constant(terminals.copy()),
          "b": tf.constant(absorbings.copy()),
          "y": tf.constant(absorbing_masks.copy()),
          "i": tf.constant(initials.copy()),
          "ep_len": ep_len,
          "ep_ret": ep_ret}
      ep_ret = []
      ep_len = []

    i = cnt % size

    if absorbing_phase:
      a = env.A.sample()
      u = a_scale * tf.atanh(a / a_scale)

      initials[i] = False
      states[i] = s
      actions[i] = a
      raw_actions[i] = u
      terminals[i] = True
      rewards[i] = 0
      next_states[i] = s = ś
      absorbings[i] = True
      absorbing_masks[i] = True

      absorbing_phase = False
      cnt += 1
    else:
      initials[i] = new
      a, u = act(s)
      states[i] = s
      actions[i] = a
      raw_actions[i] = u
      ś, r, terminal, truncated = step(a)
      terminals[i] = terminal
      absorbings[i] = absorbing_phase = terminal and not truncated
      absorbing_masks[i] = False
      if absorbing_phase:
        ś = env.absorbing_state
      new = terminal
      rewards[i] = r
      next_states[i] = s = ś
      cur_ret += r
      cur_len += 1
      cnt += 1
      if terminal:
        ep_len.append(cur_len)
        ep_ret.append(cur_ret)
        cur_ret = np.float32(0.)
        cur_len = 0
        s = reset()
