from src.problem_surrogate_projection import SurrogateProjectionProblem
from src.problem_quadratic import QuadraticProblem
from src.runner_vanilla import VanillaRunner
from src.predictor import Predictor
import torch


class Runner(VanillaRunner):
    def __init__(self,
                 problem: QuadraticProblem,
                 predictor: Predictor,
                 lr: float,
                 optimizer_class: str = 'adam',
                 reward_weight: float = 1,
                 proximity_weight: float = 0,
                 solution_distance_weight: float = 0,
                 projection_distance_weight: float = 0,
                 mse_weight: float = 0,
                 surrogate_problem_weight: float = 0,
                 surrogate_problem_min_radius: float = 1e-3,
                 local_loss_weight: float = 0,
                 local_loss_th: float = 1e-2,
                 add_inwards_gradient: bool = False,
                 inwards_gradient_th: float = 0,
                 pretrain_for: bool = False,
                 pretraining_proximity_weight: float = 1,
                 pretraining_reward_weight: float = 0,
                 pretraining_solution_distance_weight: float = 0,
                 pretraining_projection_distance_weight: float = 0,
                 pretraining_mse_weight: float = 0,
                 pretraining_surrogate_problem_weight: float = 0,
                 pretraining_local_loss_weight: float = 0,
                 ):

        super().__init__(problem, predictor, lr, optimizer_class)
        # Self regularization parameters
        self.reward_weight = reward_weight
        self.proximity_weight = proximity_weight
        self.solution_distance_weight = solution_distance_weight
        self.projection_distance_weight = projection_distance_weight
        self.mse_weight = mse_weight
        self.surrogate_problem_weight = surrogate_problem_weight
        self.use_surrogate_problem = surrogate_problem_weight != 0
        self.surrogate_problem_min_radius = surrogate_problem_min_radius
        self.surrogate_problem = (SurrogateProjectionProblem(problem.n_variables, flexible_W_sq=predictor.train_W_sq))
        self.local_loss_weight = local_loss_weight
        self.local_loss_th = local_loss_th
        self.add_inwards_gradient = add_inwards_gradient
        self.inwards_gradient_th = inwards_gradient_th
        # Pretraining parameters
        self.pretrain_for = pretrain_for
        self.pretraining_reward_weight = pretraining_reward_weight
        self.pretraining_proximity_weight = pretraining_proximity_weight
        self.pretraining_solution_distance_weight = pretraining_solution_distance_weight
        self.pretraining_projection_distance_weight = pretraining_projection_distance_weight
        self.pretraining_mse_weight = pretraining_mse_weight
        self.pretraining_surrogate_problem_weight = pretraining_surrogate_problem_weight
        self.pretraining_local_loss_weight = pretraining_local_loss_weight



    def compute_proximity_loss(self, x_u, param_values_batch):
        return self.problem.get_penalty(x_u, param_values_batch)

    def compute_solution_distance_loss(self, x_u, param_values_batch):
        return self.problem.get_solution_distance(x_u, param_values_batch)

    def compute_projection_distance_loss(self, x_u, solution, param_values_batch):
        return self.problem.get_projection_distance(x_u, solution, param_values_batch)

    def compute_mse_loss(self, prediction_dict, param_values_batch):
        return self.problem.get_mse(prediction_dict, param_values_batch)

    def compute_surrogate_problem_loss(self, x_u, prediction_dict, param_values_batch, solution, force=False):
        """ Alternative method to compute local smoothing, not used anymore """
        if not self.use_surrogate_problem and not force:
            return 0
        o = solution.detach()
        distance = torch.norm(x_u - o, dim=1).detach()
        mask_is_infeasible = distance >= 1e-3
        assert len(distance) == 1, 'Only bs=1 is supported for surrogate for now!'
        radius = torch.maximum(distance / 2, torch.tensor(self.surrogate_problem_min_radius))
        radius = radius[:, None]
        surrogate_params = {'w_lin': prediction_dict['w_lin'],
                            'W_sq': prediction_dict['W_sq'],
                            'o': o,
                            'r_square': radius ** 2}
        if 'gamma_sqrt' in prediction_dict:
            surrogate_params['gamma_sqrt'] = prediction_dict['gamma_sqrt']
        solution_dict_surrogate = self.surrogate_problem.solve(surrogate_params)
        solution_surrogate = solution_dict_surrogate['solution']
        reward_loss, _, _ = self.compute_reward_loss(solution_surrogate, param_values_batch)
        return reward_loss#  * mask_is_infeasible

    def compute_local_loss(self, x_u, solution, param_values_batch):
        """ Compute local r-smoothing loss """
        x = solution.detach()
        constraint_violation = self.problem.get_penalty(x_u, param_values_batch, proximity_method_to_use='max')
        mask_is_infeasible = (constraint_violation > 0)[:, None]
        reward_gradient = self.problem.get_reward_gradient(x, param_values_batch)
        radius_vector = x - x_u
        dot_product = (reward_gradient * radius_vector).sum(-1, keepdim=True)
        radius_vector_norm_sq = (radius_vector * radius_vector).sum(-1,  keepdim=True)
        projection = radius_vector * dot_product / radius_vector_norm_sq
        desired_direction = reward_gradient - mask_is_infeasible * projection
        if self.add_inwards_gradient:
            reward_vector_norm_sq = (reward_gradient * reward_gradient).sum(-1, keepdim=True)
            cos = dot_product / (torch.sqrt(radius_vector_norm_sq * reward_vector_norm_sq))
            reward_points_inward_mask = cos >= self.inwards_gradient_th
            desired_direction += mask_is_infeasible * reward_points_inward_mask * projection * torch.sign(cos)
        loss = (-desired_direction.detach() * x_u).sum(-1)
        return loss

    def estimate_all_gradients(self, param_values_batch, **kwargs):
        """ Auxiliary function used to visualize gradients from different losses """
        for key, val in param_values_batch.items():
            assert val.shape[0] == 1, 'Estimate gradients only works for bs=1!'
        obs = self.problem.get_observation(param_values_batch).detach()
        grads_dict = {}

        # Reward gradient
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        reward_loss_before, _, _ = self.compute_reward_loss(solution_before, param_values_batch)
        reward_loss_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['rew-d_' + key] = val - prediction_dict_before[key]
        grads_dict['rew-d_solution'] = solution_after - solution_before

        # Proximity
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        x_u_before = self.problem.get_x_u(prediction_dict_before, param_values_batch)
        proximity_loss_before = self.compute_proximity_loss(x_u_before, param_values_batch)
        proximity_loss_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['p-d_' + key] = val - prediction_dict_before[key]
        grads_dict['p-d_solution'] = solution_after - solution_before
        # Solution distance
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        x_u_before = self.problem.get_x_u(prediction_dict_before, param_values_batch)
        solution_distance_before = self.compute_solution_distance_loss(x_u_before, param_values_batch)
        solution_distance_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['sd-d_' + key] = val - prediction_dict_before[key]
        grads_dict['sd-d_solution'] = solution_after - solution_before

        # Projection distance
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        x_u_before = self.problem.get_x_u(prediction_dict_before, param_values_batch)
        projection_distance_before = self.compute_projection_distance_loss(x_u_before, solution_before, param_values_batch)
        projection_distance_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['pd-d_' + key] = val - prediction_dict_before[key]
        grads_dict['pd-d_solution'] = solution_after - solution_before

        # MSE loss
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        mse_loss_before = self.compute_mse_loss(prediction_dict_before, param_values_batch)
        mse_loss_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['mse-d_' + key] = val - prediction_dict_before[key]
        grads_dict['mse-d_solution'] = solution_after - solution_before

        # Surrogate loss
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        x_u_before = self.problem.get_x_u(prediction_dict_before, param_values_batch)
        surrogate_loss_before = self.compute_surrogate_problem_loss(x_u_before, prediction_dict_before,
                                                                    param_values_batch, solution_before, force=True)
        try:
            surrogate_loss_before.mean(0).backward()
        except:
            pass
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['surr-d_' + key] = val - prediction_dict_before[key]
        grads_dict['surr-d_solution'] = solution_after - solution_before

        # Local loss
        predictor, optimizer = self.copy_predictor_and_optimizer(optimizer_class_to_use='sgd')
        optimizer.zero_grad()
        prediction_dict_before = predictor.predict(obs)
        solution_before = self._solve_controller_problem(param_values_batch, prediction_dict_before)
        solution_before = solution_before['solution']
        x_u_before = self.problem.get_x_u(prediction_dict_before, param_values_batch)
        local_loss_before = self.compute_local_loss(x_u_before, solution_before, param_values_batch)
        local_loss_before.mean(0).backward()
        optimizer.step()
        prediction_dict_after = predictor.predict(obs)
        solution_after = self._solve_controller_problem(param_values_batch, prediction_dict_after)
        solution_after = solution_after['solution']
        for key, val in prediction_dict_after.items():
            grads_dict['loc-d_' + key] = val - prediction_dict_before[key]
        grads_dict['loc-d_solution'] = solution_after - solution_before

        return grads_dict

    def get_current_weights(self, pretraining):
        reward_weight = (self.pretraining_reward_weight if pretraining
                         else self.reward_weight)
        proximity_weight = (self.pretraining_proximity_weight if pretraining
                            else self.proximity_weight)
        solution_distance_weight = (self.pretraining_solution_distance_weight if pretraining
                                    else self.solution_distance_weight)
        projection_distance_weight = (self.pretraining_projection_distance_weight if pretraining
                                      else self.projection_distance_weight)
        mse_weight = (self.pretraining_mse_weight if pretraining
                      else self.mse_weight)
        surrogate_problem_weight = (self.pretraining_surrogate_problem_weight if pretraining
                                    else self.surrogate_problem_weight)
        local_loss_weight = (self.pretraining_local_loss_weight if pretraining
                             else self.local_loss_weight)
        return (reward_weight, proximity_weight, solution_distance_weight, projection_distance_weight,
                mse_weight, surrogate_problem_weight, local_loss_weight)

    def run_on_batch(self, param_values_batch, train=True, normalize_losses=False, pretraining=False):
        obs = self.problem.get_observation(param_values_batch).detach()
        if train:
            self.optimizer.zero_grad()
        (reward_weight, proximity_weight, solution_distance_weight, projection_distance_weight,
         mse_weight, surrogate_problem_weight, local_loss_weight) = self.get_current_weights(pretraining)
        normalizer = (reward_weight + proximity_weight + solution_distance_weight + projection_distance_weight
                      + mse_weight + surrogate_problem_weight + local_loss_weight) if normalize_losses else 1
        # Prepare optimizer, make prediction, precompute important points
        if train:
            self.optimizer.zero_grad()
        prediction_dict = self.predictor.predict(obs)
        x_u = self.problem.get_x_u(prediction_dict, param_values_batch)
        solution_dict = self._solve_controller_problem(param_values_batch, prediction_dict)
        solution = solution_dict['solution']
        # Compute losses
        reward_loss, reward, r_max = self.compute_reward_loss(solution_dict['solution'], param_values_batch)
        reward_loss_true, reward_true, r_max_true = torch.tensor(reward_loss), torch.tensor(reward), torch.tensor(r_max)
        if 'solution_true' in solution_dict:
            reward_loss_true, reward_true, r_max_true = self.compute_reward_loss(solution_dict['solution_true'],
                                                                                 param_values_batch)
        proximity_loss = (self.compute_proximity_loss(x_u, param_values_batch)
                          if proximity_weight else torch.zeros_like(reward_loss))
        solution_distance_loss = (self.compute_solution_distance_loss(x_u, param_values_batch)
                                  if solution_distance_weight else torch.zeros_like(reward_loss))
        projection_distance_loss = (self.compute_projection_distance_loss(x_u, solution, param_values_batch)
                                    if projection_distance_weight else torch.zeros_like(reward_loss))
        mse_loss = (self.compute_mse_loss(prediction_dict, param_values_batch)
                    if mse_weight else torch.zeros_like(reward_loss))
        surrogate_problem_loss = (self.compute_surrogate_problem_loss(x_u, prediction_dict, param_values_batch, solution)
                                  if surrogate_problem_weight else torch.zeros_like(reward_loss))
        local_loss = (self.compute_local_loss(x_u, solution, param_values_batch, )
                      if local_loss_weight else torch.zeros_like(reward_loss))
        # Total loss
        total_loss = (reward_weight * reward_loss
                      + proximity_weight * proximity_loss
                      + solution_distance_weight * solution_distance_loss
                      + projection_distance_weight * projection_distance_loss
                      + mse_weight * mse_loss
                      + surrogate_problem_loss * surrogate_problem_weight
                      + local_loss * local_loss_weight) / normalizer
        if train:
            total_loss.mean(dim=0).backward()
            self.optimizer.step()
        results_dict = {'reward_loss': reward_loss,
                        'proximity_loss': proximity_loss,
                        'solution_distance': solution_distance_loss,
                        'projection_distance': projection_distance_loss,
                        'mse_loss': mse_loss,
                        'total_loss': total_loss,
                        'x_u': x_u,
                        'reward': reward,
                        'reward_true': reward_true,
                        'regret': r_max - reward,
                        'regret_true': r_max_true - reward_true,
                        'r_max': r_max,
                        'solution': solution,
                        }
        results_dict.update(prediction_dict)
        return results_dict

