# coding: utf-8

import tools
tools.success("Module loading...")

import argparse
import math
import pathlib
import signal
import shlex
import sys

import torch

import experiments

# ---------------------------------------------------------------------------- #
# Miscellaneous initializations
tools.success("Miscellaneous initializations...")

# "Exit requested" global variable accessors
exit_is_requested, exit_set_requested = tools.onetime("exit")

# Signal handlers
signal.signal(signal.SIGINT, exit_set_requested)
signal.signal(signal.SIGTERM, exit_set_requested)

# ---------------------------------------------------------------------------- #
# Command-line processing
tools.success("Command-line processing...")

def process_commandline():
  """ Parse the command-line and perform checks.
  Returns:
    Parsed configuration
  """
  # Description
  parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument("--data-directory",
    type=str,
    default="results-data",
    help="Path of the data directory, containing the data gathered from the experiments")
  parser.add_argument("--plot-directory",
    type=str,
    default="results-plot",
    help="Path of the plot directory, containing the graphs traced from the experiments")
  parser.add_argument("--devices",
    type=str,
    default="auto",
    help="Comma-separated list of devices on which to run the experiments, used in a round-robin fashion")
  parser.add_argument("--supercharge",
    type=int,
    default=1,
    help="How many experiments are run in parallel per device, must be positive")
  # Parse command line
  return parser.parse_args(sys.argv[1:])

with tools.Context("cmdline", "info"):
  args = process_commandline()
  # Check the "supercharge" parameter
  if args.supercharge < 1:
    tools.fatal(f"Expected a positive supercharge value, got {args.supercharge}")
  # Make the result directories
  def check_make_dir(path):
    path = pathlib.Path(path)
    if path.exists():
      if not path.is_dir():
        tools.fatal(f"Given path {str(path)!r} must point to a directory")
    else:
      path.mkdir(mode=0o755, parents=True)
    return path
  args.data_directory = check_make_dir(args.data_directory)
  args.plot_directory = check_make_dir(args.plot_directory)
  # Preprocess/resolve the devices to use
  if args.devices == "auto":
    if torch.cuda.is_available():
      args.devices = list(f"cuda:{i}" for i in range(torch.cuda.device_count()))
    else:
      args.devices = ["cpu"]
  else:
    args.devices = list(name.strip() for name in args.devices.split(","))

# ---------------------------------------------------------------------------- #
# Serial preloading of the dataset
tools.success("Pre-downloading datasets...")

# Pre-load the datasets to prevent the first parallel runs from downloading them several times
with tools.Context("dataset", "info"):
  for name in ("mnist", "cifar10"):
    with tools.Context(name, "info"):
      experiments.make_datasets(name, 1, 1)

# ---------------------------------------------------------------------------- #
# Run (missing) experiments
tools.success("Running experiments...")

# Training datapoints per dataset
trainpoints = {
  "mnist": 60000,
  "fashionmnist": 60000 }

# Hyperparameters to test
batchs = (25, 50, 150, 300, 500, 750, 1000, 1250, 1500)
epsilons = (0.2, 0.1, 0.05)
byzcounts = ((3, 0), (6, 1))
gars_mnist = ("brute", "krum", "median", "bulyan")
attacks = (("little", "factor:1. negative:True"), ("empire", "factor:1.1"))
momentums = ("worker",)
datasets = ("mnist", "fashionmnist")

# Command maker helper
def make_command(params):
  cmd = ["python3", "-OO", "momentum.py"]
  cmd += tools.dict_to_cmdlist(params)
  return tools.Command(cmd)

# Compute batch increase factor helper
def compute_batch_factor(params):
  global trainpoints
  maxperhonest = math.floor(trainpoints[params["dataset"]] / (params["nb-workers"] - params.get("nb-real-byz", 0)))
  return math.pow(maxperhonest / params["batch-size"], 1. / params["nb-steps"])

