# coding:utf8
import numpy as np

class DataInput:
  def __init__(self, data, batch_size,model_name):

    self.batch_size = batch_size
    self.data = data
    self.epoch_size = len(self.data) // self.batch_size
    if self.epoch_size * self.batch_size < len(self.data):
      self.epoch_size += 1
    self.i = 0
    self.model_name = model_name
    self.start = 0 

  def __iter__(self):
    return self

  def __next__(self):

    if self.i == self.epoch_size:
      raise StopIteration
  
    if self.model_name == 'CE':
      ts = []
      count = 0
      if self.start>=len(self.data):
        raise StopIteration
      for i in range(self.start, len(self.data)):
        if self.data[i][4]==1:  
          ts.append(self.data[i])
          count+=1
          if count == self.batch_size:
            break
      self.start += self.batch_size
    else:
      ts = self.data[self.i * self.batch_size: min((self.i + 1) * self.batch_size,
                                                   len(self.data))]

    self.i += 1
    u, i, y, sl = [], [], [], []
    display = []
    for t in ts:
      u.append(t[0])
      i.append(t[2])
      y.append(t[3])
      sl.append(len(t[1]))  
      display.append(t[4])
    max_sl = max(sl)

    hist_i = np.zeros([len(ts), max_sl], np.int64)

    k = 0
    for t in ts:
      for l in range(len(t[1])):
        hist_i[k][l] = t[1][l]
      k += 1

    return self.i, (u, i, y, hist_i, sl, display)