# from __future__ import print_function
# python run_mocwd.py --dataset residuals --pop_size 20 --generations 20 --number_of_runs 1 --seed 42 --epochs 10 --batch_size 32 --max_conv_layers 20 --exp_name residuals20p20g10e20mcl_08_03_2022
# python run_mocwd.py --dataset sixray --pop_size 10 --generations 20 --number_of_runs 1 --seed 42 --epochs 5 --batch_size 16 --max_conv_layers 20 --exp_name sixray20p20g10e20mcl_14_03 --synflow True
# python run_mocwd.py --exp_name 02042022-residuals-20-20 --dataset residuals --batch_size 16 --generations 10 --pop_size 10 --number_of_runs 1 --seed 42 --type_problem classification --max_conv_layers 25
# RAN ON 16052022 python run_mocwd.py --exp_name 02042022-residuals-20-20 --dataset residuals --batch_size 16 --generations 50 --pop_size 20 --number_of_runs 1 --seed 42 --type_problem classification --max_conv_layers 25 --multi_proxy True --naswot True --synflow True
# RAN ON 22052022 python run_mocwd_ae.py --dataset sixray_ae --pop_size 10 --batch_size 16 --generations 10 --number_of_runs 1 --exp_name sixray_ae_24_05 --type_problem ae --epochs 1 --max_conv_layers 30 
# RAN ON 08062022 python run_mocwd_ae.py --dataset sixray_ae --pop_size 10 --batch_size 16 --generations 10 --number_of_runs 1 --exp_name sixray_ae_06_06 --type_problem ae --epochs 1 --max_conv_layers 35
import sys
sys.path.append("..")

import operator
import random
import pickle
import math
import traceback
import os
import argparse
import numpy as np
from ne import PymooGenomeReduced


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
tf.autograph.set_verbosity(1)

#Used for calculating contributing HVI and normalised CHVI
from auxiliary.chvi import calculate_contrib_hvi, calculate_normalised_contrib_hvi

from datasets import get_dataset


from tensorflow.keras import backend as K

parser = argparse.ArgumentParser()
parser.add_argument("--dataset")
parser.add_argument("--pop_size")
parser.add_argument('--batch_size')
parser.add_argument("--generations")
parser.add_argument("--number_of_runs")
parser.add_argument("--fitness_fn")
parser.add_argument("--seed")
parser.add_argument("--pss")
parser.add_argument('--nasbench')
parser.add_argument('--type_problem')
parser.add_argument('--exp_name')
parser.add_argument('--naswot')
parser.add_argument('--epochs')
parser.add_argument('--max_conv_layers')
parser.add_argument('--synflow')
parser.add_argument('--multi_proxy')
args = parser.parse_args()

METRIC_OPS = [operator.__lt__, operator.__gt__]
METRIC_OBJECTIVES = [min, max]


BATCH_NORMALIZATION = False


DATASET = args.dataset
TYPE_PROBLEM = 'ae'
BATCH_SIZE = 0
EXPERIMENT_NAME = 'exp404'
NASWOT = False
SYNFLOW = False
MULTI_PROXY = False
EPOCHS = 3
MAX_CONV_LAYERS = 20


if(args.naswot):
  NASWOT = True
else:
  NASWOT = False

if(args.multi_proxy):
  MULTI_PROXY = True
else:
  MULTI_PROXY = False


if(args.synflow):
  SYNFLOW = True
else:
  SYNFLOW = False

if (args.type_problem):
  TYPE_PROBLEM = args.type_problem

# TODO make as user input
if(TYPE_PROBLEM =='ae'):
      BATCH_NORMALIZATION = True


if(args.pop_size):
  POP_SIZE = int(args.pop_size)
else:
  POP_SIZE = 20

if(args.pop_size):
  BATCH_SIZE = int(args.batch_size)
else:
  BATCH_SIZE = 32

if(args.generations):
  GENERATIONS = int(args.generations)
else:
  GENERATIONS = 20
  
if(args.number_of_runs):
  NUMBER_OF_RUNS = int(args.number_of_runs)