# Jobs
jobs  = tools.Jobs(args.data_directory, devices=args.devices, devmult=args.supercharge)
seeds = jobs.get_seeds()

# Base parameters for the (Fashion-)MNIST experiments
params_mnist = {
  "batch-size": 25,
  "model": "simples-full",
  "loss": "nll",
  "learning-rate": 0.5,
  "learning-rate-decay-delta": 300,
  "l2-regularize": 1e-4,
  "evaluation-delta": 5,
  "gradient-clip": 2,
  "nb-steps": 300,
  "nb-for-study": 1,
  "nb-for-study-past": 1,
  "nb-workers": 15,
  "privacy-delta": 1e-5,
  "momentum": 0.99,
  "dampening": 0.99 }

# Submit all (Fashion-)MNIST experiments
for ds in datasets:
  for batch in batchs:
    for epsilon in epsilons:
      for f, fm in byzcounts:
        # No attack
        params = params_mnist.copy()
        params["dataset"] = ds
        params["nb-workers"] -= f
        if epsilon is not None:
          params["privacy"] = True
          params["privacy-epsilon"] = epsilon
        if batch is None:
          params["batch-increase"] = True
          params["batch-increase-factor"] = compute_batch_factor(params)
        else:
          params["batch-size"] = batch
        jobs.submit(f"{ds}-average-n_{params['nb-workers']}-m_{params['momentum']}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}", make_command(params))
        # Attacked
        for gar in gars_mnist[:len(gars_mnist) - fm]:
          for attack, attargs in attacks:
            for momentum in momentums:
              params = params_mnist.copy()
              params["dataset"] = ds
              params["nb-decl-byz"] = params["nb-real-byz"] = f
              params["gar"] = gar
              params["attack"] = attack
              params["attack-args"] = shlex.split(attargs)
              params["momentum-at"] = momentum
              if epsilon is not None:
                params["privacy"] = True
                params["privacy-epsilon"] = epsilon
              if batch is None:
                params["batch-increase"] = True
                params["batch-increase-factor"] = compute_batch_factor(params)
              else:
                params["batch-size"] = batch
              jobs.submit(f"{ds}-{attack}-{gar}-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}", make_command(params))

# Wait for the jobs to finish and close the pool
jobs.wait(exit_is_requested)
jobs.close()

# Check if exit requested before going to plotting the results
if exit_is_requested():
  exit(0)

# ---------------------------------------------------------------------------- #
# Plot results
tools.success("Plotting results...")

# Import additional modules
try:
  import numpy
  import pandas
  import study
except ImportError as err:
  tools.fatal(f"Unable to plot results: {err}")

# GAR renaming
gar_to_name = {
  "brute": "MDA" }

def compute_avg_err_op(name, *colops, avgs="", errs="-err"):
  """ Compute the average and standard deviation of the selected columns over the given experiment.
  Args:
    name Given experiment name
    ...  Tuples of (selected column name (through 'study.select'), optional reduction operation name)
    avgs Suffix for average column names
    errs Suffix for standard deviation (or "error") column names
  Returns:
    Data frames for each of the computed columns,
    Tuple of reduced values per seed (or None if None was provided for 'op')
  Raises:
    'RuntimeError' if a reduction operation was specified for a column selector that did not select exactly 1 column
  """
  # Load all the runs for the given experiment name, and keep only a subset
  datas = tuple(study.select(study.Session(args.data_directory / f"{name}-{seed}").compute_ratio(nowarn=True), *(col for col, _ in colops)) for seed in seeds)
  # Make the aggregated data frames
  def make_df_ro(col, op):
    nonlocal datas
    # For every selected columns
    subds = tuple(study.select(data, col).dropna() for data in datas)
    df    = pandas.DataFrame(index=subds[0].index)
    ro    = None
    for cn in subds[0]:
      # Generate compound column names
      avgn = cn + avgs
      errn = cn + errs
      # Compute compound columns
      numds = numpy.stack(tuple(subd[cn].to_numpy() for subd in subds))
      df[avgn] = numds.mean(axis=0)
      df[errn] = numds.std(axis=0)
      # Compute reduction, if requested
      if op is not None:
        if ro is not None:
          raise RuntimeError(f"column selector {col!r} selected more than one column ({(', ').join(subds[0].columns)}) while a reduction operation was requested")
        ro = tuple(getattr(subd[cn], op)().item() for subd in subds)
    # Return the built data frame and optional computed reduction
    return df, ro
  dfs = list()
  ros = list()
  for col, op in colops:
    df, ro = make_df_ro(col, op)
    dfs.append(df)
    ros.append(ro)
  # Return the built data frames and optional computed reductions
  return dfs, ros

