"""Main experiment driver file for data experiments."""
import json
import os
import random
from datetime import datetime
from functools import partialmethod

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from tqdm import tqdm
from absl import app
from absl import flags
from absl import logging
from ml_collections import config_flags, ConfigDict

import games
import manifold
import optimisation
import recsys_eval
import utils

FLAGS = flags.FLAGS
flags.DEFINE_bool('debug', False, 'Debugging flag.')
flags.DEFINE_bool('enable_x64', True, 'jax_enable_x64')
config_flags.DEFINE_config_file(
  'config', 'recsys_config.py', 'training configuration', lock_config=False)

jax.config.update('jax_numpy_rank_promotion', 'raise')
jax.config.parse_flags_with_absl()


def experiment_fixed_recsys(
    C, consumers, base_producers, data_game, sample_producers):
  # utility and loss functions + optimiser
  logging.info(f"Setup logits and probabilities for tau={C.tau}...")
  logits_and_probs = games.get_logits_and_probs(consumers=consumers, tau=C.tau)

  logging.info(f"Setup utility {C.utility} with tau={C.tau}...")
  utility = games.get_utility(consumers=consumers, tau=C.tau, version=C.utility)

  logging.info(f"Using reparameterization {C.reparam}...")
  reparam_fn = optimisation.get_reparam_fn(C.reparam)

  logging.info(f"Using regularization {C.regulariser}...")
  reg_fn = optimisation.get_regulariser_fn(C.regulariser)

  logging.info(f"Setup optimizer...")
  optimise = optimisation.optimiser_from_config(
    C=C, utility=utility, reparam_fn=reparam_fn, reg_fn=reg_fn,
    logits_and_probs=logits_and_probs, consumers=consumers)

  _, second_order_checker = manifold.get_riemann_checkers(
    utility=utility, reparam_fn=reparam_fn, tol=1e-5, jit_compile=True,
    utility_type=C.utility, tau=C.tau, consumers=consumers,
    logits_and_probs=logits_and_probs)

  def ratings_and_probs(s):  # compute ratings and recommendation probabilities
    pred_ratings = consumers @ s.T  # TODO: hard-coded assumption!
    _, rec_probs = logits_and_probs(s)
    return pred_ratings, rec_probs

  if C.dataset == 'ml-100k':
    logging.info("Running experiment on ml-100k dataset")
    group_info = np.array(  # extract info gender info for all users
      [data_game.get_info('consumer', i).gender for i in range(len(consumers))])

    iter_stats = base_stats = recsys_eval.all_stats_ml100k
  elif C.dataset == 'lastfm-360k':
    logging.info("Running experiment on lastfm-360k dataset")
    group_info = np.array([
      data_game.get_info('producer', i).strict_gender
      for i in range(len(base_producers))])

    base_stats = recsys_eval.base_stats_lastfm360
    iter_stats = recsys_eval.iter_stats_lastfm360
  else:
    raise NotImplementedError(
      f'evaluation not implemented for dataset "{C.dataset}"')

  logging.info("Evaluating results for producers before strategic adaptation")
  pred_ratings, rec_probs = ratings_and_probs(base_producers)
  base_results = base_stats(
    group_info=group_info, offer=base_producers, consumers=consumers,
    ratings=pred_ratings, probs=rec_probs, label='base', normalise=False)
  base_results_norm = base_stats(
    group_info=group_info, offer=base_producers, consumers=consumers,
    ratings=pred_ratings, probs=rec_probs, label='base', normalise=True)

  results_norm = {}  # general iter stats only saved in standard `results`
  results = {k: [] for k in (
    'success', 'lne', 'jump_count', 'jump_max', 'excess_grad', 'excess_hess',
    'cluster_count', 'cluster_max', 'cluster_min')}
  for seed in range(C.n_repeats_per_recsys):
    logging.info(f"---------- Repetition {seed} ----------")
    key = jax.random.PRNGKey(seed)

    # sample new initial locations
    key, gkey = jax.random.split(key)
    param = sample_producers(gkey, n_producers=C.n_producers)

    # optimise the producer locations
    if C.dynamics == 'simultaneous':
      logging.info("Starting simultaneous optimization")
      opt_param, aux = optimisation.simultaneous_ascent(optimise, param=param)
    elif C.dynamics == 'sequential':
      logging.info("Starting sequential optimization")
      opt_param, aux = optimisation.sequential_ascent(
        optimise, param=param, n_rounds=C.n_outer_rounds, tol=C.tol)
    else:
      raise NotImplementedError(f'unknown dynamics "{C.dynamics}"')

    # evaluate
    success, _, _ = aux
    opt_strategies = reparam_fn(opt_param)
    opt_strategies /= jnp.linalg.norm(opt_strategies, axis=-1, keepdims=True)
    local_optimum, excess2nd = second_order_checker(opt_param, pid=None)

    logging.info(f"Finished {'un' if not success else ''}successfully")

    # compute collapsed cluster centers
    logging.info("Compute clusters...")
    clusters = utils.recursive_cluster(opt_strategies, max_iters=50, tol=1e-5)
    clusters /= jnp.linalg.norm(clusters, ord=2, axis=-1, keepdims=True)
    cluster_ids = utils.get_cluster_ids(opt_strategies, clusters)
    _, cluster_counts = np.unique(cluster_ids, return_counts=True)

    # check for existence of strategic cluster jumps
    logging.info("Check for strategic cluster jumps...")
    jump_diffs = np.array(list(utils.utility_jumps(
      utility, strategies=opt_strategies, clusters=clusters,
      cluster_ids=cluster_ids).values()))
    max_jump = jump_diffs.max().item()
    jump_count = np.sum(np.max(jump_diffs, axis=-1) > C.tol).item()

    logging.info("Evaluate and collect results...")
    # evaluate the results
    results['success'].append(success)
    results['lne'].append(local_optimum.item())
    results['excess_grad'].append(excess2nd[0].item())
    results['excess_hess'].append(excess2nd[1].item())
    results['cluster_count'].append(len(clusters))
    results['cluster_max'].append(cluster_counts.max().item())
    results['cluster_min'].append(cluster_counts.min().item())
    results['jump_count'].append(jump_count)
    results['jump_max'].append(max_jump)

    offer_util = utility(opt_strategies)
    pred_ratings, rec_probs = ratings_and_probs(opt_strategies)
    for norm, dct in zip((False, True), (results, results_norm)):
      iter_results = iter_stats(
        group_info=group_info, offer=opt_strategies, consumers=consumers,
        ratings=pred_ratings, probs=rec_probs, label='opt', normalise=norm,
        offer_util=offer_util, base_offer=base_producers)

      for k in iter_results.keys():
        dct.setdefault(k, []).append(iter_results[k])

  return (results, results_norm), (base_results, base_results_norm)


