# import functools
# import math
# from typing import Callable, Union, Optional
#
# import torch
# from torch import Tensor
# from torch.nn import Sequential, ReLU, Linear, SmoothL1Loss
# from torch.optim import Adam
# from tqdm import tqdm
#
# from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
# from algorithms.convergence_algorithms.modules import BaseSequentialModel
# from algorithms.mapping.trust_region import TanhTrustRegion
# from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
# from algorithms.space.callable_function import CallableSpace
# from config import DEFAULT_EGL_SCHEDULER, DEFAULT_EGL_CONFIG
# from scripts.exceptions import UnspecifiedDimensionSizeError
# from utils import make_dataclass_from_dict
#
# BoundsType = Union[float, Tensor]
#
#
# def egl_minimize(
#     fun: Callable,
#     upper_bound: BoundsType,
#     lower_bound: BoundsType,
#     config: dict,
#     dim: int = None,
#     device: Optional[int] = 0,
# ) -> Tensor:
#     if dim is not None and (
#         isinstance(upper_bound, float) and isinstance(lower_bound, float)
#     ):
#         raise UnspecifiedDimensionSizeError(
#             "You must either specify dimension size or give a vector for a bound"
#         )
#
#     config = make_dataclass_from_dict(
#         DEFAULT_EGL_CONFIG | DEFAULT_EGL_SCHEDULER | config
#     )
#     dim = dim or len(upper_bound)
#     lower_bound = ensure_bounds(lower_bound, dim)
#     upper_bound = ensure_bounds(upper_bound, dim)
#     space = CallableSpace(fun, lower_bound, upper_bound)
#
#     grad_net = Sequential(
#         Linear(dim, 10), ReLU(), Linear(10, 15), ReLU(), Linear(15, dim),
#     )
#     grad_net = grad_net.to(device=device)
#     model_net = Sequential(Linear(dim, 1, bias=False, dtype=torch.float64)).to(
#         device=device
#     )
#     egl_adapted_net = BaseSequentialModel(model_net)
#     trust_region = TanhTrustRegion(
#         space.lower_bound,
#         space.upper_bound,
#         shrink_factor=config.shrink_factor,
#         min_trust_region_size=config.min_tr_size,
#         device=device,
#     )
#
#     egl = EGLScheduler(
#         space,
#         grad_net,
#         model_lr_factor=config.lr_factor,
#         train_quantile=config.quantile,
#         weights_creator=config.w_creator,
#         model_to_train=egl_adapted_net,
#         value_optimizer=Adam(
#             grad_net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
#         ),
#         model_to_train_optimizer=config.model_optimizer(
#             model_net.parameters(), lr=config.model_lr, momentum=config.momentum
#         ),
#         epsilon=config.eps * math.sqrt(dim),
#         epsilon_factor=config.eps_factor,
#         min_epsilon=config.min_eps,
#         perturb=config.perturb,
#         grad_loss=SmoothL1Loss(reduction="none"),
#         output_mapping=AdaptedOutputUnconstrainedMapping(config.output_outlier),
#         input_mapping=trust_region,
#         device=device,
#     )
#
#     tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)
#
#     egl.train(
#         config.epochs,
#         config.explore_size,
#         config.num_loop_without_improvement,
#         config.min_iter,
#         warmup_minibatch=config.warmup_minibatch,
#         warmup_loops=config.warmup_loops,
#     )
#     return egl.env.denormalize(
#         egl.input_mapping.inverse(egl.model_to_train.model_parameter_tensor())
#     )
#
#
# def ensure_bounds(bounds: BoundsType, dim_size: int) -> Tensor:
#     if isinstance(bounds, float):
#         return torch.tensor([bounds] * dim_size, dtype=torch.float64)
#     return bounds
