import json
import os

from absl import app
from absl import flags
from absl import logging

from datetime import datetime

import numpy as np

from agents import (
  GreedyRidge, LinUCB, LinUniform, SquareCB, PolicyGradient, DelayedAgent,
  TwoStage)
from bandits import (
  LinGaussBandit, LinBernBandit, LinCategBandit, LinPoissBandit, WikiBandit, AmazonBandit)
from games import ContextBanditGame
import plotting
from utils import get_ctxt_sampler


# ------------------------------- RUN PARAMETERS ------------------------------
flags.DEFINE_integer("n_rounds", 1000,
                     "The number of timesteps to run.")
flags.DEFINE_bool("batch_mode", False,
                  "Whether to run in batch mode. Currently only supported for "
                  "single_stage systems.")
flags.DEFINE_integer("batch_update_freq", -1,
                     "At what frequency to perform updates in batch_mode. "
                     "If negative, update ten times in total, i.e., "
                     "batch_update_freq = n_rounds // 10")
# ------------------------------- AGENTS --------------------------------------
flags.DEFINE_enum("stages", "single_stage", ("two_stage", "single_stage"),
                  "Which stages to run.")
flags.DEFINE_enum("agent_ranker", "ucb", ("ucb", "pg", "greedy", "sqcb", "random"),
                  "Which agent to use for the ranker.")
flags.DEFINE_enum("agent_nominators", "ucb", ("ucb", "pg", "greedy", "sqcb"),
                  "Which agent to use for the nominators. Only relevant when "
                  "`stages` is `two_stage`.")
flags.DEFINE_integer("n_nominators", 2, "How many nominators to use.")
# ------------------------------- BANDIT --------------------------------------
flags.DEFINE_enum("bandit_type", "gauss",
                  ("gauss", "bernoulli", "categorical", "poisson", 'wiki', 'amazon'),
                  "Reward distribution of the linear contextual bandit.")
flags.DEFINE_float("reward_std", 0.1, "The standard deviation of the reward "
                                      "distribution.")
flags.DEFINE_integer("n_arms", 100,
                     "Dimensionality of features (for ranker).")
flags.DEFINE_bool("split_arms", True,
                  "Whether to split the arms across the different nominators.")
# ------------------------------- CONTEXTS ------------------------------------
flags.DEFINE_integer("n_contexts", 10,
                     "Number of contexts.")
flags.DEFINE_integer("n_features", 10,
                     "Number of features (for ranker).")
flags.DEFINE_integer("n_nom_features", 5,
                     "Number of features for nominators.")
flags.DEFINE_enum("dtype", "unif", ("mix", "unif"),
                  "Whether ground truth regression vector is chosen "
                  "uniformly at random or as mixture.")
flags.DEFINE_float("mixture_std", 0.001,
                   "Standard deviation of the Gaussian perturbing the "
                   "regression vector in the mixture model.")
flags.DEFINE_bool("split_features", True,
                  "Whether to split the features across the nominators.")
flags.DEFINE_float("context_scaling", -1.,
                   "Scaling factor for contexts. All contexts are normalized "
                   "by default first and then scaled by this factor. Negative "
                   "means scaling by sqrt(d), which amounts to whitening.")
# ------------------------------- HYPERPARAMETERS ---------------------------
# ------------------- PG
flags.DEFINE_float("pg_lr_ranker", 1,
                   "Policy gradient learning rate for ranker.")
flags.DEFINE_float("pg_lr_nominators", 1,
                   "Policy gradient learning rate for nominators.")
flags.DEFINE_bool("pg_random_greedy_ranker", False,
                  "Whether to choose greedily from softmax in ranker pg.")
flags.DEFINE_bool("pg_random_greedy_nominators", False,
                  "Whether to choose greedily from softmax in nominator pg.")
flags.DEFINE_float("pg_eps_scaling_ranker", 0.0,
                   "Epsilon scaling for policy gradient of ranker.")
flags.DEFINE_float("pg_eps_scaling_nominators", 0.0,
                   "Epsilon scaling for policy gradient of nominators.")
