# Copyright (c) Facebook, Inc. and its affiliates.

import os
import pickle
import sys
import unittest
from functools import partial
import torch
from iopath.common.file_io import LazyPath

from detectron2 import model_zoo
from detectron2.config import get_cfg, instantiate
from detectron2.data import (
    DatasetCatalog,
    DatasetFromList,
    MapDataset,
    ToIterableDataset,
    build_batch_data_loader,
    build_detection_test_loader,
    build_detection_train_loader,
)
from detectron2.data.common import (
    AspectRatioGroupedDataset,
    set_default_dataset_from_list_serialize_method,
)
from detectron2.data.samplers import InferenceSampler, TrainingSampler


def _a_slow_func(x):
    return "path/{}".format(x)


class TestDatasetFromList(unittest.TestCase):
    # Failing for py3.6, likely due to pickle
    @unittest.skipIf(sys.version_info.minor <= 6, "Not supported in Python 3.6")
    def test_using_lazy_path(self):
        dataset = []
        for i in range(10):
            dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))})

        dataset = DatasetFromList(dataset)
        for i in range(10):
            path = dataset[i]["file_name"]
            self.assertTrue(isinstance(path, LazyPath))
            self.assertEqual(os.fspath(path), _a_slow_func(i))

    def test_alternative_serialize_method(self):
        dataset = [1, 2, 3]
        dataset = DatasetFromList(dataset, serialize=torch.tensor)
        self.assertEqual(dataset[2], torch.tensor(3))

    def test_change_default_serialize_method(self):
        dataset = [1, 2, 3]
        with set_default_dataset_from_list_serialize_method(torch.tensor):
            dataset_1 = DatasetFromList(dataset, serialize=True)
            self.assertEqual(dataset_1[2], torch.tensor(3))
        dataset_2 = DatasetFromList(dataset, serialize=True)
        self.assertEqual(dataset_2[2], 3)


class TestMapDataset(unittest.TestCase):
    @staticmethod
    def map_func(x):
        if x == 2:
            return None
        return x * 2

    def test_map_style(self):
        ds = DatasetFromList([1, 2, 3])
        ds = MapDataset(ds, TestMapDataset.map_func)
        self.assertEqual(ds[0], 2)
        self.assertEqual(ds[2], 6)
        self.assertIn(ds[1], [2, 6])

    def test_iter_style(self):
        class DS(torch.utils.data.IterableDataset):
            def __iter__(self):
                yield from [1, 2, 3]

        ds = DS()
        ds = MapDataset(ds, TestMapDataset.map_func)
        self.assertIsInstance(ds, torch.utils.data.IterableDataset)

        data = list(iter(ds))
        self.assertEqual(data, [2, 6])

    def test_pickleability(self):
        ds = DatasetFromList([1, 2, 3])
        ds = MapDataset(ds, lambda x: x * 2)
        ds = pickle.loads(pickle.dumps(ds))
        self.assertEqual(ds[0], 2)


class TestAspectRatioGrouping(unittest.TestCase):
    def test_reiter_leak(self):
        data = [(1, 0), (0, 1), (1, 0), (0, 1)]
        data = [{"width": a, "height": b} for (a, b) in data]
        batchsize = 2
        dataset = AspectRatioGroupedDataset(data, batchsize)

        for _ in range(5):
            for idx, __ in enumerate(dataset):
                if idx == 1:
                    # manually break, so the iterator does not stop by itself
                    break
            # check that bucket sizes are valid
            for bucket in dataset._buckets:
                self.assertLess(len(bucket), batchsize)


class _MyData(torch.utils.data.IterableDataset):
    def __iter__(self):
        while True:
            yield 1


class TestDataLoader(unittest.TestCase):
    def _get_kwargs(self):
        # get kwargs of build_detection_train_loader
        cfg = model_zoo.get_config("common/data/coco.py").dataloader.train
        cfg.dataset.names = "coco_2017_val_100"
        cfg.pop("_target_")
        kwargs = {k: instantiate(v) for k, v in cfg.items()}
        return kwargs

    def test_build_dataloader_train(self):
        kwargs = self._get_kwargs()
        dl = build_detection_train_loader(**kwargs)
        next(iter(dl))

    def test_build_iterable_dataloader_train(self):
        kwargs = self._get_kwargs()
        ds = DatasetFromList(kwargs.pop("dataset"))
        ds = ToIterableDataset(ds, TrainingSampler(len(ds)))
        dl = build_detection_train_loader(dataset=ds, **kwargs)
        next(iter(dl))

    def test_build_iterable_dataloader_from_cfg(self):
        cfg = get_cfg()
        cfg.DATASETS.TRAIN = ["iter_data"]
        DatasetCatalog.register("iter_data", lambda: _MyData())
        dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False)
        next(iter(dl))

        dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x)
        next(iter(dl))

    def _check_is_range(self, data_loader, N):
        # check that data_loader produces range(N)
        data = list(iter(data_loader))
        data = [x for batch in data for x in batch]  # flatten the batches
        self.assertEqual(len(data), N)
        self.assertEqual(set(data), set(range(N)))

    def test_build_batch_dataloader_inference(self):
        # Test that build_batch_data_loader can be used for inference
        N = 96
        ds = DatasetFromList(list(range(N)))
        sampler = InferenceSampler(len(ds))
        dl = build_batch_data_loader(ds, sampler, 8, num_workers=3)
        self._check_is_range(dl, N)

    def test_build_batch_dataloader_inference_incomplete_batch(self):
        # Test that build_batch_data_loader works when dataset size is not multiple of
        # batch size or num_workers
        def _test(N, batch_size, num_workers):
            ds = DatasetFromList(list(range(N)))
            sampler = InferenceSampler(len(ds))

            dl = build_batch_data_loader(ds, sampler, batch_size, num_workers=num_workers)
            data = list(iter(dl))
            self.assertEqual(len(data), len(dl))  # floor(N / batch_size)
            self._check_is_range(dl, N // batch_size * batch_size)

            dl = build_batch_data_loader(
                ds, sampler, batch_size, num_workers=num_workers, drop_last=False
            )
            data = list(iter(dl))
            self.assertEqual(len(data), len(dl))  # ceil(N / batch_size)
            self._check_is_range(dl, N)

        _test(48, batch_size=8, num_workers=3)
        _test(47, batch_size=8, num_workers=3)
        _test(46, batch_size=8, num_workers=3)
        _test(40, batch_size=8, num_workers=3)
        _test(39, batch_size=8, num_workers=3)

    def test_build_dataloader_inference(self):
        N = 50
        ds = DatasetFromList(list(range(N)))
        sampler = InferenceSampler(len(ds))
        # test that parallel loader works correctly
        dl = build_detection_test_loader(
            dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3
        )
        self._check_is_range(dl, N)

        # test that batch_size works correctly
        dl = build_detection_test_loader(
            dataset=ds, sampler=sampler, mapper=lambda x: x, batch_size=4, num_workers=0
        )
        self._check_is_range(dl, N)

    def test_build_iterable_dataloader_inference(self):
        # Test that build_detection_test_loader supports iterable dataset
        N = 50
        ds = DatasetFromList(list(range(N)))
        ds = ToIterableDataset(ds, InferenceSampler(len(ds)))
        dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3)
        self._check_is_range(dl, N)
