# Copyright (c) Facebook, Inc. and its affiliates.
import os
import tempfile
import unittest
from collections import OrderedDict
import torch
from iopath.common.file_io import PathHandler, PathManager
from torch import nn

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.checkpoint.c2_model_loading import (
    _longest_common_prefix_str,
    align_and_update_state_dicts,
)
from detectron2.utils.logger import setup_logger


class TestCheckpointer(unittest.TestCase):
    def setUp(self):
        setup_logger()

    def create_complex_model(self):
        m = nn.Module()
        m.block1 = nn.Module()
        m.block1.layer1 = nn.Linear(2, 3)
        m.layer2 = nn.Linear(3, 2)
        m.res = nn.Module()
        m.res.layer2 = nn.Linear(3, 2)

        state_dict = OrderedDict()
        state_dict["layer1.weight"] = torch.rand(3, 2)
        state_dict["layer1.bias"] = torch.rand(3)
        state_dict["layer2.weight"] = torch.rand(2, 3)
        state_dict["layer2.bias"] = torch.rand(2)
        state_dict["res.layer2.weight"] = torch.rand(2, 3)
        state_dict["res.layer2.bias"] = torch.rand(2)
        return m, state_dict

    def test_complex_model_loaded(self):
        for add_data_parallel in [False, True]:
            model, state_dict = self.create_complex_model()
            if add_data_parallel:
                model = nn.DataParallel(model)
            model_sd = model.state_dict()

            sd_to_load = align_and_update_state_dicts(model_sd, state_dict)
            model.load_state_dict(sd_to_load)
            for loaded, stored in zip(model_sd.values(), state_dict.values()):
                # different tensor references
                self.assertFalse(id(loaded) == id(stored))
                # same content
                self.assertTrue(loaded.to(stored).equal(stored))

    def test_load_with_matching_heuristics(self):
        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            model, state_dict = self.create_complex_model()
            torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth"))
            checkpointer = DetectionCheckpointer(model, save_dir=d)

            with torch.no_grad():
                # use a different weight from the `state_dict`, since torch.rand is less than 1
                model.block1.layer1.weight.fill_(1)

            # load checkpoint without matching_heuristics
            checkpointer.load(os.path.join(d, "checkpoint.pth"))
            self.assertTrue(model.block1.layer1.weight.equal(torch.ones(3, 2)))

            # load checkpoint with matching_heuristics
            checkpointer.load(os.path.join(d, "checkpoint.pth?matching_heuristics=True"))
            self.assertFalse(model.block1.layer1.weight.equal(torch.ones(3, 2)))

    def test_custom_path_manager_handler(self):
        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:

            class CustomPathManagerHandler(PathHandler):
                PREFIX = "detectron2_test://"

                def _get_supported_prefixes(self):
                    return [self.PREFIX]

                def _get_local_path(self, path, **kwargs):
                    name = path[len(self.PREFIX) :]
                    return os.path.join(d, name)

                def _open(self, path, mode="r", **kwargs):
                    return open(self._get_local_path(path), mode, **kwargs)

            pathmgr = PathManager()
            pathmgr.register_handler(CustomPathManagerHandler())

            model, state_dict = self.create_complex_model()
            torch.save({"model": state_dict}, os.path.join(d, "checkpoint.pth"))
            checkpointer = DetectionCheckpointer(model, save_dir=d)
            checkpointer.path_manager = pathmgr
            checkpointer.load("detectron2_test://checkpoint.pth")
            checkpointer.load("detectron2_test://checkpoint.pth?matching_heuristics=True")

    def test_lcp(self):
        self.assertEqual(_longest_common_prefix_str(["class", "dlaps_model"]), "")
        self.assertEqual(_longest_common_prefix_str(["classA", "classB"]), "class")
        self.assertEqual(_longest_common_prefix_str(["classA", "classB", "clab"]), "cla")


if __name__ == "__main__":
    unittest.main()
