import contextlib
import tempfile
import unittest
from io import StringIO

import numpy as np

from tests.utils import create_dummy_data, preprocess_lm_data, train_language_model

try:
    from pyarrow import plasma

    from fairseq.data.plasma_utils import PlasmaStore, PlasmaView

    PYARROW_AVAILABLE = True
except ImportError:
    PYARROW_AVAILABLE = False

dummy_path = "dummy"


@unittest.skipUnless(PYARROW_AVAILABLE, "")
class TestPlasmaView(unittest.TestCase):
    def setUp(self) -> None:
        self.tmp_file = tempfile.NamedTemporaryFile()  # noqa: P201
        self.path = self.tmp_file.name
        self.server = PlasmaStore.start(path=self.path, nbytes=10000)
        self.client = plasma.connect(self.path, num_retries=10)

    def tearDown(self) -> None:
        self.client.disconnect()
        self.tmp_file.close()
        self.server.kill()

    def test_two_servers_do_not_share_object_id_space(self):
        data_server_1 = np.array([0, 1])
        data_server_2 = np.array([2, 3])
        server_2_path = self.path
        with tempfile.NamedTemporaryFile() as server_1_path:
            server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
            arr1 = PlasmaView(
                data_server_1, dummy_path, 1, plasma_path=server_1_path.name
            )
            assert len(arr1.client.list()) == 1
            assert (arr1.array == data_server_1).all()
            arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
            assert (arr2.array == data_server_2).all()
            assert (arr1.array == data_server_1).all()
            server.kill()

    def test_hash_collision(self):
        data_server_1 = np.array([0, 1])
        data_server_2 = np.array([2, 3])
        arr1 = PlasmaView(data_server_1, dummy_path, 1, plasma_path=self.path)
        assert len(arr1.client.list()) == 1
        arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=self.path)
        assert len(arr1.client.list()) == 1
        assert len(arr2.client.list()) == 1
        assert (arr2.array == data_server_1).all()
        # New hash key based on tuples
        arr3 = PlasmaView(
            data_server_2, dummy_path, (1, 12312312312, None), plasma_path=self.path
        )
        assert (
            len(arr2.client.list()) == 2
        ), "No new object was created by using a novel hash key"
        assert (
            arr3.object_id in arr2.client.list()
        ), "No new object was created by using a novel hash key"
        assert (
            arr3.object_id in arr3.client.list()
        ), "No new object was created by using a novel hash key"
        del arr3, arr2, arr1

    @staticmethod
    def _assert_view_equal(pv1, pv2):
        np.testing.assert_array_equal(pv1.array, pv2.array)

    def test_putting_same_array_twice(self):
        data = np.array([4, 4, 4])
        arr1 = PlasmaView(data, dummy_path, 1, plasma_path=self.path)
        assert len(self.client.list()) == 1
        arr1b = PlasmaView(
            data, dummy_path, 1, plasma_path=self.path
        )  # should not change contents of store
        arr1c = PlasmaView(
            None, dummy_path, 1, plasma_path=self.path
        )  # should not change contents of store

        assert len(self.client.list()) == 1
        self._assert_view_equal(arr1, arr1b)
        self._assert_view_equal(arr1, arr1c)
        PlasmaView(
            data, dummy_path, 2, plasma_path=self.path
        )  # new object id, adds new entry
        assert len(self.client.list()) == 2

        new_client = plasma.connect(self.path)
        assert len(new_client.list()) == 2  # new client can access same objects
        assert isinstance(arr1.object_id, plasma.ObjectID)
        del arr1b
        del arr1c

    def test_plasma_store_full_raises(self):
        with tempfile.NamedTemporaryFile() as new_path:
            server = PlasmaStore.start(path=new_path.name, nbytes=10000)
            with self.assertRaises(plasma.PlasmaStoreFull):
                # 2000 floats is more than 2000 bytes
                PlasmaView(
                    np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
                )
            server.kill()

    def test_object_id_overflow(self):
        PlasmaView.get_object_id("", 2**21)

    def test_training_lm_plasma(self):
        with contextlib.redirect_stdout(StringIO()):
            with tempfile.TemporaryDirectory("test_transformer_lm") as data_dir:
                create_dummy_data(data_dir)
                preprocess_lm_data(data_dir)
                train_language_model(
                    data_dir,
                    "transformer_lm",
                    ["--use-plasma-view", "--plasma-path", self.path],
                    run_validation=True,
                )
