import torch, math, argparse, sys, os
import study, pandas, numpy, tools, pathlib

# ---------------------------------------------------------------------------- #
#JS: Functions used for experiments in the reproduce scripts

#JS: Function used to create the directories needed to store ther results in the reproduce scripts
def check_make_dir(path):
  path = pathlib.Path(path)
  if path.exists():
    if not path.is_dir():
      tools.fatal(f"Given path {str(path)!r} must point to a directory")
  else:
    path.mkdir(mode=0o755, parents=True)
  return path

#JS: Function used to parse the command-line and perform checks in the reproduce scripts
def process_commandline():
  parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
  parser.add_argument("--result-directory",
    type=str,
    default="results-data",
    help="Path of the data directory, containing the data gathered from the experiments")
  parser.add_argument("--plot-directory",
    type=str,
    default="results-plot",
    help="Path of the plot directory, containing the graphs traced from the experiments")
  parser.add_argument("--devices",
    type=str,
    default="auto",
    help="Comma-separated list of devices on which to run the experiments, used in a round-robin fashion")
  parser.add_argument("--supercharge",
    type=int,
    default=1,
    help="How many experiments are run in parallel per device, must be positive")
  # Parse command line
  return parser.parse_args(sys.argv[1:])


#JS: Function used to plot the results of the experiments in the reproduce scripts
def compute_avg_err_op(name, seeds, result_directory, location, *colops, avgs="", errs="-err"):
  """ Compute the average and standard deviation of the selected columns over the given experiment.
  Args:
    name Given experiment name
    seeds   Seeds used for the experiment
    result_directory Directory to store the results
    location Script to read from
    ...  Tuples of (selected column name (through 'study.select'), optional reduction operation name)
    avgs Suffix for average column names
    errs Suffix for standard deviation (or "error") column names
  Returns:
    Data frames for each of the computed columns,
    Tuple of reduced values per seed (or None if None was provided for 'op')
  Raises:
    'RuntimeError' if a reduction operation was specified for a column selector that did not select exactly 1 column
  """
# Load all the runs for the given experiment name, and keep only a subset
  datas = tuple(study.select(study.Session(result_directory + "/" + name + "-" +str(seed), location), *(col for col, _ in colops)) for seed in seeds)

  # Make the aggregated data frames
  def make_df_ro(col, op):
    nonlocal datas
    # For every selected columns
    subds = tuple(study.select(data, col).dropna() for data in datas)
    df    = pandas.DataFrame(index=subds[0].index)
    ro    = None
    for cn in subds[0]:
      # Generate compound column names
      avgn = cn + avgs
      #avgn = "mean_acc"
      errn = cn + errs
      #errn = "error_acc"
      # Compute compound columns
      numds = numpy.stack(tuple(subd[cn].to_numpy() for subd in subds))
      df[avgn] = numds.mean(axis=0)
      df[errn] = numds.std(axis=0)
      # Compute reduction, if requested
      if op is not None:
        if ro is not None:
          raise RuntimeError(f"column selector {col!r} selected more than one column ({(', ').join(subds[0].columns)}) while a reduction operation was requested")
        ro = tuple(getattr(subd[cn], op)().item() for subd in subds)
    # Return the built data frame and optional computed reduction
    return df, ro
  dfs = list()
  ros = list()
  for col, op in colops:
    df, ro = make_df_ro(col, op)
    dfs.append(df)
    ros.append(ro)
  # Return the built data frames and optional computed reductions
  return dfs


# ---------------------------------------------------------------------------- #
#JS: Functions used for dataset manipulation in dataset.py

#JS: Lazy-initialize and return the default dataset root directory path
def get_default_root():
    # Generate the default path
    default_root = pathlib.Path(__file__).parent / "datasets" / "cache"
    # Create the path if it does not exist
    default_root.mkdir(parents=True, exist_ok=True)
    # Return the path
    return default_root

#JS: Returns the indices of the training datapoints selected for each honest worker, in case of Dirichlet distribution
def draw_indices(samples_distribution, indices_per_label, nb_workers):
    
    #JS: Initialize the dictionary of samples per worker. Should hold the indices of the samples each worker possesses
    worker_samples = dict()
    for worker in range(nb_workers):
        worker_samples[worker] = list()

    for label, label_distribution in enumerate(samples_distribution):
        last_sample = 0
        number_samples_label = len(indices_per_label[label])
        #JS: Iteratively split the number of samples of label into chunks according to the worker proportions, and assign each chunk to the corresponding worker
        for worker, worker_proportion in enumerate(label_distribution):
            samples_for_worker = int(worker_proportion * number_samples_label)
            worker_samples[worker].extend(indices_per_label[label][last_sample:last_sample+samples_for_worker])
            last_sample = samples_for_worker

    return worker_samples


# ---------------------------------------------------------------------------- #
#JS: Functions used in train.py and train_p2p.py

