# coding: utf-8

import tools

import math
import torch

from . import register

# ---------------------------------------------------------------------------- #
# Attack implementation

def attack(grad_honests, f_real, f_decl, defense, model, epsilon=1, **kwargs):
  """ Generate the attack gradients.
  Args:
    grad_honests Non-empty list of honest gradients
    f_decl       Number of declared Byzantine gradients
    f_real       Number of Byzantine gradients to generate
    defense      Aggregation rule in use to defeat
    model        Model with valid default dataset and loss set
    epsilon      Fixed epsilon if positive, number of evaluations for best attack epsilon if negative
    ...          Ignored keyword-arguments
  Returns:
    Generated Byzantine gradients
  """
  # Fast path
  if f_real == 0:
    return list()
  # Compute the honest average gradient
  grad_avg = grad_honests[0].clone().detach()
  for grad in grad_honests[1:]:
    grad_avg.add_(grad)
  grad_avg.div_(len(grad_honests))
  # Evaluate the best attack epsilon (if required)
  if epsilon < 0:
    def eval_epsilon(epsilon):
      # Apply the given epsilon
      grad_attack = grad_avg.mul(-epsilon)
      # Measure effective squared distance
      aggregated = defense(gradients=(grad_honests + [grad_attack] * f_real), f=f_decl, model=model)
      aggregated.sub_(grad_avg)
      return aggregated.dot(aggregated).item()
    epsilon = tools.line_maximize(eval_epsilon, evals=math.ceil(-epsilon))
  # Generate the Byzantine gradient
  byz_grad = grad_avg
  byz_grad.mul_(-epsilon)
  # Return this Byzantine gradient 'f_real' times
  return [byz_grad] * f_real

def check(grad_honests, f_real, defense, epsilon=1, **kwargs):
  """ Check parameter validity for this attack.
  Args:
    grad_honests Non-empty list of honest gradients
    f_real       Number of Byzantine gradients to generate
    defense      Aggregation rule in use to defeat
    ...          Ignored keyword-arguments
  Returns:
    Whether the given parameters are valid for this attack
  """
  if not isinstance(grad_honests, list) or len(grad_honests) == 0:
    return f"Expected a non-empty list of honest gradients, got {grad_honests!r}"
  if not isinstance(f_real, int) or f_real < 0:
    return f"Expected a non-negative number of Byzantine gradients to generate, got {f_real!r}"
  if not callable(defense):
    return f"Expected a callable for the aggregation rule, got {defense!r}"
  if not isinstance(epsilon, int) or epsilon == 0:
    return f"Expected a non-zero attack epsilon, got {epsilon!r}"

# ---------------------------------------------------------------------------- #
# Attack registration

# Register the attack
register("empire-strict", attack, check)
