import numpy as np
from tqdm import tqdm

from . import interact

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.full = False
        self.position = 0
        self.obs_buffer = []
        self.done_buffer = []
        self.state_buffer = []

    def addExperience(self, ob, state=None, done=False):
        if len(self.obs_buffer) < self.capacity:
            self.obs_buffer.append(None)
            self.done_buffer.append(None)
            self.state_buffer.append(None)
        self.obs_buffer[self.position] = ob
        self.done_buffer[self.position] = done
        self.state_buffer[self.position] = state
        if (self.position + 1) == self.capacity:
            self.full = True
        self.position = (self.position + 1) % self.capacity

    def addEpisode(self, obs, states=None):
        if states is None:
            states = [None]*len(obs)
        for ob, state in zip(obs[:-1], states[:-1]):
            self.addExperience(ob, state, done=False)
        self.addExperience(obs[-1], states[-1], done=True)

    def fill(self, env, seed, single_episode=False):
        print('Filling replay buffer...')
        with tqdm(total=self.capacity) as progress:
            while not self.full:
                obs, _, _, states = interact.runEpisode(env, seed, render=False)
                self.addEpisode(obs, states)
                seed += 1
                progress.update(len(obs))
                # print(len(obs), self.full)

                if single_episode:
                    break
        if single_episode:
            print('Stopping after one episode, as requested.')
        return seed

    def sample(self, n_steps=1, batch_size=1):
        obs = []
        next_obs = []
        for i in range(batch_size):
            ob, next_ob = self._sampleObsPair(n_steps)
            obs.append(ob)
            next_obs.append(next_ob)
        return obs, next_obs

    def _sampleObsPair(self, n_steps=1):
        try:
            max_idx = self.capacity if self.full else len(self)-1
            n_attempts = 20 # Try a few times to get a non-terminating state, then give up
            for attempt in range(n_attempts):
                idx = np.random.randint(max_idx)
                ob, done = self._get(idx)
                next_ob = ob
                if not done:
                    next_done = False
                    for step in range(1, n_steps+1):
                        next_ob, next_done = self._get(idx+step)
                        if next_done: # about to transition into next episode
                            break # stop stepping
                    if not next_done:# found an n_step transition without an episode change
                        break # stop attempting
        except ValueError:
            if len(self) == 0:
                raise IndexError('Tried to sample from empty buffer')
            elif len(self) < 2:
                raise IndexError('Tried to sample from buffer with only one element')
            else:
                raise
        return (ob, next_ob)

    def _get(self, idx):
        idx = idx % self.capacity
        return self.obs_buffer[idx], self.done_buffer[idx]

    def __len__(self):
        return len(self.obs_buffer)

def testEmptyBuffer():
    replay = ReplayBuffer(capacity=8)
    # Sampling from an empty buffer is an error
    try:
        replay.sample()
    except IndexError:
        pass
    else:
        assert False

def testOneElementBuffer():
    replay = ReplayBuffer(capacity=8)
    x = 'foo'
    replay.addExperience(x, done=False)
    # Sampling from a 1-element buffer is an error
    try:
        replay.sample()
    except IndexError:
        pass
    else:
        assert False

def testTwoElementBuffer():
    replay = ReplayBuffer(capacity=8)
    x1 = 'foo'
    x2 = 'bar'
    replay.addExperience(x1, done=False)
    replay.addExperience(x2, done=True)
    y1, y2 = replay.sample()
    y1, y2 = y1[0], y2[0]
    # Sampling from a 2-element buffer always gives the same result
    assert y1 == x1 and y2 == x2

def testFullBuffer():
    capacity = 8
    replay = ReplayBuffer(capacity=capacity)
    x = list(range(capacity + 1))
    replay.addEpisode(x)

    # The extra sample should overwrite the first element
    assert replay.obs_buffer == [capacity] + list(range(1, capacity))

    batch_size = 10000
    x1, x2 = replay.sample(batch_size=batch_size)
    counts1 = np.bincount(x1, minlength=capacity+1)
    counts2 = np.bincount(x2, minlength=capacity+1)
    # The extra sample should overwrite the first element
    assert counts1[0] == counts2[0] == 0
    # The sampled indices for state / next-state pairs should differ by 1
    assert np.all(counts1[1:capacity-1] == counts2[2:capacity])
    assert counts1[-1] == counts2[1]
    n_bad_samples = counts1[-1]
    if n_bad_samples > 0:
        print('Warning: {} of {} samples were non-consecutive (~{}%).'.format(n_bad_samples, batch_size, 100*n_bad_samples/batch_size))

def testNStepSample():
    capacity = 20
    replay = ReplayBuffer(capacity=capacity)
    # put a 'done' in the middle of the buffer
    x = list(range(capacity + capacity//2))
    replay.addEpisode(x)

    batch_size = 10000

    # when n_steps == 0, |x1 - x2| should be equal to 0
    n_steps = 0
    x1, x2 = replay.sample(n_steps=n_steps, batch_size=batch_size)
    x1, x2 = np.asarray(x1), np.asarray(x2)
    assert all(x1-x2 == 0.0), np.mean(np.abs(x1-x2))

    # when n_steps is in [1, capacity/2], |x1 - x2| should be ~= n_steps
    for n_steps in range(1,capacity*2//3):
        x1, x2 = replay.sample(n_steps=n_steps, batch_size=batch_size)
        x1, x2 = np.asarray(x1), np.asarray(x2)
        counts = np.bincount(np.abs(x1-x2))
        n_good_samples = counts[n_steps]
        n_bad_samples = sum(counts)-n_good_samples
        if n_steps <= capacity/3:
            assert n_bad_samples == 0, {'n_steps:': n_steps, 'n_bad_samples': n_bad_samples}
        if n_bad_samples > 0:
            print('Warning: {} of {} {}-step samples were non-consecutive (~{}%).'.format(n_bad_samples, batch_size, n_steps, 100*n_bad_samples/batch_size))

    # when n_steps == capacity, no |x1 - x2| should be equal to n_steps, because
    # there aren't enough samples for an n_step transition
    n_steps = capacity
    x1, x2 = replay.sample(n_steps=n_steps, batch_size=10000)
    x1, x2 = np.asarray(x1), np.asarray(x2)
    mean_step_diff = np.mean(np.abs(x1-x2))
    assert not any(np.isclose(np.abs(x1-x2),n_steps))


def main():
    testEmptyBuffer()
    testOneElementBuffer()
    testTwoElementBuffer()
    testFullBuffer()
    testNStepSample()
    print('Testing complete.')

if __name__ == '__main__':
    main()
