import copy
import models
import autograd
import scipy.spatial
import scipy.optimize
import autograd.numpy as np


#####
# Following two methods are  how to autodiff the following process:
#   1) set l_2 regularization parameter, lambda.
#   2) Refit the model
# I feel these should be methods of the Model object, but I couldn't
#   figure out how to make autograd happy with that.
#####
@autograd.core.primitive # "Will manually tell autograd how to diff. this"
def fitWithLambda(lam, weights, model):
  if type(lam) == np.ndarray:
    lam = float(lam.squeeze()) # scipy.optimize.minimize passes in numpy array,
                                # but we want a float
  model.theta = np.random.normal(size=model.training_data.D+1)
  model.regularization = lambda theta: lam / 2 * np.linalg.norm(theta)**2
  return model.fit(weights)

def fitWithLambda_vjp(ans, lam, weights, model):
  '''
  Manual version of the derivative d(fitWithLambda) / d(lam).
  Specifically, autograd needs this function to return a function that
    evaluates vector jacobian products evaluated at x
  (full honesty: I don't know what the ans argument does)

  ****NOTE*** You need to have pre-set self.theta = \hat\theta(\lambda)
    and pre-run model.computeDerivs(). Otherwise this function will
    have to re-do those (time-consuming) computations itself, which
    is a waste of time.
  '''
  model.theta = np.random.normal(size=model.training_data.D+1)
  model.theta = fitWithLambda(lam, weights, model)
  model.computeDerivs()
  X = model.training_data.X
  regHess = lam * np.eye(X.shape[1])
  #H = X.T.dot(model.D2[:,np.newaxis]*X) + regHess ... this crashes np.solve
                                            # during 2nd deriv computation?!?!
  H = X.T @ (model.D2[:,np.newaxis]*X) + regHess # this doesn't for some reason
  dThetadLam = -np.linalg.solve(H, model.theta)
  return lambda g: np.inner(dThetadLam, g)
autograd.core.defvjp(fitWithLambda, fitWithLambda_vjp)

def gen_low_rank_X(N, D, ranks,
                   basis_vectors=None, lowRankNoise=0.0, rotate=True):
  '''
  ranks is a list of subspace sizes; each datapoint will be randomly
  drawn from one of these subspaces.
  '''
  if type(ranks) == int: # make sure ranks is a list
    ranks = [ranks]
  
  # Generate random set of D-dimensional orthonormal vectors
  if basis_vectors is None:
    basis_vectors = []
    for rank in ranks:
      rank = min(rank,D)
      basis_vectors.append(np.zeros((rank,D)))
      if rotate:
        # QR decomposition is much faster than using ortho_group
        #   U = scipy.stats.special_ortho_group.rvs(D)
        # Not sure why.
        U = np.linalg.qr(np.random.normal(size=(D,D)))[0]
      else:
        U = np.eye(D)
      for r in range(rank):
        basis_vectors[-1][r,r] = 1.0
        basis_vectors[-1][r] = U.dot(basis_vectors[-1][r])

  # Generate X s.t. each X[n] lives in subspace spanned by
  #  basis_vectors[i] for some i.
  X = np.empty((N,D))
  for n in range(N):
    i = np.random.choice(np.arange(len(ranks)))
    #i = 0
    #coeffs = np.random.normal(size=(ranks[i]), scale=np.sqrt(D/ranks[i]))
    coeffs = np.random.normal(size=(ranks[i]))
    X[n] = basis_vectors[i].T.dot(coeffs)
  X += np.random.normal(scale=lowRankNoise, size=(N,D))

  return X, basis_vectors

def sigmoid(X, w, bias=0):
  return 1./(1 + np.exp(-(np.dot(X, w) + bias)))

def hyperparameterCV(lam, model, mode='IJ'):
  '''
  L2 regularization for now
  '''
  if type(lam) == np.ndarray:
    lam = float(lam) # scipy.optimize.minimize passes in numpy arrays
  N = model.training_data.X.shape[0]
  model.regularization = lambda theta: lam/2 * np.linalg.norm(theta)**2
  
  # Fit and compute needed stuff for ACV  
  if mode == 'IJ':
    # Q: What's with copying and re-setting the regularization?
    # Answer: we're interested in autodiffing this function w.r.t. lam.
    #  Note the dependence on lam is only in model.regularization.
    #  Autograd will pass lam in as an ArrayBox object when differentiating.
    #  When it hands lam off to the hand-coded derivative of fitWithLambda
    #  (fitWithLambda_vjp() above), it will pass it in as a non-ArrayBox
    #  (just a regular float). But fitWithLambda_vjp() sets
    #  model.regularization. So when you exit fitWithLambda on an autodiff
    #  pass, the ArrayBox (which encodes all the derivative information)
    #  will have been clobbered and you'll end up with a garbage derivative.
    regularization_cpy = model.regularization
    model.theta = fitWithLambda(lam, np.ones(N), model)
    model.regularization = regularization_cpy
    model.computeDThetaDWeights()

  loss = 0.0
  for n in range(N):
    weights = np.ones(N)
    weights[n] = 0.0
    if mode == 'exact':
      # See note above about copying and re-setting the regularization.
      regularization_cpy = model.regularization
      thetanAppx = fitWithLambda(lam, weights, model)
      model.regularization = regularization_cpy
    else:
      # Doesn't involve call to model.fit() so ignore regularization copying
      thetanAppx = model.retrainWeighted(weights, mode=mode)
    weights = np.zeros(N)
    weights[n] = 1.0
    loss += model.evalObjective(thetanAppx, weights, regularized=False)
  
  return loss / N

def chooseLamGridSeach(lams, model, mode):
  losses = np.empty(lams.shape[0])
  for ll, lam in enumerate(lams):
    losses[ll] = hyperparameterCV(lam, model, mode=mode)
  argmin = np.argmin(losses)
  return lams[argmin], losses[argmin]

def chooseLamNelderMead(model, mode, initLam=0.75):
  obj = lambda lam: hyperparameterCV(lam, model, mode)
  res = scipy.optimize.minimize(obj, x0=initLam)
  return res.x, hyperparameterCV(res.x, model, mode=mode)

def chooseLamBFGS(model, mode, initLam=0.75):
  obj = lambda lam: hyperparameterCV(lam, model, mode)
  grad = autograd.grad(obj)
  res = scipy.optimize.minimize(obj, jac=grad, x0=initLam)
  return res.x, hyperparameterCV(res.x, model, mode=mode)

def chooseLamBayesOpt(model, mode, maxLam,
                      nIter=10,
                      nInit=2):
  from bayes_opt import BayesianOptimization
  fn = lambda lam: -hyperparameterCV(lam, model, mode=mode)
  optimizer = BayesianOptimization(f=fn,
                                   pbounds={'lam':(0,maxLam)},
                                   random_state=1234,
                                   verbose=0)
  optimizer.maximize(init_points=nInit, n_iter=nIter)
  return optimizer.max['params']['lam'], -optimizer.max['target']
                                   
  


