import tensorflow as tf
import numpy as np
from mpi4py import MPI
from const import a_limit

@tf.function
def no_op():
  pass

class DataSet:
  @tf.function
  def init_ptr(self):
    self.ptr.assign(0)
    self.idx.assign(tf.random.shuffle(self.idx))
  @tf.function
  def get_next_batch(self, size):
    tf.cond(tf.less(self.ptr + size, self.len), no_op, self.init_ptr)
    start = tf.identity(self.ptr)
    end = tf.math.add(start, size)
    self.ptr.assign(end)
    return tuple(tf.split(tf.gather(self.data, self.idx[start:end]),
        self.size_splits, -1))

class SA(DataSet):
  def __init__(self, s, a):
    self.n = len(s)
    self.data = tf.constant(np.concatenate((s,a), axis=-1), dtype=tf.float32)
    self.idx = tf.Variable(tf.range(len(s), dtype=tf.int32))
    ns = s.shape[-1]
    na = a.shape[-1]
    self.size_splits = tf.constant([ns, na])
    self.len = s.shape[0]
    self.ptr = tf.Variable(0, dtype=tf.int32, trainable=False)
    self.init_ptr()

class SAS(DataSet):
  def __init__(self, s, a, ś):
    self.n = len(s)
    self.data = tf.constant(np.concatenate((s,a,ś), axis=-1), dtype=tf.float32)
    self.idx = tf.Variable(tf.range(len(s), dtype=tf.int32))
    ns = s.shape[-1]
    na = a.shape[-1]
    self.size_splits = tf.constant([ns, na, ns])
    self.len = s.shape[0]
    self.ptr = tf.Variable(0, dtype=tf.int32, trainable=False)
    self.init_ptr()
  def get_next_batch(self, size):
    tf.cond(tf.less(self.ptr + size, self.len), no_op, self.init_ptr)
    start = tf.identity(self.ptr)
    end = tf.math.add(start, size)
    self.ptr.assign(end)
    s, a, ś = tf.split(tf.gather(self.data, self.idx[start:end]),
        self.size_splits, -1)
    return s, a, ś

class SASB(DataSet):
  def __init__(self, s, a, ś, b):
    self.n = len(s)
    self.data = tf.constant(np.concatenate((s,a,ś,np.expand_dims(b,-1)), axis=-1), dtype=tf.float32)
    self.idx = tf.Variable(tf.range(len(s), dtype=tf.int32))
    ns = s.shape[-1]
    na = a.shape[-1]
    self.size_splits = tf.constant([ns, na, ns, 1])
    self.len = s.shape[0]
    self.ptr = tf.Variable(0, dtype=tf.int32, trainable=False)
    self.init_ptr()
  def get_next_batch(self, size):
    tf.cond(tf.less(self.ptr + size, self.len), no_op, self.init_ptr)
    start = tf.identity(self.ptr)
    end = tf.math.add(start, size)
    self.ptr.assign(end)
    s, a, ś, b = tf.split(tf.gather(self.data, self.idx[start:end]),
        self.size_splits, -1)
    return s, a, ś, tf.reshape(tf.cast(b, tf.bool), [-1])

def get_trj(path, intype):
  τ_data = np.load(path)
  s = τ_data['s'].copy()
  a = np.clip(τ_data['a'].copy(), -a_limit, a_limit)
  ś = τ_data['ś'].copy()
  b = τ_data['b'].copy()

  if intype == 'sa':
    d = SA(s, a)
  elif intype == 'saś':
    d = SAS(s, a, ś)
  elif intype == 'saśb':
    d = SASB(s, a, ś, b)
  else:
    raise NotImplementedError
  if MPI.COMM_WORLD.Get_rank() == 0:
    G = τ_data['ep_rets']
    avg_ret = sum(G)/len(G)
    std_ret = np.std(np.array(G))
    print("Average returns: %f" % avg_ret)
    print("Std for returns: %f" % std_ret)
  τ_data.close()
  return d