# Copyright (c) Facebook, Inc. and its affiliates.
import os
import unittest
import tempfile
from itertools import count

from detectron2.config import LazyConfig, LazyCall as L
from omegaconf import DictConfig


class TestLazyPythonConfig(unittest.TestCase):
    def setUp(self):
        self.curr_dir = os.path.dirname(__file__)
        self.root_filename = os.path.join(self.curr_dir, "root_cfg.py")

    def test_load(self):
        cfg = LazyConfig.load(self.root_filename)

        self.assertEqual(cfg.dir1a_dict.a, "modified")
        self.assertEqual(cfg.dir1b_dict.a, 1)
        self.assertEqual(cfg.lazyobj.x, "base_a_1")

        cfg.lazyobj.x = "new_x"
        # reload
        cfg = LazyConfig.load(self.root_filename)
        self.assertEqual(cfg.lazyobj.x, "base_a_1")

    def test_save_load(self):
        cfg = LazyConfig.load(self.root_filename)
        with tempfile.TemporaryDirectory(prefix="detectron2") as d:
            fname = os.path.join(d, "test_config.yaml")
            LazyConfig.save(cfg, fname)
            cfg2 = LazyConfig.load(fname)

        self.assertEqual(cfg2.lazyobj._target_, "itertools.count")
        self.assertEqual(cfg.lazyobj._target_, count)
        cfg2.lazyobj.pop("_target_")
        cfg.lazyobj.pop("_target_")
        # the rest are equal
        self.assertEqual(cfg, cfg2)

    def test_failed_save(self):
        cfg = DictConfig({"x": lambda: 3}, flags={"allow_objects": True})
        with tempfile.TemporaryDirectory(prefix="detectron2") as d:
            fname = os.path.join(d, "test_config.yaml")
            LazyConfig.save(cfg, fname)
            self.assertTrue(os.path.exists(fname))
            self.assertTrue(os.path.exists(fname + ".pkl"))

    def test_overrides(self):
        cfg = LazyConfig.load(self.root_filename)
        LazyConfig.apply_overrides(cfg, ["lazyobj.x=123", 'dir1b_dict.a="123"'])
        self.assertEqual(cfg.dir1b_dict.a, "123")
        self.assertEqual(cfg.lazyobj.x, 123)

        LazyConfig.apply_overrides(cfg, ["dir1b_dict.a=abc"])
        self.assertEqual(cfg.dir1b_dict.a, "abc")

    def test_invalid_overrides(self):
        cfg = LazyConfig.load(self.root_filename)
        with self.assertRaises(KeyError):
            LazyConfig.apply_overrides(cfg, ["lazyobj.x.xxx=123"])

    def test_to_py(self):
        cfg = LazyConfig.load(self.root_filename)
        cfg.lazyobj.x = {"a": 1, "b": 2, "c": L(count)(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]})}
        cfg.list = ["a", 1, "b", 3.2]
        py_str = LazyConfig.to_py(cfg)
        expected = """cfg.dir1a_dict.a = "modified"
cfg.dir1a_dict.b = 2
cfg.dir1b_dict.a = 1
cfg.dir1b_dict.b = 2
cfg.lazyobj = itertools.count(
    x={
        "a": 1,
        "b": 2,
        "c": itertools.count(x={"r": "a", "s": 2.4, "t": [1, 2, 3, "z"]}),
    },
    y="base_a_1_from_b",
)
cfg.list = ["a", 1, "b", 3.2]
"""
        self.assertEqual(py_str, expected)

    def test_bad_import(self):
        file = os.path.join(self.curr_dir, "dir1", "bad_import.py")
        with self.assertRaisesRegex(ImportError, "relative import"):
            LazyConfig.load(file)

    def test_bad_import2(self):
        file = os.path.join(self.curr_dir, "dir1", "bad_import2.py")
        with self.assertRaisesRegex(ImportError, "not exist"):
            LazyConfig.load(file)

    def test_load_rel(self):
        file = os.path.join(self.curr_dir, "dir1", "load_rel.py")
        cfg = LazyConfig.load(file)
        self.assertIn("x", cfg)
