"""Main experiment driver file."""
import json
import os
import random
from datetime import datetime
from functools import partialmethod

import jax
import jax.numpy as jnp
import numpy as np
from absl import app
from absl import flags
from absl import logging
from ml_collections import config_flags, ConfigDict
from tqdm import tqdm

import games
import manifold
import optimisation
import utils

FLAGS = flags.FLAGS
flags.DEFINE_bool('debug', False, 'debugging flag')
flags.DEFINE_bool('enable_x64', True, 'jax_enable_x64')
config_flags.DEFINE_config_file('config', 'config.py', 'training configuration')

jax.config.update('jax_numpy_rank_promotion', 'raise')
jax.config.parse_flags_with_absl()


def run_experiment(C: ConfigDict):
  """Run the experiment.

  Args:
    C: Configuration dictionary

  Returns:
    results: Dictionary of results
  """
  random.seed(C.seed)
  np.random.seed(C.seed)
  key = jax.random.PRNGKey(C.seed)
  rng = np.random.RandomState(C.seed)
  logging.info(f'Set random seed {C.seed}')

  logging.info(f'Setup game with {C.n_producers} producers and '
               f'{C.n_consumers} consumers in {C.dim} dimensions using '
               f'nonnegative: {C.nonnegative}...')
  key, gkey = jax.random.split(key)
  (producers, consumers), _ = games.game_from_config(C=C, key=gkey, rng=rng)

  logging.info(f'Using {C.utility} as utility with tau={C.tau}')
  logits_and_probs = games.get_logits_and_probs(consumers=consumers, tau=C.tau)
  utility = games.get_utility(consumers=consumers, tau=C.tau, version=C.utility)

  logging.info(f'Reparametrisation: {C.reparam}')
  reparam_fn = optimisation.get_reparam_fn(C.reparam)
  logging.info(f'Regularization for loss={C.regulariser}')
  reg_fn = optimisation.get_regulariser_fn(C.regulariser)
  _, second_order_checker = manifold.get_riemann_checkers(
    utility=utility, reparam_fn=reparam_fn, tol=1e-5, jit_compile=True,
    utility_type=C.utility, tau=C.tau, consumers=consumers,
    logits_and_probs=logits_and_probs)
  # checks criticality and local PNE using Riemannian grads and Hessians

  # create the 'inner-loop optimiser function' used within the dynamics below
  optimise = optimisation.optimiser_from_config(
    C=C, utility=utility, reparam_fn=reparam_fn, reg_fn=reg_fn,
    logits_and_probs=logits_and_probs, consumers=consumers)

  logging.info(f'Starting {C.dynamics} optimization')
  if C.dynamics == 'simultaneous':
    opt_param, aux = optimisation.simultaneous_ascent(optimise, param=producers)
  elif C.dynamics == 'sequential':
    opt_param, aux = optimisation.sequential_ascent(
      optimise, param=producers, n_rounds=C.n_outer_rounds, tol=C.tol)
  else:
    raise NotImplementedError(f'unknown dynamics "{C.dynamics}"')

  success, excess, losses = aux
  logging.info(f'Optimization finished successfully: {success}')
  logging.info(f'Optimization excess: {excess}')

  local_optimum, excess2nd = second_order_checker(opt_param, pid=None)

  logging.info(f'Compute optimal strategies and clusters...')
  opt_strategies = reparam_fn(opt_param)
  opt_strategies /= jnp.linalg.norm(opt_strategies, axis=-1, keepdims=True)

  clusters = utils.recursive_cluster(opt_strategies, max_iters=50, tol=1e-5)
  cluster_ids = utils.get_cluster_ids(opt_strategies, clusters)
  cluster_number, cluster_counts = np.unique(cluster_ids, return_counts=True)

  logging.info(f'{"Found" if local_optimum else "Did not find"} local optimum.')
  logging.info(f'L2 norms of (concat) Riemannian gradients: {excess2nd[0]}.')
  logging.info(f'Sum of positive eigenvalues: {excess2nd[1]}.')
  logging.info(f'Took {losses.shape[0]} steps.')
  logging.info(f'Cluster number: {cluster_number}.')
  logging.info(f'Cluster counts: {cluster_counts}.')

  results = {
    'found_local_opt': int(local_optimum),
    'excess': excess2nd,
    'steps': losses.shape[0],
    'num_clusters': len(cluster_counts),
    'cluster_numbers': cluster_number,
    'cluster_counts': cluster_counts,
    'cluster_ids': cluster_ids,
  }
  if C.log_optimal_strategies:
    results['opt_strategies'] = opt_strategies
  if C.log_consumers:
    results['consumer_vectors'] = consumers
  if C.log_loss:
    results['losses'] = losses

  return results


def main(_):
  jax.config.update('jax_enable_x64', FLAGS.enable_x64)
  C = FLAGS.config

  if C.disable_tqdm:
    tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

  if FLAGS.debug:
    logging.info('Running in debug mode.')

    C.n_consumers = 10
    C.n_producers = 3
    C.tau = 1e-1  # 1e-2
    C.dim = 5

    C.lr = 1e-1
    C.tol = 1e-8
    C.scale_lr_by_temperature = False

    C.n_inner_rounds = 10_000
    C.n_outer_rounds = 25
    C.optimiser = 'optax'
    C.dynamics = 'simultaneous'
    C.regulariser = None  # 1e-1
    C.reparam = 'rescale'

    C.em = False
    C.n_em_rounds = 25
    # jax.config.update('jax_disable_jit', True)  # useful for debug

  FLAGS.alsologtostderr = True
  output_name = C.output_name

  seed = C.seed
  if C.try_seed_from_slurm:
    logging.info("Trying to fetch seed from slurm environment variable...")
    slurm_seed = os.getenv("SLURM_ARRAY_TASK_ID")
    if slurm_seed is not None:
      logging.info(f"Found task id {slurm_seed}")
      seed = int(slurm_seed)
      output_name = f'{output_name}_{seed}'
      logging.info(f"Set output directory {output_name}")
      C.output_name = output_name
  C.seed = seed

  if output_name == "":
    dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  else:
    dir_name = output_name
  out_dir = os.path.join(os.path.abspath(C.output_dir), dir_name)
  logging.info(f"Save all output to {out_dir}...")
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  logging.info("Save FLAGS (arguments)...")
  with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
    json.dump(C.to_dict(), fp, sort_keys=True, indent=2)

  results = run_experiment(C=C)

  logging.info(f"Store results...")
  result_path = os.path.join(out_dir, "results.npz")
  np.savez(result_path, **results)

  logging.info(f"DONE!")


if __name__ == '__main__':
  app.run(main)