def run_experiment(C: ConfigDict):
  if not C.dataset in ('ml-100k', 'lastfm-360k'):
    raise NotImplementedError(f'not implemented for dataset "{C.dataset}"')
  logging.info(f"Settings base seed {C.seed}...")
  random.seed(C.seed)
  np.random.seed(C.seed)
  rng = np.random.RandomState(C.seed)
  base_key = jax.random.PRNGKey(C.seed)

  iters = (iter_results, iter_results_norm) = {'rid': []}, {'rid': []}
  bases = (base_results, base_results_norm) = {'rid': []}, {'rid': []}
  for rid in range(C.n_recsys_reinits):
    logging.info(f'==================== Recommender {rid} ====================')

    base_key, gkey = jax.random.split(base_key)
    logging.info(f'Setup game...')
    (_, consumers), aux = games.game_from_config(key=gkey, rng=rng, C=C)
    data_game, base_producers = aux['game'], aux['base_producers']
    sample_producers = aux['sample_producers']

    logging.info(f'Launch experiment...')
    r_iters, r_bases = experiment_fixed_recsys(
      C=C, consumers=consumers, base_producers=base_producers,
      data_game=data_game, sample_producers=sample_producers)

    logging.info(f'Collect results and put them into dataframe...')
    for d_iter, d_base, r_iter, r_base in zip(iters, bases, r_iters, r_bases):
      d_base['rid'].append(rid)
      d_iter['rid'].extend([rid] * C.n_repeats_per_recsys)
      for k, v in r_base.items():
        d_base.setdefault(k, []).append(v)
      for k, v in r_iter.items():
        d_iter.setdefault(k, []).extend(v)

  return ((pd.DataFrame(iter_results), pd.DataFrame(iter_results_norm)),
          (pd.DataFrame(base_results), pd.DataFrame(base_results_norm)))


def main(_):
  jax.config.update('jax_enable_x64', FLAGS.enable_x64)
  C = FLAGS.config

  if C.disable_tqdm:
    tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

  if FLAGS.debug:
    logging.info('Running in debug mode')

    # C.dataset = 'ml-100k'
    C.dataset = 'lastfm-360k'
    C.recommender = 'nmf'
    # C.recommender = 'mf-biased'
    C.random_start = True
    C.normalize_consumers = False

    C.n_recsys_reinits = 3
    C.n_repeats_per_recsys = 2

    if C.dataset == 'random':
      C.n_consumers = 10
      C.n_producers = 3
    else:
      C.n_consumers = None  # `None` w/ recsys dataset takes all users
      C.n_producers = 10  # `None` w/ recsys dataset takes all producers
    C.tau = 1e-1
    C.dim = 5

    C.lr = 1e-1
    C.tol = 1e-8
    C.scale_lr_by_temperature = False
    C.n_inner_rounds = 10_000
    C.n_outer_rounds = 25
    C.optimiser = 'optax'
    C.dynamics = 'simultaneous'
    C.regulariser = None  # 1e-1
    C.reparam = 'rescale'

    # jax.config.update('jax_disable_jit', True)  # useful for debug

  FLAGS.alsologtostderr = True
  output_name = C.output_name

  seed = C.seed
  if C.try_seed_from_slurm:
    logging.info("Trying to fetch seed from slurm environment variable...")
    slurm_seed = os.getenv("SLURM_ARRAY_TASK_ID")
    if slurm_seed is not None:
      logging.info(f"Found task id {slurm_seed}")
      seed = int(slurm_seed)
      output_name = f'{output_name}_{seed}'
      logging.info(f"Set output directory {output_name}")
      C.output_name = output_name
  C.seed = seed

  if output_name == "":
    dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  else:
    dir_name = output_name
  out_dir = os.path.join(os.path.abspath(C.output_dir), dir_name)
  logging.info(f"Save all output to {out_dir}...")
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  logging.info("Save FLAGS (arguments)...")
  with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
    json.dump(C.to_dict(), fp, sort_keys=True, indent=2)

  iter_results, base_results = run_experiment(C=C)

  logging.info(f"Store results...")
  iter_results[0].to_csv(os.path.join(out_dir, 'iter.csv'))
  base_results[0].to_csv(os.path.join(out_dir, 'base.csv'))
  iter_results[1].to_csv(os.path.join(out_dir, 'iter_norm.csv'))
  base_results[1].to_csv(os.path.join(out_dir, 'base_norm.csv'))

  logging.info(f"DONE!")


if __name__ == '__main__':
  app.run(main)
