# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `data_utils.py`."""

import datetime
import numpy as np
import xarray as xa

from absl.testing import absltest, parameterized

from . import data_utils


class DataUtilsTest(parameterized.TestCase):
    def setUp(self):
        super().setUp()
        # Fix the seed for reproducibility.
        np.random.seed(0)

    def test_year_progress_is_zero_at_year_start_or_end(self):
        year_progress = data_utils.get_year_progress(
            np.array([
                0,
                data_utils.AVG_SEC_PER_YEAR,
                data_utils.AVG_SEC_PER_YEAR * 42,  # 42 years.
            ])
        )
        np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))

    def test_year_progress_is_almost_one_before_year_ends(self):
        year_progress = data_utils.get_year_progress(
            np.array([
                data_utils.AVG_SEC_PER_YEAR - 1,
                (data_utils.AVG_SEC_PER_YEAR - 1) * 42,  # ~42 years
            ])
        )
        with self.subTest("Year progress values are close to 1"):
            self.assertTrue(np.all(year_progress > 0.999))
        with self.subTest("Year progress values != 1"):
            self.assertTrue(np.all(year_progress < 1.0))

    def test_day_progress_computes_for_all_times_and_longitudes(self):
        times = np.random.randint(low=0, high=1e10, size=10)
        longitudes = np.arange(0, 360.0, 1.0)
        day_progress = data_utils.get_day_progress(times, longitudes)
        with self.subTest("Day progress is computed for all times and longinutes"):
            self.assertSequenceEqual(day_progress.shape, (len(times), len(longitudes)))

    @parameterized.named_parameters(
        dict(
            testcase_name="random_date_1",
            year=1988,
            month=11,
            day=7,
            hour=2,
            minute=45,
            second=34,
        ),
        dict(
            testcase_name="random_date_2",
            year=2022,
            month=3,
            day=12,
            hour=7,
            minute=1,
            second=0,
        ),
    )
    def test_day_progress_is_in_between_zero_and_one(self, year, month, day, hour, minute, second):
        # Datetime from a timestamp.
        dt = datetime.datetime(year, month, day, hour, minute, second)
        # Epoch time.
        epoch_time = datetime.datetime(1970, 1, 1)
        # Seconds since epoch.
        seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])

        # Longitudes with 1 degree resolution.
        longitudes = np.arange(0, 360.0, 1.0)

        day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
        with self.subTest("Day progress >= 0"):
            self.assertTrue(np.all(day_progress >= 0.0))
        with self.subTest("Day progress < 1"):
            self.assertTrue(np.all(day_progress < 1.0))

    def test_day_progress_is_zero_at_day_start_or_end(self):
        day_progress = data_utils.get_day_progress(
            seconds_since_epoch=np.array([
                0,
                data_utils.SEC_PER_DAY,
                data_utils.SEC_PER_DAY * 42,  # 42 days.
            ]),
            longitude=np.array([0.0]),
        )
        np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))

    def test_day_progress_specific_value(self):
        day_progress = data_utils.get_day_progress(
            seconds_since_epoch=np.array([123]),
            longitude=np.array([0.0]),
        )
        np.testing.assert_array_almost_equal(day_progress, np.array([[0.00142361]]), decimal=6)

    def test_featurize_progress_valid_values_and_dimensions(self):
        day_progress = np.array([0.0, 0.45, 0.213])
        feature_dimensions = ("time",)
        progress_features = data_utils.featurize_progress(
            name="day_progress", dims=feature_dimensions, progress=day_progress
        )
        for feature in progress_features.values():
            with self.subTest(f"Valid dimensions for {feature}"):
                self.assertSequenceEqual(feature.dims, feature_dimensions)

        with self.subTest("Valid values for day_progress"):
            np.testing.assert_array_equal(day_progress, progress_features["day_progress"].values)

        with self.subTest("Valid values for day_progress_sin"):
            np.testing.assert_array_almost_equal(
                np.array([0.0, 0.30901699, 0.97309851]),
                progress_features["day_progress_sin"].values,
                decimal=6,
            )

        with self.subTest("Valid values for day_progress_cos"):
            np.testing.assert_array_almost_equal(
                np.array([1.0, -0.95105652, 0.23038943]),
                progress_features["day_progress_cos"].values,
                decimal=6,
            )

    def test_featurize_progress_invalid_dimensions(self):
        year_progress = np.array([0.0, 0.45, 0.213])
        feature_dimensions = ("time", "longitude")
        with self.assertRaises(ValueError):
            data_utils.featurize_progress(
                name="year_progress", dims=feature_dimensions, progress=year_progress
            )

    def test_add_derived_vars_variables_added(self):
        data = xa.Dataset(
            data_vars={"var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))},
            coords={
                "lon": np.array([0.0, 0.5]),
                "datetime": np.array([
                    datetime.datetime(2021, 1, 1),
                    datetime.datetime(2023, 1, 1),
                    datetime.datetime(2023, 1, 3),
                ]),
            },
        )
        data_utils.add_derived_vars(data)
        all_variables = set(data.variables)

        with self.subTest("Original value was not removed"):
            self.assertIn("var1", all_variables)
        with self.subTest("Year progress feature was added"):
            self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
        with self.subTest("Day progress feature was added"):
            self.assertIn(data_utils.DAY_PROGRESS, all_variables)

    def test_add_derived_vars_existing_vars_not_overridden(self):
        dims = ["x", "lon", "datetime"]
        data = xa.Dataset(
            data_vars={
                "var1": (dims, 8 * np.random.randn(2, 2, 3)),
                data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
                data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
            },
            coords={
                "lon": np.array([0.0, 0.5]),
                "datetime": np.array([
                    datetime.datetime(2021, 1, 1),
                    datetime.datetime(2023, 1, 1),
                    datetime.datetime(2023, 1, 3),
                ]),
            },
        )

        data_utils.add_derived_vars(data)

        with self.subTest("Year progress feature was not overridden"):
            np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
        with self.subTest("Day progress feature was not overridden"):
            np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)

    @parameterized.named_parameters(
        dict(testcase_name="missing_datetime", coord_name="lon"),
        dict(testcase_name="missing_lon", coord_name="datetime"),
    )
    def test_add_derived_vars_missing_coordinate_raises_value_error(self, coord_name):
        with self.subTest(f"Missing {coord_name} coordinate"):
            data = xa.Dataset(
                data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
                coords={
                    coord_name: np.array([0.0, 0.5]),
                },
            )
            with self.assertRaises(ValueError):
                data_utils.add_derived_vars(data)

    def test_add_tisr_var_variable_added(self):
        data = xa.Dataset(
            data_vars={"var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))},
            coords={
                "lat": np.array([2.0, 1.0]),
                "lon": np.array([0.0, 0.5]),
                "time": np.array([100, 200], dtype="timedelta64[s]"),
                "datetime": xa.Variable("time", np.array([10, 20], dtype="datetime64[D]")),
            },
        )

        data_utils.add_tisr_var(data)

        self.assertIn(data_utils.TISR, set(data.variables))

    def test_add_tisr_var_existing_var_not_overridden(self):
        dims = ["time", "lat", "lon"]
        data = xa.Dataset(
            data_vars={
                "var1": (dims, np.full((2, 2, 2), 8.0)),
                data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
            },
            coords={
                "lat": np.array([2.0, 1.0]),
                "lon": np.array([0.0, 0.5]),
                "time": np.array([100, 200], dtype="timedelta64[s]"),
                "datetime": xa.Variable("time", np.array([10, 20], dtype="datetime64[D]")),
            },
        )

        data_utils.add_derived_vars(data)

        np.testing.assert_allclose(data[data_utils.TISR], 1200.0)

    def test_add_tisr_var_works_with_batch_dim_size_one(self):
        data = xa.Dataset(
            data_vars={
                "var1": (
                    ["batch", "time", "lat", "lon"],
                    np.full((1, 2, 2, 2), 8.0),
                )
            },
            coords={
                "lat": np.array([2.0, 1.0]),
                "lon": np.array([0.0, 0.5]),
                "time": np.array([100, 200], dtype="timedelta64[s]"),
                "datetime": xa.Variable(
                    ("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
                ),
            },
        )

        data_utils.add_tisr_var(data)

        self.assertIn(data_utils.TISR, set(data.variables))

    def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
        data = xa.Dataset(
            data_vars={
                "var1": (
                    ["batch", "time", "lat", "lon"],
                    np.full((2, 2, 2, 2), 8.0),
                )
            },
            coords={
                "lat": np.array([2.0, 1.0]),
                "lon": np.array([0.0, 0.5]),
                "time": np.array([100, 200], dtype="timedelta64[s]"),
                "datetime": xa.Variable(
                    ("batch", "time"),
                    np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
                ),
            },
        )

        with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
            data_utils.add_tisr_var(data)


if __name__ == "__main__":
    absltest.main()