# ------------------- SQCB
flags.DEFINE_float("sqcb_lr_ranker", 10,
                   "Learning rate scaling for ranker.")
flags.DEFINE_float("sqcb_lr_nominators", 10,
                   "Learning rate scaling for nominators.")
# ------------------- UCB
flags.DEFINE_float("ucb_regularizer_ranker", 0.1,
                   "Initial regularization parameter for UCB ranker.")
flags.DEFINE_float("ucb_regularizer_nominators", 0.1,
                   "Initial regularization parameter for UCB nominators.")
flags.DEFINE_bool("fixed_ucb_alpha_ranker", True,
                  "Use fixed alpha for the ranker instead of beta.")
flags.DEFINE_bool("fixed_ucb_alpha_nominators", True,
                  "Use fixed alpha for the nominators instead of beta.")
flags.DEFINE_float("ucb_alpha_ranker", 1.0,
                   "The exploration parameter alpha in LinUCB. Only used if "
                   "`fixed_alpha_ranker` is True.")
flags.DEFINE_float("ucb_alpha_nominators", 1.0,
                   "The exploration parameter alpha in LinUCB. Only used if "
                   "`fixed_alpha_nominators` is True.")
# ------------------- Greedy (and thus sqcb)
flags.DEFINE_float("greedy_regularizer_ranker", 0.1,
                   "Initial regularization parameter for greedy ranker.")
flags.DEFINE_float("greedy_regularizer_nominators", 0.1,
                   "Initial regularization parameter for greedy nominators.")
# ------------------- Greedy and UCB
flags.DEFINE_bool("scale_regularizer_by_dim", True,
                  "Whether to scale the regularization by feature dimension.")
# ---------------------------- INPUT/OUTPUT -----------------------------------
flags.DEFINE_string("output_dir", "../results/",
                    "Path to the output directory (for results).")
flags.DEFINE_string("output_name", "",
                    "Name for result folder. Use timestamp if empty.")
flags.DEFINE_integer("plot_verbosity", 0,
                     "Verbosity of plot output for each individual run.")
# ------------------------------ MISC -----------------------------------------
flags.DEFINE_bool("try_seed_from_slurm", False,
                  "Try to get the seed from a slurm job array. If doesn't "
                  "exist, use `seed` flag.")
flags.DEFINE_integer("seed", 0,
                     "The random seed. Only used if `try_seed_from_slurm` is "
                     "False.")
flags.DEFINE_integer("fixed_pool_seed", None,
                     "Use a fixed seed for the pool allocation.")
flags.DEFINE_integer("fixed_feature_seed", None,
                     "Use a fixed seed for the feature allocation.")
FLAGS = flags.FLAGS

# TODO(nki): Track what each nominator and each ranker suggested at each time?
#  Which nom (if any) suggests best arm?
#  corresponding rewards, cosine distances (and other) between truth and
#  estimate over time?


def log_kwargs(kwargs):
  """Log key value pairs in the kwargs dictionary."""
  for k, v in kwargs.items():
    logging.info(f"\t{k}: {v}")


def init_random_ranker(seed: int, active_features=None):
  kwargs = dict(
    n_features=FLAGS.n_features,
    active_features=active_features,
    exp_arms=None,
    seed=seed
  )
  logging.info(f"Initialize random ranker with:")
  log_kwargs(kwargs)
  agent = LinUniform(**kwargs)
  return agent


def init_greedy_ranker(seed: int, active_features=None):
  kwargs = dict(
    n_features=FLAGS.n_features,
    regulariser=FLAGS.greedy_regularizer_ranker,
    scale_regulariser_by_dim=FLAGS.scale_regularizer_by_dim,
    active_features=active_features,
    exp_arms=None,
    seed=seed
  )
  logging.info(f"Initialize greedy ranker with:")
  log_kwargs(kwargs)
  agent = GreedyRidge(**kwargs)
  return agent