else:
  NUMBER_OF_RUNS = 1
if(args.fitness_fn):
  FITNES_FN = args.fitness_fn
else:
  FITNES_FN = 'CHVI'
if(args.seed):
  SEED = int(args.seed)
else:
  SEED = random.randint(1,100)

if(args.epochs):
      EPOCHS = int(args.epochs)
if(args.max_conv_layers):
  MAX_CONV_LAYERS = int(args.max_conv_layers)

if(args.pss):
  PSS = int(args.pss)
else:
  PSS = False
if(args.nasbench):
  NASBENCH = True
else:
  NASBENCH = False

if(args.exp_name):
  EXPERIMENT_NAME = args.exp_name
else:
  raise Exception('Please specify experiment name to avoid mess!')
print(f'==== RUNNING EXPERIMENT {EXPERIMENT_NAME} =====')


# Verbosity is now 0

physical_devices = tf.config.experimental.list_physical_devices('GPU')
try:
  tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
  print("Invalid device or cannot modify virtual devices once initialized.")
  pass


# **Prepare dataset**



fitnes_fns = dict()
fitnes_fns['CHVI'] = calculate_contrib_hvi
fitnes_fns['CHVI_norm'] = calculate_normalised_contrib_hvi

dataset, input_shape,n_classes, TRAIN_WITH_GEN, TRAIN_WITH_LOGITS, batch_size, normalize, multilabel = get_dataset(DATASET,batch_size=BATCH_SIZE, PSS=PSS)
seed = SEED
seeds_used = []
result = None

print('RUNNING')

for run_num in range(NUMBER_OF_RUNS):
  seeds_used.append(seed)
  np.random.seed(0)
  np.random.seed(seed)
  random.seed(seed)
  tf.random.set_seed(seed)
    # 6->10
    # 256 -> 512
  problem = PymooGenomeReduced(max_conv_layers=MAX_CONV_LAYERS, 
                                  max_dense_layers=0,
                                  max_nodes=256,
                                  max_filters=256,
                                  input_shape=input_shape,
                                  n_classes=n_classes,
                                  dropout=False,
                                  type_problem=TYPE_PROBLEM,
                                  batch_size = BATCH_SIZE,
                                  TRAIN_WITH_LOGITS = TRAIN_WITH_LOGITS,
                                  batch_normalization=BATCH_NORMALIZATION,
                                  NASWOT=NASWOT,
                                  SYNFLOW=SYNFLOW,
                                  MULTI_PROXY = MULTI_PROXY,
                                  # Added for 0806 ae experiment
                                  min_downsample_rate = 64
                                  )
  problem.feed_data(
        train_with_gen=TRAIN_WITH_GEN,
        dataset = dataset,
        num_generations=GENERATIONS,
        pop_size = POP_SIZE,
        pss = PSS,
        multilabel=True,
        metric = 'loss',
        batch_size = BATCH_SIZE,
        # TODO ADD AS PARAM TRUE FOR SIXRAY FALSE FOR RESIDUALS
        gen_to_tf_data = True,
        epochs=EPOCHS,
        normalize = normalize)
