# coding: utf-8
###
 # @file   reproduce_cifar.py
 #
 # @section LICENSE
 #
 # @section DESCRIPTION
 #
 # Running clipping experiments on CIFAR-10.
###

import tools, study, misc
tools.success("Module loading...")
import signal, torch

# ---------------------------------------------------------------------------- #
# Miscellaneous initializations
tools.success("Miscellaneous initializations...")

# "Exit requested" global variable accessors
exit_is_requested, exit_set_requested = tools.onetime("exit")

# Signal handlers
signal.signal(signal.SIGINT, exit_set_requested)
signal.signal(signal.SIGTERM, exit_set_requested)

# ---------------------------------------------------------------------------- #
#JS: Pick the dataset on which to run experiments
dataset = "cifar10"
result_directory = "results-data-" + dataset
plot_directory = "results-plot-" + dataset

with tools.Context("cmdline", "info"):
  args = misc.process_commandline()
  # Make the result directories
  args.result_directory = misc.check_make_dir(result_directory)
  args.plot_directory = misc.check_make_dir(plot_directory)
  # Preprocess/resolve the devices to use
  if args.devices == "auto":
    if torch.cuda.is_available():
      args.devices = list(f"cuda:{i}" for i in range(torch.cuda.device_count()))
    else:
      args.devices = ["cpu"]
  else:
    args.devices = list(name.strip() for name in args.devices.split(","))

# ---------------------------------------------------------------------------- #
# Run (missing) experiments
tools.success("Running experiments...")

params_cifar = {
  "batch-size": 50,
  "loss": "NLLLoss",
  "learning-rate-decay": 25,
  "learning-rate-decay-delta": 1500,
  "weight-decay": 1e-2,
  "evaluation-delta": 100,
  "nb-steps": 2000,
  "momentum-worker": 0.9,
  "numb-labels": 10,
  "mimic-learning-phase": 400,
  "aggregator": "average",
  "nb-decl-byz" : 0,
  "nb-real-byz": 0
}


# Hyperparameters to test
attacks = ["auto_FOE", "auto_ALIE_pos", "LF", "SF", "mimic"]
gars = ["trmean", "geometric_median", "median", "multi_krum"]
gar_names = {"trmean": "CWTM", "median": "CWMED", "geometric_median": "GM", "multi_krum": "MK"}
honest_workers = 16
alpha_byz = [(0.075, 1), (0.2, 1), (0.1, 1), (0.05, 1)]
alphas = [0.1, 0.075, 0.2, 0.05]
models = [("cnn_cifar_old", 0.05)]
static_clip =  ["None"]
params_common = params_cifar


# Command maker helper
def make_command(params):
  cmd = ["python3", "-OO", "train.py"]
  cmd += tools.dict_to_cmdlist(params)
  return tools.Command(cmd)

# Jobs
jobs  = tools.Jobs(args.result_directory, devices=args.devices, devmult=args.supercharge)
seeds = jobs.get_seeds()

#JS: DSHB
for model, lr in models:
    for alpha in alphas:
        params = params_common.copy()
        params["dataset"] = dataset
        params["model"] = model
        params["learning-rate"] = lr
        params["nb-workers"] = honest_workers
        if alpha == "extreme":
          params["hetero"] = True
        else:
          params["dirichlet-alpha"] = alpha
        jobs.submit(f"{dataset}-average-n_{honest_workers}-model_{model}-lr_{lr}-alpha_{alpha}", make_command(params))

#JS: NNM + Clip
for model, lr in models:
    for alpha, f in alpha_byz:
        for gar in gars:
            for attack in attacks:
                #JS: NNM with static clipping
                params = params_common.copy()
                params["dataset"] = dataset
                params["model"] = model
                params["learning-rate"] = lr
                params["nb-workers"] = honest_workers + f
                params["nb-decl-byz"] = params["nb-real-byz"] = f
                params["attack"] = attack
                params["aggregator"] = gar
                params["pre-aggregator"] = "nnm"
                if alpha == "extreme":
                    params["hetero"] = True
                else:
                  params["dirichlet-alpha"] = alpha

                for clip_parameter in static_clip:
                  params["gradient-clip"] = clip_parameter
                  if clip_parameter == "None":
                    params["gradient-clip"] = None
                  jobs.submit(f"{dataset}-{attack}-{gar}-n_{honest_workers + f}-f_{f}-model_{model}-lr_{lr}-StaticClip+NNM_{clip_parameter}-alpha_{alpha}", make_command(params))

                #JS: NNM with adaptive clipping
                params["gradient-clip"] = None
                params["server-clip"] = True
                jobs.submit(f"{dataset}-{attack}-{gar}-n_{honest_workers + f}-f_{f}-model_{model}-lr_{lr}-adaptive-alpha_{alpha}", make_command(params))

# Wait for the jobs to finish and close the pool
jobs.wait(exit_is_requested)
jobs.close()

# Check if exit requested before going to plotting the results
if exit_is_requested():
  exit(0)

 # ---------------------------------------------------------------------------- #

 # Plot results
tools.success("Plotting results...")

for model, lr in models:
   for alpha, f in alpha_byz:

      #JS: DSGD
      name = f"{dataset}-average-n_{honest_workers}-model_{model}-lr_{lr}-alpha_{alpha}"
      dsgd = misc.compute_avg_err_op(name, seeds, result_directory, "eval", ("Accuracy", "max"))

      #JS: Robust aggregators
      for attack in attacks:
          plot = study.LinePlot()
          plot.include(dsgd[0]*100, "Accuracy", errs="-err", lalp=0.8)
          legend = ["DSGD (f = 0)"]
          #legend = list()
          for gar in gars:
              #JS: no clip
              name = f"{dataset}-{attack}-{gar}-n_{honest_workers + f}-f_{f}-model_{model}-lr_{lr}-StaticClip+NNM_None-alpha_{alpha}"
              no_clip = misc.compute_avg_err_op(name, seeds, result_directory, "eval", ("Accuracy", "max"))
              plot.include(no_clip[0]*100, "Accuracy", errs="-err", lalp=0.8)
              legend.append(gar_names[gar] + " ∘ NNM")

              #JS: adaptive clip
              name = f"{dataset}-{attack}-{gar}-n_{honest_workers + f}-f_{f}-model_{model}-lr_{lr}-adaptive-alpha_{alpha}"
              adaptive_clip = misc.compute_avg_err_op(name, seeds, result_directory, "eval", ("Accuracy", "max"))
              plot.include(adaptive_clip[0]*100, "Accuracy", errs="-err", lalp=0.8)
              legend.append(gar_names[gar] + "∘ NNM ∘ ARC")

          #JS: plot every time graph in terms of the maximum number of steps
          plot.finalize(None, "Step", "Test Accuracy", xmin=0, xmax=params_common['nb-steps'], ymin=0, ymax=100, legend=legend)
          plot.save(plot_directory + "/" + dataset + "_" + attack + "_f=" + str(f) + "_model=" + str(model) + "_lr=" + str(lr) + "_alpha=" + str(alpha) + ".pdf", xsize=3, ysize=1.5)