def init_ucb_ranker(seed: int, active_features=None):
  kwargs = dict(
    n_features=FLAGS.n_features,
    regulariser=FLAGS.ucb_regularizer_ranker,
    scale_regulariser_by_dim=FLAGS.scale_regularizer_by_dim,
    upper_bound_fn='const' if FLAGS.fixed_ucb_alpha_ranker else 'adapt',
    alpha=FLAGS.ucb_alpha_ranker,
    active_features=active_features,
    exp_arms=None,
    seed=seed
  )
  logging.info(f"Initialize UCB ranker with:")
  log_kwargs(kwargs)
  agent = LinUCB(**kwargs)
  return agent


def init_pg_ranker(seed: int, active_features=None):
  kwargs = dict(
    n_features=FLAGS.n_features,
    lr_schedule=PolicyGradient.root_lr_schedule(FLAGS.pg_lr_ranker),
    random_greedy=FLAGS.pg_random_greedy_ranker,
    eps_scaling=FLAGS.pg_eps_scaling_ranker,
    active_features=active_features,
    exp_arms=None,
    seed=seed
  )
  logging.info(f"Initialize PG ranker with:")
  log_kwargs(kwargs)
  agent = PolicyGradient(**kwargs)
  return agent


def init_sqcb_ranker(seed: int, active_features=None):
  oracle = init_greedy_ranker(seed, active_features)
  kwargs = dict(
    n_arms=FLAGS.n_arms,
    oracle=oracle,
    sqcb_lr_scale=FLAGS.sqcb_lr_ranker,
    seed=seed + 1
  )
  logging.info(f"Initialize SqCB ranker with greedy agent and:")
  log_kwargs(kwargs)
  agent = SquareCB(**kwargs)
  return agent


def init_greedy_nominators(pools, feats, seed: int):
  kwargs = dict(
    n_features=FLAGS.n_nom_features,
    regulariser=FLAGS.greedy_regularizer_nominators,
    scale_regulariser_by_dim=FLAGS.scale_regularizer_by_dim,
  )
  logging.info(f"Initialize greedy nominators with:")
  log_kwargs(kwargs)
  nominators = [GreedyRidge(active_features=af,
                            exp_arms=ea,
                            seed=(seed := seed + 1),
                            **kwargs)
                for af, ea in zip(feats, pools)]
  return nominators


def init_ucb_nominators(pools, feats, seed: int):
  kwargs = dict(
    n_features=FLAGS.n_nom_features,
    regulariser=FLAGS.ucb_regularizer_nominators,
    scale_regulariser_by_dim=FLAGS.scale_regularizer_by_dim,
    upper_bound_fn='const' if FLAGS.fixed_ucb_alpha_nominators else 'adapt',
    alpha=FLAGS.ucb_alpha_nominators,
  )
  logging.info(f"Initialize UCB nominators with:")
  log_kwargs(kwargs)
  nominators = [LinUCB(active_features=af,
                       exp_arms=ea,
                       seed=(seed := seed + 1),
                       **kwargs)
                for af, ea in zip(feats, pools)]
  return nominators


def init_pg_nominators(pools, feats, seed: int):
  kwargs = dict(
    n_features=FLAGS.n_nom_features,
    lr_schedule=PolicyGradient.root_lr_schedule(FLAGS.pg_lr_nominators),
    random_greedy=FLAGS.pg_random_greedy_nominators,
    eps_scaling=FLAGS.pg_eps_scaling_nominators,
  )
  logging.info(f"Initialize PG nominators with:")
  log_kwargs(kwargs)
  nominators = [PolicyGradient(active_features=af,
                               exp_arms=ea,
                               seed=(seed := seed + 1),
                               **kwargs)
                for af, ea in zip(feats, pools)]
  return nominators


def init_sqcb_nominators(pools, feats, seed: int):
  oracles = init_greedy_nominators(pools, feats, seed)
  kwargs = dict(
    sqcb_lr_scale=FLAGS.sqcb_lr_nominators,
  )
  logging.info(f"Initialize SqCB nominators with greedy oracle and:")
  log_kwargs(kwargs)
  nominators = [SquareCB(n_arms=len(ea),
                         oracle=oracle,
                         seed=(seed := seed + 1),
                         **kwargs)
                for ea, oracle in zip(pools, oracles)]
  return nominators


