import unittest
from unittest.mock import patch

import ray
from src.rllib.utils.actors import TaskPool


def createMockWorkerAndObjectRef(obj_ref):
    return ({obj_ref: 1}, obj_ref)


class TaskPoolTest(unittest.TestCase):
    @patch("ray.wait")
    def test_completed_prefetch_yieldsAllComplete(self, rayWaitMock):
        task1 = createMockWorkerAndObjectRef(1)
        task2 = createMockWorkerAndObjectRef(2)
        # Return the second task as complete and the first as pending
        rayWaitMock.return_value = ([2], [1])

        pool = TaskPool()
        pool.add(*task1)
        pool.add(*task2)

        fetched = list(pool.completed_prefetch())
        self.assertListEqual(fetched, [task2])

    @patch("ray.wait")
    def test_completed_prefetch_yieldsAllCompleteUpToDefaultLimit(
            self, rayWaitMock):
        # Load the pool with 1000 tasks, mock them all as complete and then
        # check that the first call to completed_prefetch only yields 999
        # items and the second call yields the final one
        pool = TaskPool()
        for i in range(1000):
            task = createMockWorkerAndObjectRef(i)
            pool.add(*task)

        rayWaitMock.return_value = (list(range(1000)), [])

        # For this test, we're only checking the object refs
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, list(range(999)))

        # Finally, check the next iteration returns the final taks
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, [999])

    @patch("ray.wait")
    def test_completed_prefetch_yieldsAllCompleteUpToSpecifiedLimit(
            self, rayWaitMock):
        # Load the pool with 1000 tasks, mock them all as complete and then
        # check that the first call to completed_prefetch only yield 999 items
        # and the second call yields the final one
        pool = TaskPool()
        for i in range(1000):
            task = createMockWorkerAndObjectRef(i)
            pool.add(*task)

        rayWaitMock.return_value = (list(range(1000)), [])

        # Verify that only the first 500 tasks are returned, this should leave
        # some tasks in the _fetching deque for later
        fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=500)]
        self.assertListEqual(fetched, list(range(500)))

        # Finally, check the next iteration returns the remaining tasks
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, list(range(500, 1000)))

    @patch("ray.wait")
    def test_completed_prefetch_yieldsRemainingIfIterationStops(
            self, rayWaitMock):
        # Test for issue #7106
        # In versions of Ray up to 0.8.1, if the pre-fetch generator failed to
        # run to completion, then the TaskPool would fail to clear up already
        # fetched tasks resulting in stale object refs being returned
        pool = TaskPool()
        for i in range(10):
            task = createMockWorkerAndObjectRef(i)
            pool.add(*task)

        rayWaitMock.return_value = (list(range(10)), [])

        # This should fetch just the first item in the list
        try:
            for _ in pool.completed_prefetch():
                # Simulate a worker failure returned by ray.get()
                raise ray.exceptions.RayError
        except ray.exceptions.RayError:
            pass

        # This fetch should return the remaining pre-fetched tasks
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, list(range(1, 10)))

    @patch("ray.wait")
    def test_reset_workers_pendingFetchesFromFailedWorkersRemoved(
            self, rayWaitMock):
        pool = TaskPool()
        # We need to hold onto the tasks for this test so that we can fail a
        # specific worker
        tasks = []

        for i in range(10):
            task = createMockWorkerAndObjectRef(i)
            pool.add(*task)
            tasks.append(task)

        # Simulate only some of the work being complete and fetch a couple of
        # tasks in order to fill the fetching queue
        rayWaitMock.return_value = ([0, 1, 2, 3, 4, 5], [6, 7, 8, 9])
        fetched = [pair[1] for pair in pool.completed_prefetch(max_yield=2)]

        # As we still have some pending tasks, we need to update the
        # completion states to remove the completed tasks
        rayWaitMock.return_value = ([], [6, 7, 8, 9])

        pool.reset_workers([
            tasks[0][0],
            tasks[1][0],
            tasks[2][0],
            tasks[3][0],
            # OH NO! WORKER 4 HAS CRASHED!
            tasks[5][0],
            tasks[6][0],
            tasks[7][0],
            tasks[8][0],
            tasks[9][0]
        ])

        # Fetch the remaining tasks which should already be in the _fetching
        # queue
        fetched = [pair[1] for pair in pool.completed_prefetch()]
        self.assertListEqual(fetched, [2, 3, 5])


if __name__ == "__main__":
    import pytest
    import sys
    sys.exit(pytest.main(["-v", __file__]))