#JS: Store a result in the corresponding list result file.
def store_result(fd, *entries):
	"""
	Args:
		fd     Descriptor of the valid result file
		entries... Object(s) to convert to string and write in order in a new line
	"""
	fd.write(os.linesep + ("\t").join(str(entry) for entry in entries))
	fd.flush()

#JS: Create the results file.
def make_result_file(fd, *fields):
	"""
	Args:
		fd     Descriptor of the valid result file
		entries... Object(s) to convert to string and write in order in a new line
	"""
	fd.write("# " + ("\t").join(str(field) for field in fields))
	fd.flush()

#JS: Print the configuration of the current training in question
def print_conf(subtree, level=0):
  if isinstance(subtree, tuple) and len(subtree) > 0 and isinstance(subtree[0], tuple) and len(subtree[0]) == 2:
    label_len = max(len(label) for label, _ in subtree)
    iterator  = subtree
  elif isinstance(subtree, dict):
    if len(subtree) == 0:
      return " - <none>"
    label_len = max(len(label) for label in subtree.keys())
    iterator  = subtree.items()
  else:
    return f" - {subtree}"
  level_spc = "  " * level
  res = ""
  for label, node in iterator:
    res += f"{os.linesep}{level_spc}· {label}{' ' * (label_len - len(label))}{print_conf(node, level + 1)}"
  return res


# ---------------------------------------------------------------------------- #
#JS: Criterions to evaluate accuracy of models. Used in worker.py and p2pWorker.py

def topk(output, target, k=1):
      """ Compute the top-k criterion from the output and the target.
      Args:
        output Batch × model logits
        target Batch × target index
      Returns:
        1D-tensor [#correct classification, batch size]
      """
      res = (output.topk(k, dim=1)[1] == target.view(-1).unsqueeze(1)).any(dim=1).sum()
      return torch.cat((res.unsqueeze(0), torch.tensor(target.shape[0], dtype=res.dtype, device=res.device).unsqueeze(0)))


def sigmoid(output, target):
      """ Compute the sigmoid criterion from the output and the target.
      Args:
        output Batch × model logits (expected in [0, 1])
        target Batch × target index (expected in {0, 1})
      Returns:
        1D-tensor [#correct classification, batch size]
      """
      correct = target.sub(output).abs_() < 0.5
      res = torch.empty(2, dtype=output.dtype, device=output.device)
      res[0] = correct.sum()
      res[1] = len(correct)
      return res

# ---------------------------------------------------------------------------- #
#JS: Functions for manipulating gradients and model parameters

#JS: Clip the vector if its L2 norm is greater than clip_threshold
def clip_vector(vector, clip_threshold):
    vector_norm = vector.norm().item()
    if vector_norm > clip_threshold:
        vector.mul_(clip_threshold / vector_norm)
    return vector

#JS: flatten list of tensors. Used for model parameters and gradients
def flatten(list_of_tensors):
    return torch.cat(tuple(tensor.view(-1) for tensor in list_of_tensors))

#JS: unflatten a flat tensor. Used when setting model parameters and gradients
def unflatten(flat_tensor, model_shapes):
    c = 0
    returned_list = [torch.zeros(shape) for shape in model_shapes]
    for i, shape in enumerate(model_shapes):
        count = 1
        for element in shape:
            count *= element
        returned_list[i].data = flat_tensor[c:c + count].view(shape)
        c = c + count
    return returned_list

# ---------------------------------------------------------------------------- #
#JS: Functions for Byzantine attacks

#JS: used for Auto ALIE and Auto FOE
def line_maximize(scape, evals=16, start=0., delta=1., ratio=0.8):
  """ Best-effort arg-maximize a scape: ℝ⁺⟶ ℝ, by mere exploration.
  Args:
    scape Function to best-effort arg-maximize
    evals Maximum number of evaluations, must be a positive integer
    start Initial x evaluated, must be a non-negative float
    delta Initial step delta, must be a positive float
    ratio Contraction ratio, must be between 0.5 and 1. (both excluded)
  Returns:
    Best-effort maximizer x under the evaluation budget
  """
  # Variable setup
  best_x = start
  best_y = scape(best_x)
  evals -= 1
  # Expansion phase
  while evals > 0:
    prop_x = best_x + delta
    prop_y = scape(prop_x)
    evals -= 1
    # Check if best
    if prop_y > best_y:
      best_y = prop_y
      best_x = prop_x
      delta *= 2
    else:
      delta *= ratio
      break
  # Contraction phase
  while evals > 0:
    if prop_x < best_x:
      prop_x += delta
    else:
      x = prop_x - delta
      while x < 0:
        x = (x + prop_x) / 2
      prop_x = x
    prop_y = scape(prop_x)
    evals -= 1
    # Check if best
    if prop_y > best_y:
      best_y = prop_y
      best_x = prop_x
    # Reduce delta
    delta *= ratio
  # Return found maximizer
  return best_x