# from pymoo.algorithms.moead import MOEAD
  from pymoo.algorithms.moo.age import AGEMOEA
  from pymoo.algorithms.moo.nsga2 import NSGA2
  from pymoo.factory import get_sampling, get_crossover, get_mutation,get_reference_directions, get_selection,get_performance_indicator

  ref_dirs = get_reference_directions("das-dennis", 2, n_partitions=10)
  # from pymoo.algorithms.nsga3 import NSGA3
  # algorithm = AGEMOEA(
  #       pop_size=POP_SIZE,
  #       sampling=get_sampling("int_random"),
  #       crossover=get_crossover("int_k_point", n_points=3,prob=0.9),
  #       mutation=get_mutation("int_pm",eta=0.01, prob=0.05)
  #   )
  algorithm = NSGA2(
        pop_size=POP_SIZE,
        n_offsprings=None,
        sampling=get_sampling("int_random"),
        crossover=get_crossover("int_k_point", n_points=3,prob=0.9),
        mutation=get_mutation("int_pm",eta=0.01, prob=0.05),
        eliminate_duplicates=True
    )


  from pymoo.factory import get_termination

  termination = get_termination("n_gen", GENERATIONS)

  from pymoo.optimize import minimize

  res = None
  try:
    res = minimize(problem,
                    algorithm,
                    termination,
                    seed=SEED,
                    save_history=True,
                    verbose=True)
    print('==========================')
    print(res.F)
    print('==========================')
    print(res.X)
    print('==========================')
  except Exception as error:
      traceback.print_exc()
      print(error)
      pass

  
  import pickle
  
  with open(r"l{}-{}-F.pkl".format(DATASET,EXPERIMENT_NAME),'wb') as f:
      pickle.dump(res.F,f)

  with open(r"l{}-{}-X.pkl".format(DATASET,EXPERIMENT_NAME),'wb') as f:
      pickle.dump(res.X,f)
  
  with open('seeds_used_{}.pkl'.format(EXPERIMENT_NAME), 'wb') as f:
      pickle.dump(seeds_used, f)



  from pymoo.visualization.scatter import Scatter

  # get the pareto-set and pareto-front for plotting
  ps = problem.pareto_set(use_cache=False, flatten=False)
  pf = problem.pareto_front(use_cache=False, flatten=False)

  # # Design Space
  # plot = Scatter(title = "Design Space", axis_labels="x")
  # plot.add(np.array([ind.data[''] for ind in res.pop]), s=10, facecolors='none', edgecolors='r')
  # if ps is not None:
  #     plot.add(ps, plot_type="line", color="black", alpha=0.7)
  # plot.do()
  # plot.apply(lambda ax: ax.set_xlim(-10, 10))
  # plot.apply(lambda ax: ax.set_ylim(-10, 10))
  # plot.show()

  # Objective Space
  plot = Scatter(title = "Objective Space")
  plot.add(res.F)
  if pf is not None:
      plot.add(pf, plot_type="line", color="black", alpha=0.7)
  plot.show()



  n_evals = []    # corresponding number of function evaluations\
  F = []          # the objective space values in each generation
  cv = []         # constraint violation in each generation


  # iterate over the deepcopies of algorithms
  for algorithm in res.history:

      # store the number of function evaluations
      n_evals.append(algorithm.evaluator.n_eval)

      # retrieve the optimum from the algorithm
      opt = algorithm.opt

      # store the least contraint violation in this generation
      cv.append(opt.get("CV").min())

      # filter out only the feasible and append
      feas = np.where(opt.get("feasible"))[0]
      _F = opt.get("F")[feas]
      F.append(_F)

  import matplotlib.pyplot as plt

  k = min([i for i in range(len(cv)) if cv[i] <= 0])
  first_feas_evals = n_evals[k]
  print(f"First feasible solution found after {first_feas_evals} evaluations")

  plt.plot(n_evals, cv, '--', label="CV")
  plt.scatter(first_feas_evals, cv[k], color="red", label="First Feasible")
  plt.xlabel("Function Evaluations")
  plt.ylabel("Constraint Violation (CV)")
  plt.legend()
  plt.show()

  import matplotlib.pyplot as plt
  # from pymoo.performance_indicator.hv import Hypervolume



  # MODIFY - this is problem dependend
  ref_point = np.array([1.0, 1.0,1.0])

  # create the performance indicator object with reference point
  metric = get_performance_indicator("hv", ref_point=ref_point)
  # calculate for each generation the HV metric
  hv = [metric.do(f) for f in F]

  # visualze the convergence curve
  plt.plot(n_evals, hv, '-o', markersize=4, linewidth=2)
  plt.title("Convergence")
  plt.xlabel("Function Evaluations")
  plt.ylabel("Hypervolume")
  plt.show()

  from pymoo.util.running_metric import RunningMetric

  running = RunningMetric(delta_gen=GENERATIONS/2,
                          n_plots=2,
                          only_if_n_plots=True,
                          key_press=False,
                          do_show=True)

  for algorithm in res.history[:GENERATIONS]:
      running.notify(algorithm)
