"""Utility functions for Nested Bandit analysis."""
import os
from typing import Text
from hashlib import md5
import json
import urllib.request
import zipfile

import numpy as np
import scipy
from scipy.special import softmax
from scipy.optimize import minimize

import parameter_sweep as ps


DATA_DIR = os.path.join(os.path.dirname(__file__), '../data')


def get_min_gap(means: np.ndarray) -> np.ndarray:
  """Gap is the difference between the largest and second to largest weights."""
  max_mean = np.max(means)
  return max_mean - np.max(means[means < max_mean])


def fit_pg_supervised(
        X, y, X_test=None, y_test=None, init_param=None, regulariser=1e-5,
        method='L-BFGS-B', tol=1e-4, maxiter=500, verbose=False):
  """Fit cost-sensitive logistic regression on batch data.

  :param X: Training inputs. Dimensions (point, arm, feature).
  :param y: Training outputs. Dimensions (point, arm).
  :param X_test: Test inputs (optional). Dimensions as for `X`.
  :param y_test: Test outputs (optional). Dimensions as for `y`.
  :param init_param: Initial parameter guess. Must match last dim of `X`.
  :param regulariser: Ridge regression penalty.
  :param method: Optimisation method. See scipy.optimize.minimize.
  :param tol: Tolerance for the optimiser. See scipy.optimize.minimize.
  :param maxiter: Maximum number of optimisation iterations.
  :param verbose: Print auxiliary info.
  :return: Optimised parameter, and auxiliary optimiser outputs.
  """
  def fn(param):
    """Returns loss and gradient."""
    probs = softmax(X @ param, axis=-1)
    loss = - np.mean(np.sum(probs * y, axis=-1))
    loss += 0.5 * regulariser * np.sum(param**2) / param.size

    g = np.einsum('xij,xi->xj', X, probs, optimize='greedy')
    g = X - np.expand_dims(g, -2)
    g = np.einsum('xij,xi,xi->xj', g, probs, y, optimize='greedy')
    g = - np.mean(g, 0) + regulariser * param / param.size

    return loss, g

  init_param = np.zeros(X.shape[-1]) if init_param is None else init_param
  optim_out = minimize(
    fn, init_param, method=method, jac=True, tol=tol,
    options={'maxiter': maxiter, 'disp': verbose})
  param = optim_out.x

  if verbose and X_test is not None and y_test is not None:
    probs = softmax(X_test @ param, axis=-1)
    test_acc = np.mean(np.argmax(probs, axis=-1) == np.argmax(y_test, axis=-1))
    test_loss = - np.mean(np.sum(probs * y_test, axis=-1))
    print(f'test_acc = {test_acc}, test_loss = {test_loss}')

  return param, optim_out


def get_ctxt_sampler(
        d, dtype, n_arms, n_contexts, ctxt_scale=-1., ctxt_std=0.1, seed=None):
  """Uniform and mixture-based context samplers."""
  rng = np.random.RandomState(seed)
  ctxt_scale = np.sqrt(d) if ctxt_scale < 0 else ctxt_scale
  if dtype == 'unif':

    def ctxt_sampler(n=1):
      if d == 1:
        ret = 2 * rng.uniform(size=(n, n_arms, d)) - 1
      else:
        ret = rng.normal(size=(n, n_arms, d))
        ret /= np.linalg.norm(ret, axis=-1, keepdims=True)
      if n == 1:
        ret = ret.squeeze(0)
      return ctxt_scale * ret

  elif dtype == 'mix':
    if d == 1:
      ctxt_means = 2 * rng.uniform(size=(n_arms, n_contexts, d)) - 1
    else:
      ctxt_means = rng.normal(size=(n_arms, n_contexts, d))

    def ctxt_sampler(n=1):
      ret = ctxt_means[:, rng.randint(0, n_contexts, size=n)]
      ret = np.moveaxis(ret, 1, 0)
      if d > 1:
        ret += ctxt_std * rng.normal(size=(n, n_arms, d))  # noise perturbation
        ret /= np.linalg.norm(ret, axis=-1, keepdims=True)
      if n == 1:
        ret = ret.squeeze(0)
      return ctxt_scale * ret

  else:
    raise NotImplementedError(dtype)
  return ctxt_sampler


# -----------------------------------------------------------------------------
# Input output and result management helpers
# -----------------------------------------------------------------------------
def get_output_name(value_dict: ps.ArgDict) -> Text:
  """Get the name of the output directory."""
  name = ""
  for k, v in value_dict.items():
    name += f"-{k}_{v}"
  return name[1:]


def get_run_id(value_dict: ps.ArgDict) -> Text:
  """Get hash of value dictionary for ID purposes."""
  return md5(json.dumps(value_dict, sort_keys=True).encode('utf-8')).hexdigest()


