# coding:utf8
import numpy as np

class DataInputOnlyX:
  def __init__(self, data, data_start, batch_size):

    self.batch_size = batch_size
    self.data = data
    self.data_start = data_start
    self.epoch_size = len(self.data) // self.batch_size
    if self.epoch_size * self.batch_size < len(self.data):
      self.epoch_size += 1
    self.i = 0

  def __iter__(self):
    return self

  def __next__(self):

    if self.i == self.epoch_size:
      raise StopIteration

    ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
                                                  len(self.data))]
    ts_start = self.data_start[self.i * self.batch_size : min((self.i+1) * self.batch_size,
                                                  len(self.data_start))]
    self.i += 1
    
    u, i, len_pos = [], [], []
    for t in ts:
      u.append(t[0])
      i.append(t[1])  # [pos1, pos2, ..., posX]
      len_pos.append(len(t[1]))
    u_s = []
    for t in ts_start:
        u_s.append(t[0])
    
    return self.i, (u, i, len_pos), (u_s,[])


class DataInputSyn:
  def __init__(self, data, batch_size):

    self.batch_size = batch_size
    self.data = data
    self.epoch_size = len(self.data) // self.batch_size
    if self.epoch_size * self.batch_size < len(self.data):
      self.epoch_size += 1
    self.i = 0

  def __iter__(self):
    return self

  def __next__(self):

    if self.i == self.epoch_size:
      raise StopIteration

    ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
                                                  len(self.data))]
    self.i += 1

    u, i,y,display = [], [],[],[]
    for t in ts:
      u.append(t[0])
      i.append(t[1])
      y.append(t[2])
      display.append(t[3])

    return self.i, (u, i,y,display)