
from functools import reduce
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import os
import numpy as np
import cv2

import tensorflow as tf

from tensorpack import logger, ModelSaver, EstimatedTimeLeft, \
                        ClassificationError, InferenceRunner, \
                        QueueInput, EnableCallbackIf, \
                        SyncMultiGPUTrainerReplicated, SmartInit, \
                        SimpleDatasetPredictor, logger, AutoResumeTrainConfig, \
                        FeedfreeInput, RunOp
from tensorpack.predict import PredictConfig, OfflinePredictor
from tensorpack.utils.stats import RatioCounter

from attacker import NoOpAttacker, PGDAttacker
from nets import Model, AdvModel


class CustomTrainer(SyncMultiGPUTrainerReplicated):
  def _setup_graph(self, input, get_cost_fn, get_opt_fn):
    if len(self.devices) > 1:
      assert isinstance(input, FeedfreeInput), input
    tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)
    grad_list = self._builder.call_for_each_tower(tower_fn)
    if not self.robust_activation_trainable:
      new_grad_list = []
      for tower in grad_list:
        new_tower = []
        for x in tower:
          v = x[1]
          if 'dynamic_relu' in v.name and ('gamma' in v.name or 'beta' in v.name):
            pass
          else:
            new_tower.append(x)
        new_grad_list.append(new_tower)
      grad_list = new_grad_list
    self.train_op, post_init_op = self._builder.build(grad_list, get_opt_fn)

    if post_init_op is not None:
      cb = RunOp(
          post_init_op,
          run_before=True,
          run_as_trigger=self.BROADCAST_EVERY_EPOCH,
          verbose=True)
      cb.name_scope = "SyncVariables"
      return [cb]
    else:
      return []


def create_eval_callback(name, dataflow_func, tower_func, condition):
  """Create evalution callbacks.
  """
  # We eval both the classification error rate (for comparison with defenders)
  # and the attack success rate (for comparison with attackers).
  infs = [ClassificationError('wrong-top1', '{}-top1-error'.format(name)),
          ClassificationError('wrong-top5', '{}-top5-error'.format(name)),
          ClassificationError('attack_success', '{}-attack-success-rate'.format(name))
          ]
  cb = InferenceRunner(
      QueueInput(dataflow_func()), infs,
      tower_name=name,
      tower_func=tower_func).set_chief_only(False)
  cb = EnableCallbackIf(cb, lambda self: condition(self.epoch_num))

  return cb


def train(gpus, model,
          train_dataflow,
          extra_callbacks,
          get_val_dataflow_func=None,
          steps_per_epoch=None,
          max_epoch=99999999,
          starting_epoch=1,
          saver_max_to_keep=10,
          keep_checkpoint_every_n_hours=0.5,
          auto_resume=True,
          sessinit=None,
          attacker_step_size=1.0,
          attacker_epsilon=16.0):
  """Main train procedure.

  Steps:
    1. 
    2. 


  Parameters
  ----------

  extra_callbacks : list of callbacks
    Should contain learning rate scheduler at least.

  """
  # Setup callbacks
  callbacks = [
    ModelSaver(max_to_keep=saver_max_to_keep,
               keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours),
    EstimatedTimeLeft()
  ] + extra_callbacks

  def add_eval_callback(name, attacker, condition):
    cb = create_eval_callback(name, get_val_dataflow_func,
                              model.get_inference_func_with_attacker(attacker),
                              lambda epoch : condition(epoch) or epoch > max_epoch - 1)
    callbacks.append(cb)

  if isinstance(model, AdvModel):

    add_eval_callback('eval-clean', NoOpAttacker(), lambda e: e % 50 == 0)
    add_eval_callback('eval-10step', 
                      PGDAttacker(10, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    # add_eval_callback('eval-200step',
    #                   PGDAttacker(200, 1.0, 4.0, # attacker_step_size, attacker_epsilon,
    #                               num_classes=model.num_classes),
    #                   lambda e: e % 100 == 0)
    add_eval_callback('eval-20step',
                      PGDAttacker(20, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    add_eval_callback('eval-30step',
                      PGDAttacker(30, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    add_eval_callback('eval-50step',
                      PGDAttacker(50, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
  else:
    # cb = create_eval_callback('eval', get_val_dataflow_func,
    #                           model.get_inference_func(), 
    #                           lambda e : e % 5 == 0 or e > max_epoch - 1)
    add_eval_callback('eval-clean', NoOpAttacker(), lambda e: True)
    add_eval_callback('eval-10step', 
                      PGDAttacker(10, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    add_eval_callback('eval-20step',
                      PGDAttacker(20, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    add_eval_callback('eval-30step',
                      PGDAttacker(30, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)
    add_eval_callback('eval-50step',
                      PGDAttacker(50, attacker_step_size, attacker_epsilon,
                                  num_classes=model.num_classes),
                      lambda e: e % 50 == 0)

    # callbacks.append(cb)
  if auto_resume:
    assert sessinit is None, "If you want auto resume, why provide sessinit"
    config = AutoResumeTrainConfig(steps_per_epoch=steps_per_epoch)
    session_init = config.session_init
    starting_epoch = config.starting_epoch
  else:
    session_init = sessinit
  # trainer = SyncMultiGPUTrainerReplicated(gpus=gpus, average=True)
  trainer = CustomTrainer(gpus=gpus, average=True)
  trainer.robust_activation_trainable = model.robust_activation_trainable
  trainer.setup_graph(model.get_input_signature(), 
                      train_dataflow, model.build_graph, model.get_optimizer)
  trainer.train_with_defaults(
        callbacks=callbacks,
        steps_per_epoch=steps_per_epoch,
        session_init=session_init,
        max_epoch=max_epoch,
        starting_epoch=starting_epoch)


def eval(model, sessinit, dataflow):
  if isinstance(model, AdvModel):
    pred_config = PredictConfig(
        model=model,
        session_init=sessinit,
        input_names=['input', 'label'],
        output_names=['wrong-top1', 'wrong-top5', 'attack_success'])
    predictor = OfflinePredictor(pred_config)
    dataflow.reset_state()
    top1, top5, succ = RatioCounter(), RatioCounter(), RatioCounter()
    for dp in dataflow:
      res = predictor(*dp)
      top1.feed(res[0])
      top5.feed(res[1])
      succ.feed(res[2])
    print("Top1 Error: {}".format(top1.ratio))
    print("Top5 Error: {}".format(top5.ratio))
    print("Attack Success Rate: {}".format(succ.ratio))
  else:
    pred_config = PredictConfig(
        model=model,
        session_init=sessinit,
        input_names=['input', 'label'],
        output_names=['wrong-top1', 'wrong-top5'])
    pred = SimpleDatasetPredictor(pred_config, dataflow)
    acc1, acc5 = RatioCounter(), RatioCounter()
    for top1, top5 in pred.get_result():
        batch_size = top1.shape[0]
        acc1.feed(top1.sum(), batch_size)
        acc5.feed(top5.sum(), batch_size)
    print("Top1 Error: {}".format(acc1.ratio))
    print("Top5 Error: {}".format(acc5.ratio))