def compute_mean_dev(vals):
  """ Compute the mean and standard deviation of an iterable of values.
  Args:
    vals Iterable of values
  Returns:
    Mean value,
    Standard deviation
  Raises:
    'AssertionError' if less than two values
  """
  assert len(vals) >= 2
  mean = sum(vals) / len(vals)
  sdev = math.sqrt(sum((val - mean) ** 2 for val in vals))
  return mean, sdev

def select_ymax(data_w):
  """ Select the max y value for the given ratio data.
  Args:
    data_w Ratio data
  Returns:
    Maximum y value to use in the plot
  """
  vmax = max(data_w[3]["Sampled ratio"].max(), data_w[1]["Honest ratio"].max())
  for ymax in (1., 2., 6., 12.):
    if vmax < ymax:
      return ymax
  return math.ceil(vmax)

def post_batchs(axis):
  """ Post-operation axis processing for variable batch graphs.
  Args:
    axis Matplotlib axis instance on which the graph is drawn
  """
  global batchs
  axis.set_xticks(batchs)
  axis.set_xticklabels(str(batch) if i != 1 else "" for i, batch in enumerate(batchs))

# Plot (Fashion-)MNIST results
with tools.Context(ds, "info"):
  for ds in datasets:
    maxaccs = dict()
    for batch in batchs:
      for epsilon in epsilons:
        for f, fm in byzcounts:
          # No attack
          name = f"{ds}-average-n_{params_mnist['nb-workers'] - f}-m_{params['momentum']}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}"
          try:
            noattack, ros = compute_avg_err_op(name, ("Accuracy", "max"), ("Honest ratio", None), ("Average loss", None))
            maxaccs[(batch, epsilon, f)] = compute_mean_dev(ros[0])
          except Exception as err:
            tools.warning(f"Unable to process {name}: {err}")
            continue
          # Attacked
          for attack, _ in attacks:
            attacked_at = dict()
            for momentum in momentums:
              name = f"{ds}-{attack}-%s-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}"
              attacked = dict()
              for gar in gars_mnist[:len(gars_mnist) - fm]:
                try:
                  cols, ros = compute_avg_err_op(name % gar, ("Accuracy", "max"), ("Honest ratio", None), ("Average loss", None), ("Sampled ratio", None))
                  attacked[gar] = cols
                  maxaccs[(batch, epsilon, f, attack, momentum, gar)] = compute_mean_dev(ros[0])
                except Exception as err:
                  tools.warning(f"Unable to process {name % gar!r}: {err}")
                  continue
              attacked_at[momentum] = attacked
              # Plot top-1 cross-accuracy
              plot = study.LinePlot()
              plot.include(noattack[0], "Accuracy", errs="-err", lalp=0.8)
              legend = ["No attack"]
              for gar in gars_mnist[:len(gars_mnist) - fm]:
                if gar not in attacked:
                  continue
                plot.include(attacked[gar][0], "Accuracy", errs="-err", lalp=0.8)
                legend.append(gar_to_name.get(gar, gar.capitalize()))
              plot.finalize(None, "Step number", "Top-1 cross-accuracy", xmin=0, xmax=300, ymin=0, ymax=0.85, legend=legend)
              plot.save(args.plot_directory / f"{ds}-{attack}-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}.png", xsize=3, ysize=1.5)
              # Plot average loss
              plot = study.LinePlot()
              plot.include(noattack[2], "Average loss", errs="-err", lalp=0.8)
              legend = ["No attack"]
              for gar in gars_mnist[:len(gars_mnist) - fm]:
                if gar not in attacked:
                  continue
                plot.include(attacked[gar][2], "Average loss", errs="-err", lalp=0.8)
                legend.append(gar_to_name.get(gar, gar.capitalize()))
              plot.finalize(None, "Step number", "Average loss", xmin=0, xmax=300, ymin=0, legend=legend)
              plot.save(args.plot_directory / f"{ds}-{attack}-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}-loss.png", xsize=3, ysize=1.5)
            # Plot per-gar variance-norm ratios
            for gar in gars_mnist[:len(gars_mnist) - fm]:
              data_w = attacked_at["worker"].get(gar)
              if data_w is None:
                continue
              plot = study.LinePlot()
              plot.include(data_w[3], "ratio", errs="-err", lalp=0.5, ccnt=0)
              plot.include(data_w[1], "ratio", errs="-err", lalp=0.5, ccnt=4)
              plot.finalize(None, "Step number", "Variance-norm ratio", xmin=0, xmax=300, ymin=0, ymax=select_ymax(data_w), legend=tuple(f"{gar_to_name.get(gar, gar.capitalize())} \"{at}\"" for at in ("sample", "submit")))
              plot.save(args.plot_directory / f"{ds}-{attack}-{gar}-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_{'none' if epsilon is None else epsilon}-ratio.png", xsize=3, ysize=1.5)
    for momentum in momentums:
      for f, fm in byzcounts:
        for attack, _ in attacks:
          for batch in batchs:
            # Plot max cross-accuracy function of epsilon
            plot = study.LinePlot()
            plot.include_simple(((epsilon,) + maxaccs[(batch, epsilon, f)] for epsilon in epsilons), "No attack", lalp=0.8, post=lambda axis: axis.set_xticks(epsilons))
            legend = ["No attack"]
            for gar in gars_mnist[:len(gars_mnist) - fm]:
              plot.include_simple(((epsilon,) + maxaccs[(batch, epsilon, f, attack, momentum, gar)] for epsilon in epsilons), gar_to_name.get(gar, gar.capitalize()), lalp=0.8)
              legend.append(gar_to_name.get(gar, gar.capitalize()))
            plot.finalize(None, "Epsilon", "Max cross-accuracy", xmin=min(epsilons), xmax=max(epsilons), ymin=0, ymax=1, legend=legend)
            plot.save(args.plot_directory / f"{ds}-{attack}-f_{f}-m_{params['momentum']}-at_{momentum}-b_{'incr' if batch is None else batch}-e_variable.png", xsize=2, ysize=1.5)
          for epsilon in epsilons:
            # Plot max cross-accuracy function of batch-size
            plot = study.LinePlot()
            plot.include_simple(((batch,) + maxaccs[(batch, epsilon, f)] for batch in batchs), "No attack", lalp=0.8, post=post_batchs)
            legend = ["No attack"]
            for gar in gars_mnist[:len(gars_mnist) - fm]:
              plot.include_simple(((batch,) + maxaccs[(batch, epsilon, f, attack, momentum, gar)] for batch in batchs), gar_to_name.get(gar, gar.capitalize()), lalp=0.8)
              legend.append(gar_to_name.get(gar, gar.capitalize()))
            plot.finalize(None, "Batch size", "Max cross-accuracy", xmin=min(batchs), xmax=max(batchs), ymin=0, ymax=1, legend=legend)
            plot.save(args.plot_directory / f"{ds}-{attack}-f_{f}-m_{params['momentum']}-at_{momentum}-b_variable-e_{'none' if epsilon is None else epsilon}.png", xsize=3, ysize=1.5)