def get_random_pools(n_arms, n_nominators, pool_seed):
  rng = np.random.RandomState(pool_seed)
  arms = np.arange(n_arms)
  rng.shuffle(arms)
  apn = max(1, n_arms // n_nominators)
  pools = [arms[i * apn:(i + 1) * apn] for i in range(n_nominators)]
  for poolidx, armidx in enumerate(range(apn * n_nominators, n_arms)):
    pools[poolidx] = np.concatenate([pools[poolidx], np.array([arms[armidx]])])
  for pool in pools:
    pool.sort()
  return pools


def get_random_features(
    n_features, n_nominators, n_nom_features, feature_seed, single_stage=False):
  rng = np.random.RandomState(feature_seed)
  features = np.arange(n_features)
  if FLAGS.bandit_type in ['amazon', 'wiki']:
    logging.info(f"Shuffle features (with given random seed) before split")
    rng.shuffle(features)
  if single_stage:
    return features[:n_nom_features]
  fpn = min(n_nom_features, n_features // n_nominators)
  logging.info(f"Ended up with {fpn} features per nominator")
  feats = [features[i * fpn:(i + 1) * fpn] for i in range(n_nominators)]
  if fpn < n_nom_features:
    logging.info("Fewer features per nominator than requested. Correcting...")
    diff = n_nom_features - fpn
    for idx, fs in enumerate(feats):
      add_fs = np.setdiff1d(features, fs)
      add_fs = rng.choice(add_fs, size=diff, replace=False)
      feats[idx] = np.concatenate([fs, add_fs])
  for fs in feats:
    fs.sort()
  return feats


def get_arm_and_feature_splits(pool_seed, feature_seed):
  if FLAGS.split_arms:
    logging.info(f"Splitting {FLAGS.n_arms} arms across nominators...")
    pools = get_random_pools(FLAGS.n_arms, FLAGS.n_nominators, pool_seed)
  else:
    logging.info(f"Not splitting arms...")
    pools = [None] * FLAGS.n_nominators
  logging.info(f"Pools of arms for nominators: {pools}")

  if FLAGS.split_features:
    logging.info(f"Splitting {FLAGS.n_features} features across "
                 f"{FLAGS.n_nominators} nominators using "
                 f"{FLAGS.n_nom_features} per nominator...")
    feats = get_random_features(FLAGS.n_features, FLAGS.n_nominators,
                                FLAGS.n_nom_features, feature_seed)
    logging.info(f"Feature sets: {feats}")
  else:
    logging.info("Not splitting features across nominators")
    feats = [None] * FLAGS.n_nominators
  return pools, feats


# =============================================================================
# MAIN
# =============================================================================

def main(_):
  # ---------------------------------------------------------------------------
  # Directory setup, save flags, set random seed
  # ---------------------------------------------------------------------------
  FLAGS.alsologtostderr = True
  output_name = FLAGS.output_name

  seed = FLAGS.seed
  if FLAGS.try_seed_from_slurm:
    logging.info("Trying to fetch seed from slurm environment variable...")
    slurm_seed = os.getenv("SLURM_ARRAY_TASK_ID")
    if slurm_seed is not None:
      logging.info(f"Found task id {slurm_seed}")
      seed = int(slurm_seed)
      output_name = f'{output_name}_{seed}'
      logging.info(f"Set output directory {output_name}")
      FLAGS.output_name = output_name
  FLAGS.seed = seed
  logging.info(f"Set random seed {seed}...")
  np.random.seed(seed)

  # ----------------- Seed management for varying things independently
  bandit_seed = seed + 1
  ctxtsampler_seed = seed + 2
  ranker_seed = seed + 3
  if FLAGS.fixed_pool_seed is not None:
    pool_seed = FLAGS.fixed_pool_seed
  else:
    pool_seed = seed + 4
  if FLAGS.fixed_feature_seed is not None:
    feature_seed = FLAGS.fixed_feature_seed
  else:
    feature_seed = seed + 5
  nom_seed = seed + 5

  if output_name == "":
    dir_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
  else:
    dir_name = output_name
  out_dir = os.path.join(os.path.abspath(FLAGS.output_dir), dir_name)
  logging.info(f"Save all output to {out_dir}...")
  if not os.path.exists(out_dir):
    os.makedirs(out_dir)

  FLAGS.log_dir = out_dir
  # Create extra logfile?
  # logging.get_absl_handler().use_absl_log_file(program_name="run")

  logging.info("Save FLAGS (arguments)...")
  with open(os.path.join(out_dir, 'flags.json'), 'w') as fp:
    json.dump(FLAGS.flag_values_dict(), fp, sort_keys=True, indent=2)

  # ---------------------------------------------------------------------------
  # Initialize context samplers
  # ---------------------------------------------------------------------------

  # TODO(nki): These context samplers always spit out d features. When are we
  #  subsampling for nominators. does this require additional scaling of sqrt(d)?
  if FLAGS.bandit_type in ['wiki', 'amazon']:
    logging.info(f"Skipping context sampler setup for {FLAGS.bandit_type}...")
    logging.info(f"\tIgnoring distribution type")
    logging.info(f"\tIgnoring n_contexts")
    logging.info(f"\tIgnoring context_scaling")
    logging.info(f"\tIgnoring true parameter")
    bandit_kwargs = dict(n_arms=FLAGS.n_arms,
                         n_features=FLAGS.n_features,
                         seed=bandit_seed)
  else:
    logging.info("Setup the context samplers with...")
    logging.info(f"\tn_features: {FLAGS.n_features}")
    logging.info(f"\tdistribution type: {FLAGS.dtype}")
    logging.info(f"\tn_arms: {FLAGS.n_arms}")
    logging.info(f"\tn_contexts: {FLAGS.n_contexts} (only used for mixture)")
    logging.info(f"\tscaling contexts by: {FLAGS.context_scaling or 'sqrt(d)'}")
    context_sampler = get_ctxt_sampler(
      FLAGS.n_features,
      FLAGS.dtype,
      FLAGS.n_arms,
      FLAGS.n_contexts,
      ctxt_scale=FLAGS.context_scaling,
      ctxt_std=FLAGS.mixture_std,
      seed=ctxtsampler_seed)
    logging.info(f"Sample true parameter from {FLAGS.n_features}-dim Gaussian...")
    true_param = np.random.normal(size=FLAGS.n_features)
    true_param /= np.linalg.norm(true_param)
    bandit_kwargs = dict(
      n_arms=FLAGS.n_arms,
      context_sampler=context_sampler,
      true_param=true_param,
      seed=bandit_seed)

  # ---------------------------------------------------------------------------
  # Initialize bandit
  # ---------------------------------------------------------------------------

  logging.info(f"Setup {FLAGS.bandit_type} bandit with")
  logging.info(f"\tn_arms: {FLAGS.n_arms}")
  if FLAGS.bandit_type == 'gauss':
    logging.info(f"\treward_std: {FLAGS.reward_std}")
    bandit_kwargs.update(reward_std=FLAGS.reward_std)
    bandit = LinGaussBandit(**bandit_kwargs)
  elif FLAGS.bandit_type == 'bernoulli':
    bandit = LinBernBandit(**bandit_kwargs)
  elif FLAGS.bandit_type == 'categorical':
    bandit = LinCategBandit(**bandit_kwargs)
  elif FLAGS.bandit_type == 'poisson':
    bandit = LinPoissBandit(**bandit_kwargs)
  elif FLAGS.bandit_type == 'wiki':
    bandit = WikiBandit(**bandit_kwargs)
  elif FLAGS.bandit_type == 'amazon':
    bandit = AmazonBandit(**bandit_kwargs)
  else:
    raise NotImplementedError(f"Unknown bandit type {FLAGS.bandit_type}.")

  # ---------------------------------------------------------------------------
  # Single stage with fewer features than all?
  # ---------------------------------------------------------------------------
  if FLAGS.stages == 'single_stage' and FLAGS.n_nom_features < FLAGS.n_features:
    active_features = get_random_features(FLAGS.n_features, FLAGS.n_nominators,
                                          FLAGS.n_nom_features, feature_seed,
                                          single_stage=True)
  else:
    active_features = None

  # ---------------------------------------------------------------------------
  # Initialize ranker (needed for two_stage and single_stage)
  # ---------------------------------------------------------------------------
  if FLAGS.agent_ranker == 'greedy':
    agent = init_greedy_ranker(ranker_seed, active_features)
  elif FLAGS.agent_ranker == 'ucb':
    agent = init_ucb_ranker(ranker_seed, active_features)
  elif FLAGS.agent_ranker == 'pg':
    agent = init_pg_ranker(ranker_seed, active_features)
  elif FLAGS.agent_ranker == 'sqcb':
    agent = init_sqcb_ranker(ranker_seed, active_features)
  elif FLAGS.agent_ranker == 'random':
    agent = init_random_ranker(ranker_seed, active_features)
  else:
    raise NotImplementedError(f"Unknown agent {FLAGS.agent_ranker}")

  # Are we running single stage in batch mode?
  if FLAGS.stages == 'single_stage' and FLAGS.batch_mode:
    logging.info("Running single stage agent in batch mode...")
    if FLAGS.batch_update_freq < 0:
      batch_update_freq = FLAGS.n_rounds // 10
    else:
      batch_update_freq = FLAGS.batch_update_freq
    logging.info(f"\tUpdate frequency: {batch_update_freq}")
    agent = DelayedAgent(agent, lambda t: (t - 1) % batch_update_freq == 0)

  # ---------------------------------------------------------------------------
  # Initialize nominators (if two_stage)
  # ---------------------------------------------------------------------------
  if FLAGS.stages == 'two_stage':
    pools, feats = get_arm_and_feature_splits(pool_seed, feature_seed)
    if FLAGS.agent_nominators == 'greedy':
      nominators = init_greedy_nominators(pools, feats, nom_seed)
    elif FLAGS.agent_nominators == 'ucb':
      nominators = init_ucb_nominators(pools, feats, nom_seed)
    elif FLAGS.agent_nominators == 'pg':
      nominators = init_pg_nominators(pools, feats, nom_seed)
    elif FLAGS.agent_nominators == 'sqcb':
      nominators = init_sqcb_nominators(pools, feats, nom_seed)
    else:
      raise NotImplementedError(f"Unknown agent {FLAGS.agent_ranker}")

    logging.info("Combining ranker and nominators into two stage system...")
    agent = TwoStage(agent, nominators)

  print(f"is twostage: {isinstance(agent, TwoStage)}")

  # ---------------------------------------------------------------------------
  # Initialize and run game
  # ---------------------------------------------------------------------------
  logging.info("Setup the contextual bandit game...")
  game = ContextBanditGame(bandit, *(agent,))
  game.reset()

  # ---------------------------------------------------------------------------
  # Collect return and regret results
  # ---------------------------------------------------------------------------
  logging.info(f"Run the game for {FLAGS.n_rounds} rounds...")
  results = game.play(FLAGS.n_rounds)
  for res, val in results.items():
    logging.info(f"{res} shape: {val.shape}")

  # ---------------------------------------------------------------------------
  # Collect agent and true parameters
  # ---------------------------------------------------------------------------
  if FLAGS.bandit_type not in ["amazon", "wiki"]:
    if FLAGS.stages == "two_stage":
      results['weights'] = agent.ranker.weights
    elif FLAGS.batch_mode:
      results['weights'] = agent.agent.weights
    else:
      results['weights'] = agent.weights
    results['true_param'] = true_param

  # ---------------------------------------------------------------------------
  # Store aggregate results
  # ---------------------------------------------------------------------------
  logging.info(f"Store results...")
  result_path = os.path.join(out_dir, "results.npz")
  np.savez(result_path, **results)
  
  # ---------------------------------------------------------------------------
  # Plot aggregate results
  # ---------------------------------------------------------------------------
  if FLAGS.plot_verbosity > 0:
    logging.info(f"Plot aggregate results...")
    plotting.plot_all(results, os.path.join(out_dir, 'figures'))

  logging.info(f"DONE")

  
if __name__ == "__main__":
  app.run(main)