def get_curve_with_errors(dat,
                          df,
                          run_id: Text,
                          mid_func=np.mean,
                          errs='se',
                          quantiles=(0.25, 0.75)):
  """Compute statistics over cumulative sums of multiple runs for given setting.

  If full curves not available, default to mean and standard errors computed
  from dataframe values directly.

  Args:
    dat: Numpy data file containing all runs
    df: Dataframe containing summaries for runs
    run_id: The key to a given run setting
    mid_func: Which function to compute for the main curve (usually `np.mean` or
        `np.median`).
    errs: Can be either 'se', 'std', or 'quantile' for how to compute errors.
    quantiles: If `errs` is `quantile` then this tuple specifies lower and
        upper quantiles to show (lower, upper).

  Returns:
    Dictionary with two keys: 'regret' and 'return'. For both keys, the value is
    a 3-tuple of np.ndarray consisting of (mid, lower, upper)
  """
  results = {}
  for key in ['regret', 'return']:
    try:
      data = dat[f'{run_id}_{key}']
    except KeyError:
      mid = dat[f'{run_id}_{key}_mean']
      std = dat[f'{run_id}_{key}_std']
      n = df[df.id == run_id][f'n_seeds_{key}'].values
      lo, up = mid - 1.96 * std / np.sqrt(n), mid + 1.96 * std / np.sqrt(n)
      results[key] = (mid, lo, up)
      continue
    curdat = np.cumsum(data, axis=1)
    mid = mid_func(curdat, axis=0)
    if errs == 'se':
      std = np.std(curdat, axis=0)
      lo = mid - 1.96 * std / np.sqrt(curdat.shape[0])
      up = mid + 1.96 * std / np.sqrt(curdat.shape[0])
    elif errs == 'std':
      std = np.std(curdat, axis=0)
      lo, up = mid - std, mid + std
    elif errs == 'quantile':
      std = np.zeros_like(mid)
      lo, up = np.quantile(curdat, quantiles, axis=0)
    else:
      raise NotImplementedError(f'Unknown error method {errs}.')
    results[key] = (mid, lo, up, std)
  return results


def read_bandit_dataset(name: str):
  if name == 'wiki10-31k':
    with open_zipped(zipped_dir_name='wiki10-31k',
                     data_name='features.npz',
                     data_url='https://kkrauth.s3-us-west-2.amazonaws.com/wiki10-31k.zip',
                     mode='rb') as feature_file:
      features = scipy.sparse.load_npz(feature_file).toarray()

    with open_zipped(zipped_dir_name='wiki10-31k',
                     data_name='ratings.npz',
                     data_url='https://kkrauth.s3-us-west-2.amazonaws.com/wiki10-31k.zip',
                     mode='rb') as ratings_file:
      ratings = scipy.sparse.load_npz(ratings_file).tocsr()
  elif name == 'amazoncat-13k-bert':
    with open_zipped(zipped_dir_name='amazoncat-13k-bert',
                     data_name='features.npz',
                     data_url='https://kkrauth.s3-us-west-2.amazonaws.com/amazoncat-13k-bert.zip',
                     mode='rb') as feature_file:
      npz = np.load(feature_file)
      features = npz[npz.files[0]]

    with open_zipped(zipped_dir_name='amazoncat-13k-bert',
                     data_name='ratings.npz',
                     data_url='https://kkrauth.s3-us-west-2.amazonaws.com/amazoncat-13k-bert.zip',
                     mode='rb') as ratings_file:
      ratings = scipy.sparse.load_npz(ratings_file).tocsr()
  else:
    raise ValueError('Dataset name not recognized.')

  return features, ratings


def open_zipped(zipped_dir_name: str, data_name: str, data_url: str, mode: str):
  data_file = os.path.join(DATA_DIR, zipped_dir_name, data_name)
  fetch_zip(zipped_dir_name, data_url)
  return open(data_file, mode)


def fetch_zip(zipped_dir_name: str, data_url: str):
  data_dir = os.path.join(DATA_DIR, zipped_dir_name)
  if not os.path.isdir(data_dir):
    os.makedirs(DATA_DIR, exist_ok=True)

    download_location = os.path.join('{}.zip'.format(data_dir))
    urllib.request.urlretrieve(data_url,
                               filename=download_location)
    with zipfile.ZipFile(download_location, 'r') as zip_ref:
      zip_ref.extractall(DATA_DIR)
    os.remove(download_location)


def whiten(features):
  means = np.mean(features, axis=0)
  std = np.maximum(np.std(features, axis=0), 1e-7)
  return (features - means) / std
