#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import product

import numpy as np
import torch
from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError
from botorch.optim.parameter_constraints import (
    _arrayify,
    _generate_unfixed_lin_constraints,
    _make_linear_constraints,
    eval_lin_constraint,
    lin_constraint_jac,
    make_scipy_bounds,
    make_scipy_linear_constraints,
)
from botorch.utils.testing import BotorchTestCase
from scipy.optimize import Bounds


class TestParameterConstraints(BotorchTestCase):
    def test_arrayify(self):
        for dtype in (torch.float, torch.double, torch.int, torch.long):
            t = torch.tensor([[1, 2], [3, 4]], device=self.device).type(dtype)
            t_np = _arrayify(t)
            self.assertIsInstance(t_np, np.ndarray)
            self.assertTrue(t_np.dtype == np.float64)

    def test_eval_lin_constraint(self):
        res = eval_lin_constraint(
            flat_idxr=[0, 2],
            coeffs=np.array([1.0, -2.0]),
            rhs=0.5,
            x=np.array([1.0, 2.0, 3.0]),
        )
        self.assertEqual(res, -5.5)

    def test_lin_constraint_jac(self):
        dummy_array = np.array([1.0])
        res = lin_constraint_jac(
            dummy_array, flat_idxr=[0, 2], coeffs=np.array([1.0, -2.0]), n=3
        )
        self.assertTrue(all(np.equal(res, np.array([1.0, 0.0, -2.0]))))

    def test_make_linear_constraints(self):
        # equality constraints, 1d indices
        indices = torch.tensor([1, 2], dtype=torch.long, device=self.device)
        for dtype, shapeX in product(
            (torch.float, torch.double), (torch.Size([3, 2, 4]), torch.Size([2, 4]))
        ):
            coefficients = torch.tensor([1.0, 2.0], dtype=dtype, device=self.device)
            constraints = _make_linear_constraints(
                indices=indices,
                coefficients=coefficients,
                rhs=1.0,
                shapeX=shapeX,
                eq=True,
            )
            self.assertTrue(
                all(set(c.keys()) == {"fun", "jac", "type"} for c in constraints)
            )
            self.assertTrue(all(c["type"] == "eq" for c in constraints))
            self.assertEqual(len(constraints), shapeX[:-1].numel())
            x = np.random.rand(shapeX.numel())
            self.assertEqual(constraints[0]["fun"](x), x[1] + 2 * x[2] - 1.0)
            jac_exp = np.zeros(shapeX.numel())
            jac_exp[[1, 2]] = [1, 2]
            self.assertTrue(np.allclose(constraints[0]["jac"](x), jac_exp))
            self.assertEqual(constraints[-1]["fun"](x), x[-3] + 2 * x[-2] - 1.0)
            jac_exp = np.zeros(shapeX.numel())
            jac_exp[[-3, -2]] = [1, 2]
            self.assertTrue(np.allclose(constraints[-1]["jac"](x), jac_exp))

        # inequality constraints, 1d indices
        for shapeX in [torch.Size([1, 1, 2]), torch.Size([1, 2])]:
            lcs = _make_linear_constraints(
                indices=torch.tensor([1]),
                coefficients=torch.tensor([1.0]),
                rhs=1.0,
                shapeX=shapeX,
                eq=False,
            )
            self.assertEqual(len(lcs), 1)
            self.assertEqual(lcs[0]["type"], "ineq")

        # constraint across q-batch (2d indics), equality constraint
        indices = torch.tensor([[0, 3], [1, 2]], dtype=torch.long, device=self.device)

        for dtype, shapeX in product(
            (torch.float, torch.double), (torch.Size([3, 2, 4]), torch.Size([2, 4]))
        ):
            q, d = shapeX[-2:]
            b = 1 if len(shapeX) == 2 else shapeX[0]
            coefficients = torch.tensor([1.0, 2.0], dtype=dtype, device=self.device)
            constraints = _make_linear_constraints(
                indices=indices,
                coefficients=coefficients,
                rhs=1.0,
                shapeX=shapeX,
                eq=True,
            )
            self.assertTrue(
                all(set(c.keys()) == {"fun", "jac", "type"} for c in constraints)
            )
            self.assertTrue(all(c["type"] == "eq" for c in constraints))
            self.assertEqual(len(constraints), b)
            x = np.random.rand(shapeX.numel())
            offsets = [q * d, d]
            # rule is [i, j, k] is i * offset[0] + j * offset[1] + k
            for i in range(b):
                pos1 = i * offsets[0] + 3
                pos2 = i * offsets[0] + 1 * offsets[1] + 2
                self.assertEqual(constraints[i]["fun"](x), x[pos1] + 2 * x[pos2] - 1.0)
                jac_exp = np.zeros(shapeX.numel())
                jac_exp[[pos1, pos2]] = [1, 2]
                self.assertTrue(np.allclose(constraints[i]["jac"](x), jac_exp))
        # make sure error is raised for scalar tensors
        with self.assertRaises(ValueError):
            constraints = _make_linear_constraints(
                indices=torch.tensor(0),
                coefficients=torch.tensor([1.0]),
                rhs=1.0,
                shapeX=torch.Size([1, 1, 2]),
                eq=False,
            )
        # test that len(shapeX) < 2 raises an error
        with self.assertRaises(UnsupportedError):
            _make_linear_constraints(
                shapeX=torch.Size([2]),
                indices=indices,
                coefficients=coefficients,
                rhs=0.0,
            )

    def test_make_scipy_linear_constraints(self):
        for shapeX in [torch.Size([2, 1, 4]), torch.Size([1, 4])]:
            b = shapeX[0] if len(shapeX) == 3 else 1
            res = make_scipy_linear_constraints(
                shapeX=shapeX, inequality_constraints=None, equality_constraints=None
            )
            self.assertEqual(res, [])
            indices = torch.tensor([0, 1], dtype=torch.long, device=self.device)
            coefficients = torch.tensor([1.5, -1.0], device=self.device)
            # both inequality and equality constraints
            cs = make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )
            self.assertEqual(len(cs), 2 * b)
            self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"})
            # inequality only
            cs = make_scipy_linear_constraints(
                shapeX=shapeX, inequality_constraints=[(indices, coefficients, 1.0)]
            )
            self.assertEqual(len(cs), b)
            self.assertTrue(all(c["type"] == "ineq" for c in cs))
            # equality only
            cs = make_scipy_linear_constraints(
                shapeX=shapeX, equality_constraints=[(indices, coefficients, 1.0)]
            )
            self.assertEqual(len(cs), b)
            self.assertTrue(all(c["type"] == "eq" for c in cs))

            # test that 2-dim indices work properly
            indices = indices.unsqueeze(0)
            cs = make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )
            self.assertEqual(len(cs), 2 * b)
            self.assertTrue({c["type"] for c in cs} == {"ineq", "eq"})

    def test_make_scipy_linear_constraints_unsupported(self):
        shapeX = torch.Size([2, 1, 4])
        coefficients = torch.tensor([1.5, -1.0], device=self.device)

        # test that >2-dim indices raises an UnsupportedError
        indices = torch.tensor([0, 1], dtype=torch.long, device=self.device)
        indices = indices.unsqueeze(0).unsqueeze(0)
        with self.assertRaises(UnsupportedError):
            make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )
        # test that out of bounds index raises an error
        indices = torch.tensor([0, 4], dtype=torch.long, device=self.device)
        with self.assertRaises(RuntimeError):
            make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )
        # test that two-d index out-of-bounds raises an error
        # q out of bounds
        indices = torch.tensor([[0, 0], [1, 0]], dtype=torch.long, device=self.device)
        with self.assertRaises(RuntimeError):
            make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )
        # d out of bounds
        indices = torch.tensor([[0, 0], [0, 4]], dtype=torch.long, device=self.device)
        with self.assertRaises(RuntimeError):
            make_scipy_linear_constraints(
                shapeX=shapeX,
                inequality_constraints=[(indices, coefficients, 1.0)],
                equality_constraints=[(indices, coefficients, 1.0)],
            )

    def test_generate_unfixed_lin_constraints(self):
        # Case 1: some fixed features are in the indices
        indices = [
            torch.arange(4, device=self.device),
            torch.arange(2, -1, -1, device=self.device),
        ]
        coefficients = [
            torch.tensor([-0.1, 0.2, -0.3, 0.4], device=self.device),
            torch.tensor([-0.1, 0.3, -0.5], device=self.device),
        ]
        rhs = [0.5, 0.5]
        dimension = 4
        fixed_features = {1: 1, 3: 2}
        new_constraints = _generate_unfixed_lin_constraints(
            constraints=list(zip(indices, coefficients, rhs)),
            fixed_features=fixed_features,
            dimension=dimension,
            eq=False,
        )
        for i, (new_indices, new_coefficients, new_rhs) in enumerate(new_constraints):
            if i % 2 == 0:  # first list of indices is [0, 1, 2, 3]
                self.assertTrue(
                    torch.equal(new_indices, torch.arange(2, device=self.device))
                )
            else:  # second list of indices is [2, 1, 0]
                self.assertTrue(
                    torch.equal(
                        new_indices, torch.arange(1, -1, -1, device=self.device)
                    )
                )
            mask = [True] * indices[i].shape[0]
            subtract = 0
            for j, old_idx in enumerate(indices[i]):
                if old_idx.item() in fixed_features:
                    mask[j] = False
                    subtract += fixed_features[old_idx.item()] * coefficients[i][j]
            self.assertTrue(torch.equal(new_coefficients, coefficients[i][mask]))
            self.assertEqual(new_rhs, rhs[i] - subtract)

        # Case 2: none of fixed features are in the indices, but have to be renumbered
        indices = [
            torch.arange(2, 6, device=self.device),
            torch.arange(5, 2, -1, device=self.device),
        ]
        fixed_features = {0: -10, 1: 10}
        dimension = 6
        new_constraints = _generate_unfixed_lin_constraints(
            constraints=list(zip(indices, coefficients, rhs)),
            fixed_features=fixed_features,
            dimension=dimension,
            eq=False,
        )
        for i, (new_indices, new_coefficients, new_rhs) in enumerate(new_constraints):
            if i % 2 == 0:  # first list of indices is [2, 3, 4, 5]
                self.assertTrue(
                    torch.equal(new_indices, torch.arange(4, device=self.device))
                )
            else:  # second list of indices is [5, 4, 3]
                self.assertTrue(
                    torch.equal(new_indices, torch.arange(3, 0, -1, device=self.device))
                )

            self.assertTrue(torch.equal(new_coefficients, coefficients[i]))
            self.assertEqual(new_rhs, rhs[i])

        # Case 3: all fixed features are in the indices
        indices = [
            torch.arange(4, device=self.device),
            torch.arange(2, -1, -1, device=self.device),
        ]
        # Case 3a: problem is feasible
        dimension = 4
        fixed_features = {0: 2, 1: 1, 2: 1, 3: 2}
        for eq in [False, True]:
            new_constraints = _generate_unfixed_lin_constraints(
                constraints=[(indices[0], coefficients[0], rhs[0])],
                fixed_features=fixed_features,
                dimension=dimension,
                eq=eq,
            )
            self.assertEqual(new_constraints, [])
        # Case 3b: problem is infeasible
        for eq in [False, True]:
            prefix = "Ineq" if not eq else "Eq"
            with self.assertRaisesRegex(CandidateGenerationError, prefix):
                new_constraints = _generate_unfixed_lin_constraints(
                    constraints=[(indices[1], coefficients[1], rhs[1])],
                    fixed_features=fixed_features,
                    dimension=dimension,
                    eq=eq,
                )


