import concurrent.futures
import threading
import time
import uuid
from collections import deque, defaultdict
from functools import partial as bind

import numpy as np
import embodied

from . import chunk as chunklib


class NaiveChunks(embodied.Replay):

  def __init__(self, length, capacity=None, directory=None, chunks=1024, seed=0):
    assert 1 <= length <= chunks
    self.length = length
    self.capacity = capacity
    self.directory = directory and embodied.Path(directory)
    self.chunks = chunks
    self.buffers = {}
    self.rng = np.random.default_rng(seed)
    self.ongoing = defaultdict(bind(chunklib.Chunk, chunks))
    if directory:
      self.directory.mkdirs()
      self.workers = concurrent.futures.ThreadPoolExecutor(16)
      self.promises = deque()

  def __len__(self):
    return len(self.buffers) * self.length

  @property
  def stats(self):
    return {'size': len(self), 'chunks': len(self.buffers)}

  def add(self, step, worker=0):
    chunk = self.ongoing[worker]
    chunk.append(step)
    if len(chunk) >= self.chunks:
      self.buffers[chunk.uuid] = self.ongoing.pop(worker)
      self.promises.append(self.workers.submit(chunk.save, self.directory))
      for promise in [x for x in self.promises if x.done()]:
        promise.result()
        self.promises.remove(promise)
    while len(self) > self.capacity:
      del self.buffers[next(iter(self.buffers.keys()))]

  def _sample(self):
    counter = 0
    while not self.buffers:
      if counter % 100 == 0:
        print('Replay sample is waiting')
      time.sleep(0.1)
      counter += 1
    keys = tuple(self.buffers.keys())
    chunk = self.buffers[keys[self.rng.integers(0, len(keys))]]
    idx = self.rng.integers(0, len(chunk) - self.length + 1)
    seq = {k: chunk.data[k][idx: idx + self.length] for k in chunk.data.keys()}
    seq['is_first'][0] = True
    return seq

  def dataset(self):
    while True:
      yield self._sample()

  def save(self, wait=False):
    for chunk in self.ongoing.values():
      if chunk.length:
        self.promises.append(self.workers.submit(chunk.save, self.directory))
    if wait:
      [x.result() for x in self.promises]
      self.promises.clear()

  def load(self, data=None):
    filenames = chunklib.Chunk.scan(self.directory, capacity)
    if not filenames:
      return
    threads = min(len(filenames), 32)
    with concurrent.futures.ThreadPoolExecutor(threads) as executor:
      chunks = list(executor.map(chunklib.Chunk.load, filenames))
    self.buffers = {chunk.uuid: chunk for chunk in chunks}
