# import logging
# from logging import Logger
# from typing import List, Tuple, Callable, Optional
#
# import torch
# from torch import Tensor
# from torch.nn import Module
# from torch.optim import Optimizer
#
# from algorithms.convergence_algorithms.base import Algorithm
# from algorithms.convergence_algorithms.exceptions import AlgorithmFinish
# from algorithms.convergence_algorithms.modules import ConstructableModel
# from algorithms.convergence_algorithms.typing import (
#     BoundedEvaluatedSamplerIdentifiableSpace,
# )
# from algorithms.convergence_algorithms.utils import (
#     ball_perturb,
#     reset_all_weights,
#     sample_random_pairs_from_databases,
#     step_model_with_gradient,
# )
# from algorithms.mapping.base import InputMapping, DefaultMapping, OutputMapping
# from algorithms.space.exceptions import NoMoreBudgetError
# from algorithms.stopping_condition.base import AlgorithmStopCondition
# from algorithms.stopping_condition.quadratic_model import distance_point_from_min
# from algorithms.stopping_condition.trsut_region import StopAfterXTimes
# from handlers.base_handler import AlgorithmCallbackHandler
# from handlers.drawers.drawable_algorithms import ConvergenceDrawable
#
#
# class MultipleConvergenceAlgorithm(ConvergenceDrawable, Algorithm):
#     ALGORITHM_NAME = "multiple_egl"
#
#     def __init__(
#         self,
#         env: BoundedEvaluatedSamplerIdentifiableSpace,
#         helper_network: Module,
#         models_to_train: List[ConstructableModel],
#         value_optimizer: Optimizer,
#         models_to_train_optimizer: List[Optimizer],
#         epsilon: float,
#         epsilon_factor: float,
#         min_epsilon: float,
#         perturb: float,
#         max_batch_size: int = 1024,
#         num_of_batch_reply: int = 32,
#         output_mapping: OutputMapping = None,
#         input_mapping: InputMapping = None,
#         device: int = None,
#         logger: Logger = logging.getLogger(__name__),
#     ):
#         self.env = env
#         self.models_to_train = models_to_train
#         self.helper_network = helper_network
#         self.best_models = torch.stack(
#             [model.model_parameter_tensor().detach() for model in self.models_to_train]
#         )
#         self.helper_optimizer = value_optimizer
#         self.model_to_train_optimizer = models_to_train_optimizer
#         self.epsilon = epsilon
#         self.min_epsilon = min_epsilon
#         self.epsilon_factor = epsilon_factor
#         self.perturb = perturb
#         self.max_batch_size = max_batch_size
#         self.num_of_batch_reply = num_of_batch_reply
#         self.output_mapping = output_mapping or DefaultMapping()
#         self.input_mapping = input_mapping
#         self.device = device
#         self.logger = logger
#         self.finish_convergence_points = []
#
#     def set_start_point(self, start_point: Tensor):
#         raise NotImplementedError()
#
#     @property
#     def best_model(self):
#         unreal_points = [model for model in self.best_models]
#         points = (
#             [self.env.denormalize(self.input_mapping.inverse(point)) for point in unreal_points]
#             if self.input_mapping
#             else unreal_points
#         )
#         values = torch.tensor(
#             [self.env(point, debug_mode=True) for point in points], dtype=torch.float64
#         )
#         best_point_index = torch.argmin(values)
#         return unreal_points[best_point_index]
#
#     @property
#     def best_point_until_now(self):
#         best_point = self.best_model
#         best_point = self.env.denormalize(self.input_mapping.inverse(best_point))
#         if not self.finish_convergence_points:
#             return best_point
#         best_point_value = self.env(best_point, debug_mode=True)
#         best_point_found, best_point_found_value = min(
#             self.finish_convergence_points, key=lambda x: x[1]
#         )
#         if best_point_found_value < best_point_value:
#             return best_point_found
#         return best_point
#
#     @property
#     def environment(self):
#         return self.env
#
#     @property
#     def curr_point_to_draw(self):
#         point, unreal_point, _ = self.best_point_from_all_convergence
#         return point
#
#     @property
#     def curr_points_values(self):
#         unreal_points = torch.stack(
#             [model.model_parameter_tensor().detach() for model in self.models_to_train]
#         )
#         points = (
#             [self.env.denormalize(self.input_mapping.inverse(point)) for point in unreal_points]
#             if self.input_mapping
#             else unreal_points
#         )
#         values = torch.tensor([self.env(point) for point in points], dtype=torch.float64)
#         return points, unreal_points, values
#
#     @property
#     def best_point_from_all_convergence(self) -> Tuple[Tensor, Tensor, float]:
#         points, unreal_points, values = self.curr_points_values
#         best_point_index = torch.argmin(values)
#         return (
#             points[best_point_index],
#             unreal_points[best_point_index],
#             values[best_point_index],
#         )
#
#     def train(
#         self,
#         epochs: int,
#         exploration_size: int,
#         num_loop_without_improvement: int,
#         min_iteration_before_shrink: int,
#         max_num_of_shrink: int = None,
#         helper_model_training_epochs: int = 60,
#         warmup_minibatch: int = 5,
#         warmup_loops: int = 6,
#         stopping_conditions: List[AlgorithmStopCondition] = None,
#         callback_handlers: List[AlgorithmCallbackHandler] = None,
#         **kwargs,
#     ):
#         self.logger.info(f"Starting running {self.__class__.__name__} for {epochs} epochs")
#         stopping_conditions = stopping_conditions or []
#         callback_handlers = callback_handlers or []
#
#         # stopping_conditions.append(GOAL_IS_REACHED_STOPPING_CONDITION)
#         if max_num_of_shrink:
#             stopping_conditions.append(StopAfterXTimes(max_num_of_shrink))
#
#         for callback_handler in callback_handlers:
#             callback_handler.on_algorithm_start(self)
#
#         self.training_start_hook(
#             epochs,
#             exploration_size,
#             num_loop_without_improvement,
#             min_iteration_before_shrink,
#             max_num_of_shrink,
#             helper_model_training_epochs,
#             warmup_minibatch,
#             callback_handlers,
#             **kwargs,
#         )
#
#         # to prevent error if database is not assigned before exception
#         databases = torch.tensor([])
#
#         try:
#             databases, evaluations = self.explore(warmup_minibatch * exploration_size)
#
#             self.logger.info(f"Warmup for {warmup_loops} in {self.__class__.__name__}")
#             num_of_samples_in_database = databases.shape[0] * databases.shape[1]
#             batch_size = min(self.max_batch_size, num_of_samples_in_database)
#             minibatches = num_of_samples_in_database // batch_size
#             for i in range(warmup_loops):
#                 self.train_helper_model(
#                     databases,
#                     evaluations,
#                     minibatches,
#                     batch_size,
#                     exploration_size,
#                     helper_model_training_epochs,
#                 )
#
#             _, _, best_models_value = self.curr_points_values
#             best_model_index = best_models_value.argmin()
#
#             reply_memory_size = self.num_of_batch_reply * exploration_size
#             no_improvement_in_model_count = 0
#             counter = 0
#             num_of_shrinks = 0
#
#             for _ in range(epochs):
#                 counter += 1
#                 # Explore
#                 self.logger.info(
#                     f"Exploring {exploration_size} on env with {self.input_mapping}"
#                 )
#                 samples, new_evaluations = self.explore(exploration_size)
#                 databases = torch.cat((databases, samples), dim=1)
#                 databases = torch.stack(
#                     [database[-reply_memory_size:] for database in databases]
#                 )
#                 evaluations = torch.cat((evaluations, new_evaluations), dim=1)
#                 evaluations = torch.stack(
#                     [evaluation[-reply_memory_size:] for evaluation in evaluations]
#                 )
#
#                 num_of_samples_in_database = databases.shape[0] * databases.shape[1]
#                 batch_size = min(self.max_batch_size, num_of_samples_in_database)
#                 minibatches = num_of_samples_in_database // batch_size
#                 self.train_helper_model(
#                     databases,
#                     evaluations,
#                     minibatches,
#                     batch_size,
#                     exploration_size,
#                     helper_model_training_epochs,
#                 )
#                 self.train_model()
#
#                 # Handle end of epoch
#                 for handler in callback_handlers:
#                     handler.on_epoch_end(
#                         self, database=databases.reshape(-1, databases.shape[-1])
#                     )
#
#                 # Check improvement
#                 _, unreal_points, curr_models_values = self.curr_points_values
#                 models_to_improve_idx = best_models_value > curr_models_values
#                 self.logger.warning(
#                     f"{sum(models_to_improve_idx)} were improved in {self.env} - "
#                     f"old best {best_models_value.min()}. "
#                     f"New best - {min(curr_models_values.min(), best_models_value.min())}"
#                 )
#                 self.best_models[models_to_improve_idx] = unreal_points[models_to_improve_idx]
#
#                 if best_models_value.min() <= curr_models_values.min():
#                     no_improvement_in_model_count += 1
#                     self.logger.info(f"No new best {no_improvement_in_model_count}")
#                 else:
#                     self.logger.info(f"Improved point {curr_models_values.min()}")
#                 best_models_value[models_to_improve_idx] = curr_models_values[
#                     models_to_improve_idx
#                 ]
#                 new_min_index = best_models_value.argmin()
#                 if best_model_index != new_min_index:
#                     points_distance = (
#                         (self.best_models[new_min_index] - self.best_models[best_model_index])
#                         .pow(2)
#                         .sum()
#                         .sqrt()
#                     )
#                     self.logger.info(
#                         f"Different point reached better solution "
#                         f"{best_model_index} to {new_min_index} - {points_distance}"
#                     )
#                     best_model_index = new_min_index
#
#                 # Get new points according to new min
#                 models_idx_to_remove = []
#                 for i, model in enumerate(self.models_to_train):
#                     point = model.model_parameter_tensor().detach()
#                     distance = distance_point_from_min(self, self.epsilon, point)
#                     if distance < 0.45:
#                         models_idx_to_remove.append(i)
#                         self.logger.info(
#                             f"Finish with point {i} with value {curr_models_values[i]}"
#                         )
#                 for i in models_idx_to_remove:
#                     self.finish_convergence_points.append(
#                         (
#                             self.env.denormalize(
#                                 self.input_mapping.inverse(
#                                     self.models_to_train[i].model_parameter_tensor()
#                                 )
#                             )
#                             .detach()
#                             .clone(),
#                             curr_models_values[i].detach().clone(),
#                         )
#                     )
#                     new_point = self.env.sample_from_space(1, self.device)
#                     new_point_mapped = self.input_mapping.map(self.env.normalize(new_point))
#                     self.models_to_train[i] = self.models_to_train[i].from_parameter_tensor(
#                         new_point_mapped
#                     )
#                     point_value = self.env(new_point)
#                     curr_models_values[i] = point_value
#                     best_models_value[i] = point_value
#
#                 if (
#                     no_improvement_in_model_count >= num_loop_without_improvement
#                     and counter >= min_iteration_before_shrink
#                 ):
#                     counter = 0
#                     no_improvement_in_model_count = 0
#                     num_of_shrinks += 1
#                     self.epsilon *= self.epsilon_factor
#                     self.epsilon = max(self.epsilon, self.min_epsilon)
#                     if self.input_mapping:
#                         self.before_shrinking_hook()
#                         best_parameters_real = self.input_mapping.inverse(self.best_model)
#                         all_best_parameters_real = self.input_mapping.inverse(self.best_models)
#                         real_databases = self.input_mapping.inverse(databases.detach())
#                         self.input_mapping.squeeze(best_parameters_real)
#
#                         self.models_to_train = [
#                             model.from_parameter_tensor(best_point.clone())
#                             for model, best_point in zip(
#                                 self.models_to_train,
#                                 self.input_mapping.map(all_best_parameters_real),
#                             )
#                         ]
#                         self.best_models = torch.stack(
#                             [
#                                 model.model_parameter_tensor().detach()
#                                 for model in self.models_to_train
#                             ]
#                         )
#                         self.after_shrinking_hook()
#                         self.logger.info(
#                             f"Shrinking trust region, new center is {best_parameters_real.tolist()} "
#                             f"with {self.input_mapping}"
#                         )
#                         # NOTE - I reset this network only if trust region has changed
#                         #       Because If it has not the network should look the same
#                         reset_all_weights(self.helper_network)
#                         databases = self.input_mapping.map(real_databases)
#
#                     self.logger.info(f"Shrinking sample radius to {self.epsilon}")
#                     self.logger.info(f"Space status {self.env}")
#                     # d = torch.stack([p for p, _ in self.finish_convergence_points])
#                     # self.logger.info(f"The database is {d}")
#                     for handler in callback_handlers:
#                         handler.on_algorithm_update(
#                             self,
#                             # database=d
#                             self,
#                             database=databases.reshape(-1, databases.shape[-1]),
#                         )
#
#                 for stop_condition in stopping_conditions:
#                     if stop_condition.should_stop(self, counter=num_of_shrinks):
#                         raise AlgorithmFinish(
#                             stop_condition.REASON.format(
#                                 alg=self.__class__.__name__,
#                                 env=self.env,
#                                 best_point=self.best_point_until_now,
#                                 tr=self.input_mapping,
#                             )
#                         )
#         except NoMoreBudgetError as e:
#             self.logger.warning("No more Budget", exc_info=e)
#         except AlgorithmFinish as e:
#             self.logger.info(f"{self.__class__.__name__} Finish stopped {e}")
#
#         for handler in callback_handlers:
#             self.logger.info(f"Calling upon {handler.on_algorithm_end} finishing convergence")
#             if databases.numel() != 0:
#                 handler.on_algorithm_end(
#                     self, database=databases.reshape(-1, databases.shape[-1])
#                 )
#             else:
#                 handler.on_algorithm_end(self)
#
#     def explore(self, exploration_size: int):
#         new_model_samples = torch.stack(
#             [
#                 self.samples_points(model.model_parameter_tensor(), exploration_size)
#                 for model in self.models_to_train
#             ]
#         )
#
#         # Evaluate
#         evaluations = torch.stack(
#             [
#                 self.evaluate_point(samples).to(device=self.device)
#                 for samples in new_model_samples
#             ]
#         )
#
#         self.output_mapping.adapt(evaluations.reshape(-1))
#         return new_model_samples, evaluations
#
#     def samples_points(self, base_point: Tensor, exploration_size: int):
#         new_model_samples = torch.cat(
#             (
#                 # TODO - use DI to allow other exploration
#                 ball_perturb(base_point, self.epsilon, exploration_size - 1, self.device),
#                 base_point.reshape(1, -1),
#             )
#         )
#         return new_model_samples
#
#     def evaluate_point(self, new_model_samples: Tensor):
#         if self.input_mapping:
#             real_samples = self.input_mapping.inverse(new_model_samples)
#             real_samples = self.env.denormalize(
#                 real_samples
#             )  # Map the samples to entire space from (-1, 1)
#         else:
#             real_samples = new_model_samples
#
#         return self.env(real_samples.detach().cpu())
#
#     def train_helper_model(
#         self,
#         samples: Tensor,
#         samples_value: Tensor,
#         num_of_minibatch: int,
#         batch_size: int,
#         exploration_size: int,
#         epochs: int,
#     ):
#         raise NotImplementedError()
#
#     def train_model(self):
#         raise NotImplementedError()
#
#     def training_start_hook(self, *args, **kwargs):
#         pass
#
#     def before_shrinking_hook(self):
#         pass
#
#     def after_shrinking_hook(self):
#         pass
#
#
# class MultipleEGL(MultipleConvergenceAlgorithm):
#     def __init__(self, *args, grad_loss: Callable, **kwargs):
#         super(MultipleEGL, self).__init__(*args, **kwargs)
#         self.grad_loss = grad_loss
#
#     @property
#     def grad_network(self):
#         return self.helper_network
#
#     @property
#     def grad_optimizer(self):
#         return self.helper_optimizer
#
#     def train_helper_model(
#         self,
#         samples: Tensor,
#         samples_value: Tensor,
#         num_of_minibatch: int,
#         batch_size: int,
#         exploration_size: int,
#         epochs: int,
#     ):
#         self.grad_network.train()
#         mapped_evaluations = self.output_mapping.map(samples_value)
#         # TODO -
#         #       1. create here sigmoid for each database
#         #       2.return which database each tuple is
#         #       3. create weight for the tuple according to the database and the matching weight function
#         for _ in range(epochs):
#             for x_i, x_j, y_i, y_j in sample_random_pairs_from_databases(
#                 samples,
#                 mapped_evaluations,
#                 num_of_minibatch,
#                 batch_size,
#                 exploration_size,
#             ):
#                 self.grad_optimizer.zero_grad()
#                 x_tag_perturb = ball_perturb(
#                     x_i, self.epsilon * self.perturb, batch_size, self.device
#                 )
#                 grad_on_perturb = self.grad_network(x_tag_perturb)
#
#                 value = ((x_j - x_i) * grad_on_perturb).sum(dim=1)
#                 target = y_j - y_i
#
#                 loss = self.grad_loss(value, target)
#                 loss.backward()
#                 self.grad_optimizer.step()
#         self.grad_network.eval()
#
#     def train_model(self):
#         for model, optimizer in zip(self.models_to_train, self.model_to_train_optimizer):
#             model.train()
#             optimizer.zero_grad()
#             self.grad_optimizer.zero_grad()
#             self.grad_network.eval()
#
#             model_to_train_gradient = self.grad_network(model.model_parameter_tensor())
#             model_to_train_gradient[model_to_train_gradient != model_to_train_gradient] = 0
#             self.logger.info(
#                 f"Algorithm {self.__class__.__name__} "
#                 f"moving Step size: {torch.norm(model_to_train_gradient)} on {self.env}"
#             )
#
#             step_model_with_gradient(model, model_to_train_gradient, optimizer)
#             model.eval()