class TestMakeScipyBounds(BotorchTestCase):
    def test_make_scipy_bounds(self):
        X = torch.zeros(3, 1, 2)
        # both None
        self.assertIsNone(make_scipy_bounds(X=X, lower_bounds=None, upper_bounds=None))
        # lower None
        upper_bounds = torch.ones(2)
        bounds = make_scipy_bounds(X=X, lower_bounds=None, upper_bounds=upper_bounds)
        self.assertIsInstance(bounds, Bounds)
        self.assertTrue(
            np.all(np.equal(bounds.lb, np.full((3, 1, 2), float("-inf")).flatten()))
        )
        self.assertTrue(np.all(np.equal(bounds.ub, np.ones((3, 1, 2)).flatten())))
        # upper None
        lower_bounds = torch.zeros(2)
        bounds = make_scipy_bounds(X=X, lower_bounds=lower_bounds, upper_bounds=None)
        self.assertIsInstance(bounds, Bounds)
        self.assertTrue(np.all(np.equal(bounds.lb, np.zeros((3, 1, 2)).flatten())))
        self.assertTrue(
            np.all(np.equal(bounds.ub, np.full((3, 1, 2), float("inf")).flatten()))
        )
        # floats
        bounds = make_scipy_bounds(X=X, lower_bounds=0.0, upper_bounds=1.0)
        self.assertIsInstance(bounds, Bounds)
        self.assertTrue(np.all(np.equal(bounds.lb, np.zeros((3, 1, 2)).flatten())))
        self.assertTrue(np.all(np.equal(bounds.ub, np.ones((3, 1, 2)).flatten())))

        # 1-d tensors
        bounds = make_scipy_bounds(
            X=X, lower_bounds=lower_bounds, upper_bounds=upper_bounds
        )
        self.assertIsInstance(bounds, Bounds)
        self.assertTrue(np.all(np.equal(bounds.lb, np.zeros((3, 1, 2)).flatten())))
        self.assertTrue(np.all(np.equal(bounds.ub, np.ones((3, 1, 2)).flatten())))
