'''
Bunch of utility functions for running experiments. Basically  blocks of
  code put here to avoid cluttering scripts
'''
import matplotlib.pyplot as plt
import autograd.numpy as np
import scipy.optimize
import autograd
import time
import os

def compareMethods(model, methods):
  '''
  9/18/2020
  methods should be a list of tuples. method[i][0] should be a function that
    takes in a model and returns
    (optimalLambda, CV(optimalLambda)) (the intent is to use things like
     lambda model: utils.chooseLamBFGS(model, 'IJ')).
    methods[i][1] should be a label (e.g. 'BFGS_IJ')
  '''
  results = {}
  for method in methods:
    print(method[1])
    results[method[1]] = {}
    start = time.process_time()
    optLam, CVVal = method[0](model)
    results[method[1]]['time'] = time.process_time() - start
    results[method[1]]['optLam'] = optLam
    results[method[1]]['CVVal'] = CVVal
  return results
  

'''
Utiliies for running LOOCV for l2 regularized linear regression
'''
def findNonQVXInds(losses, returnLocations=False):
  # Check if the derivative is going from negative to positive (a minimum)
  #  or neg. to pos. (a maximum, which would violate quasiconvexity)
  isMaximum = np.diff(np.sign(np.diff(losses, axis=0)), axis=0) < 0
  numMaxima = isMaximum.sum(axis=0)
  violatesQVX = numMaxima > 0
  if not returnLocations:
    return np.where(violatesQVX)[0]
  else:
    return np.where(violatesQVX)[0], isCrossing, isMaximum


def findNonQVXInds_old(losses, returnLocations=False):
  # it's not necessary to check numCrossings > 1 and numMaxima > 0...
  
  # A critical point has occurred when the gradient changes sign
  #  Here, we're checking the number of times this occurs.
  isCrossing = np.diff(np.sign(np.diff(losses, axis=0)), axis=0) != 0
  numCrossings = isCrossing.sum(axis=0)

  # Now check if the derivative is going from decr. to incr. (a minimum)
  #  or incr. to decr. (a maximum, which would violate quasiconvexity)
  isMaximum = np.diff(np.sign(np.diff(losses, axis=0)), axis=0) < 0
  numMaxima = isMaximum.sum(axis=0)

  violatesQVX = np.logical_or(numCrossings > 1, numMaxima > 0)
  if not returnLocations:
    return np.where(violatesQVX)[0]
  else:
    return np.where(violatesQVX)[0], isCrossing, isMaximum

def findNonQVXIndsStable(loss, returnLocations=False):
  '''
  More stable version of the function above. The issue with findNonQVXInds
  is that it will numerically fail when the loss has a numerically flat region;
  in this case, the differences between adjacent loss values is essentially
  equal to numerical noise, and may oscillate between positive and negative,
  giving spurious non-quasiconvexity.

  NOTE: this function only takes in a single loss (so a 1-D numpy array),
    whereas findNonQVXInds can process multiple at once.
  '''
  # Cut off at the first place diffs dips below two orders of magnitude
  #  bigger than the floating point resolution (this threshold is a bit
  #  arbitrary). Then proceed exactly as in findNonQVXInds.
  diffs = np.diff(loss)
  unstableInds = np.where(np.abs(diffs) < np.finfo(loss.dtype).resolution*1000)
  if (unstableInds[0].shape[0] > 0):
    if unstableInds[0][0] <= loss.shape[0]/2: # failed
      return (np.array([]), False)
    else:
      return (findNonQVXInds(loss[:unstableInds[0][0]],
                             returnLocations=returnLocations), True)
  else:
      return (findNonQVXInds(loss, returnLocations=returnLocations), True)

def findAllMinima(loss):
  '''
  Only takes in a single loss (i.e. a one-d array of length number-of-lambdas)
  '''
  diffs = np.sign(np.diff(loss, axis=0))

  # Edge case: if one of the differences is exactly zero, its sign will be
  #  zero (this can only happen if L(lambda_t) = L(lambda_{t+1}), which
  #  is really unlikely in practice!). If we let the difference stay as zero,
  #  then none of the diffs of diffs will be 2.0, and we'll find no minimum
  #  below. Setting signs of zero to be signs of 1 breaks the tie.
  diffs[np.where(diffs == 0)] = 1.0

  # If the difference of the signs of differences is 2, then that means you
  #  were decreasing (sign = -1) and then started increasing (sign = 1).
  #  Because of how np.diff works, this is offset by 1 from the actual
  #  location of the minimum.
  minInds = np.where(np.diff(diffs, axis=0) == 2)[0] + 1

  # The function is still decreasing, even at lambda_max. Presumably this means
  #  that lambda = infinity is a local minimum. Add on lambda_max as one of the
  #  minimizers.
  # Same thing if the function is increasing at labmda_min.
  if diffs[-1] == -1:
    minInds = np.append(minInds, loss.shape[0]-1)
  if diffs[0] == 1:
    minInds = np.append(minInds, 0)
    
  
  return minInds, loss[minInds]


def runTrialUOnly(U, S, thetaStar, lams, individualLosses=False):
  '''
  Computes LOOCV(\lambda) *much* faster given your covariate matrix X
    is just X = U @ np.diag(S) (i.e. the right singular vectors are I_D) and
    there is no noise in Y (i.e. Y = X @ thetaStar).
  '''
  # Compute Qs[n,l] = X[n]^T (X^T X + lams[l]*I_D) X[n]
  a = U[:,:,None] * (S[:,None]**2 / (S[:,None]**2 + lams[None,:]))
  Qs = np.einsum('nd,ndl->nl', U, a)

  # Compute trainLoss[n,l] = (x_n^T\thetaHat_lams[l] - y_n)**2
  b = U[:,:,None] * (S[:,None] * lams[None,:] / (S[:,None]**2 + lams[None,:]))
  trainLoss = np.einsum('d,ndl->nl', thetaStar, b)**2

  loocvLoss = (trainLoss / (1-Qs)**2)

  if individualLosses:
    return (loocvLoss,
            findNonQVXIndsStable(loocvLoss.mean(0)).shape[0] > 0)
  else:
    return (loocvLoss.mean(0),
            findNonQVXIndsStable(loocvLoss.mean(0)).shape[0] > 0)


def runTrialUOnlyNoisy(U, S, Y, lams, allLosses=False):
  # Compute Qs[n,l] = X[n]^T (X^T X + lams[l]*I_D) X[n]
  a = U[:,:,None] * (S[:,None]**2 / (S[:,None]**2 + lams[None,:]))
  Qs = np.einsum('nd,ndl->nl', U, a)
  
  # Compute trainLoss[n,l] = (x_n^T\thetaHat_lams[l] - y_n)**2
  SUY = (S[:,None]**2 / (S[:,None]**2 + lams[None,:])) * (U.T @ Y)[:,None]
  trainLoss = (np.einsum('dl,nd->nl', SUY, U) - Y[:,None])**2
  
  loocvLosses = (trainLoss / (1-Qs)**2)

  nonQVXInds, success = findNonQVXIndsStable(loocvLosses.mean(0))
  if allLosses:
    return (loocvLosses, nonQVXInds.shape[0] > 0, success)
  else:
    return (loocvLosses.mean(0), nonQVXInds.shape[0] > 0, success)


def runTrialKFold(U, S, Y, lams, holdOuts):
  L = lams.shape[0]
  losses = np.empty((len(holdOuts), L))

  lhs = np.diag(S) @ U.T @ Y
  for ll in range(L):
    for kk in range(len(holdOuts)):
      K = holdOuts[kk].shape[0]
      Uk = U[holdOuts[kk]]
      Yk = Y[holdOuts[kk]]
      vec = lhs - np.diag(S) @ Uk.T @ Yk
      middleMtx = np.linalg.inv(np.eye(K) - Uk @ np.diag(S**2/(S**2 + lams[ll])) @ Uk.T)
      mtx = (np.diag(1/(S**2 + lams[ll]))
             + np.diag(S / (S**2 + lams[ll])) @ Uk.T @ middleMtx @ Uk @ np.diag(S / (S**2 + lams[ll])))
      thetaHatk = mtx @ vec
      losses[kk,ll] = ((Uk @ np.diag(S) @ thetaHatk - Yk)**2).mean()
  nonQVXInds, success = findNonQVXIndsStable(losses.mean(0))
  return losses, nonQVXInds.shape[0] > 0, success

'''
def runTrialKFold(U, S, Y, lams, holdOuts):
  L = lams.shape[0]
  losses = np.empty((len(holdOuts), L))

  vals, vecs = np.linalg.eigh(Uk @ Uk.T)
  nonZeroVals = np.abs(vals) > 1e-12
  vecs = vecs[:,nonZeroVals]
  vals = vals[nonZeroVals]
  middleMtx2 = np.eye(K) - vecs @ vecs.T
  middleMtx2 += vecs @ np.diag(1 / (1 + S**2 / (S**2 + lams[ll]))) @ vecs.T

  lhs = np.diag(S) @ U.T @ Y
  for ll in range(L):
    for kk in range(len(holdOuts)):
      K = holdOuts[kk].shape[0]
      Uk = U[holdOuts[kk]]
      Yk = Y[holdOuts[kk]]
      vec = lhs - np.diag(S) @ Uk.T @ Yk
      middleMtx = np.linalg.inv(np.eye(K) - Uk @ np.diag(S**2/(S**2 + lams[ll])) @ Uk.T)
      mtx = (np.diag(1/(S**2 + lams[ll]))
             + np.diag(S / (S**2 + lams[ll])) @ Uk.T @ middleMtx @ Uk @ np.diag(S / (S**2 + lams[ll])))
      thetaHatk = mtx @ vec
      losses[kk,ll] = ((Uk @ np.diag(S) @ thetaHatk - Yk)**2).mean()
  nonQVXInds, success = findNonQVXIndsStable(losses.mean(0))
  return losses, nonQVXInds.shape[0] > 0, success
'''


def getKFoldInds(N, K):
  inds = []
  increment = int(np.floor(N/K))
  cur = 0
  for kk in range(K-1):
    inds.append(np.arange(cur, min(cur+increment, N)))
    cur += increment
  inds.append(np.arange(cur, N))
  return inds

  
def runTrial(trial, N, Xfull, Yfull, lams,
             seedBase=12345,
             normalizeData=False,
             returnLosses=False):
  # Take random subset of full dataset
  np.random.seed(seedBase+trial)
  subsetInds = np.random.choice(Xfull.shape[0], size=N, replace=False)
  X = Xfull[subsetInds]
  Y = Yfull[subsetInds]
  X = Xfull[subsetInds].copy()
  Y = Yfull[subsetInds]
  if normalizeData:
    X /= np.sqrt(np.var(X, axis=0))
    X -= X.mean(axis=0)
  S = np.linalg.svd(X, compute_uv=False)
  

  lossesLOO = np.empty((lams.shape[0],N))
  D = X.shape[1]
  for ll, lam in enumerate(lams):
    thetaHat = np.linalg.solve(X.T @ X + lam*np.eye(D), X.T @ Y)
    Qs = np.einsum('nd,nd->n',
                   np.linalg.solve(X.T @ X + lam*np.eye(D), X.T).T,
                   X)
    Xtheta = X @ thetaHat
    predsLOO = Xtheta + Qs * (Xtheta - Y) / (1-Qs)
    lossesLOO[ll] = (Y - predsLOO)**2

  nonQVX_n = findNonQVXInds(lossesLOO)

  if returnLosses:
    return (nonQVX_n.shape[0], # num. of non-qvx individual losses
            findNonQVXInds(lossesLOO.sum(axis=1)).shape[0] > 0, #nonQVX overall?
            S,
            lossesLOO)
  else:
    return (nonQVX_n.shape[0], # num. of non-qvx individual losses
            findNonQVXInds(lossesLOO.sum(axis=1)).shape[0] > 0, #nonQVX overall?
            S)


def runTrialFreshX(trial,
                   N,
                   D,
                   S,
                   lams,
                   theta,
                   exactSingularValues=True,
                   saveTag='',
                   seedBase=12345,):
  np.random.seed(seedBase+trial)
  U = None
  V = None
  if ( os.path.exists('Uoutput/U-%d-%s.txt' % (trial, saveTag)) and
       os.path.exists('Uoutput/V-%d-%s.txt' % (trial, saveTag)) ):
    U = np.loadtxt('Uoutput/U-%d-%s.txt' % (trial, saveTag))
    V = np.loadtxt('Uoutput/V-%d-%s.txt' % (trial, saveTag))
    if U.shape != (N,D) or V.shape != (D,D):
      U = None
      V = None
    else:
      X = U @ np.diag(S) @ V
      if ( (not np.allclose(np.mean(X, axis=0), 0, rtol=0, atol=1e-4)) or
           (not np.allclose(np.var(X, axis=0), 1, rtol=0, atol=1e-4))  or
           X.shape != (N,D)):

        U = None
        V = None
  if (U is None) or (V is None):
    if exactSingularValues:
      X = createX(N, S)
      U, Strial, V = np.linalg.svd(X, full_matrices=False)
      np.savetxt('Uoutput/U-%d-%s.txt' % (trial, saveTag), U)
      np.savetxt('Uoutput/V-%d-%s.txt' % (trial, saveTag), V)
    elif not exactSingularValues:
      X = np.random.normal(size=(N,D))
      U, _, V = np.linalg.svd(X, full_matrices=False)
      X = U @ np.diag(S) @ V
      X /= np.sqrt(np.var(X, axis=0))
      X -= X.mean(axis=0)
  #X /= np.sqrt(np.var(X, axis=0))
  #X -= np.mean(X, axis=0)
  
  U, Strial, V = np.linalg.svd(X, full_matrices=False)
  if exactSingularValues:
    try:
      assert(np.abs(S-Strial).sum() <= 1e-4*D)
    except:
      print('bad', np.abs(S-Strial).sum())
      
  Y = X @ theta #+ np.random.normal(size=N, scale=0.01)
  #Y -= Y.mean()
  losses = np.empty((lams.shape[0],N))
  lossesLOO = np.empty((lams.shape[0],N))
  D = X.shape[1]
  for ll, lam in enumerate(lams):
    thetaHat = np.linalg.solve(X.T @ X + lam*np.eye(D), X.T @ Y)
    Qs = np.einsum('nd,nd->n',
                   np.linalg.solve(X.T @ X + lam*np.eye(D), X.T).T,
                   X)
    Xtheta = X @ thetaHat
    predsLOO = Xtheta + Qs * (Xtheta - Y) / (1-Qs)
    lossesLOO[ll] = (Y - predsLOO)**2
  nonQVX_n = findNonQVXInds(lossesLOO)
  return (nonQVX_n.shape[0], # num. of non-qvx individual losses
          findNonQVXInds(lossesLOO.sum(axis=1)).shape[0] > 0, # nonQVX overall?
          Strial)



def objV(vprev, S, v, N):
  return ( ((vprev.T @ v)**2).sum() + # v is orthogonal to prev. v's
           ((S**2 * v**2).sum() - N)**2 + # ensures || np.diag(S @ v) ||_2^2
                                            #          = N
           (np.linalg.norm(v)**2 - 1)**2 )  # v is unit norm

def objVSimultaneous(V, S, N, D, normXtheta=None, theta=None):
  V = np.reshape(V, (D,D))
  obj = ((((np.diag(S) @ V)**2).sum(axis=0) - N)**2).sum()
  obj += ((V.T @ V - np.eye(D))**2).sum()
  if normXtheta is not None and theta is not None:
    obj += (S**2 @ (V @ theta)**2 - normXtheta)**2
  return obj

def getNextv(vprev, S, N):
  evalObj = lambda v: objV(vprev, S, v, N)
  gradObj = autograd.grad(evalObj)
  optVal = np.inf

  # May be local optima, so randomly restart until you find an obj. val of zero.
  while np.abs(optVal) > 1e-9:
    # Random init that is 1) orthog to vprev and 2) unit norm
    init = np.random.normal(size=S.shape[0])
    init -= ((vprev.T @ init) * vprev).sum(axis=1)
    init /= np.linalg.norm(init)
    ret = scipy.optimize.minimize(evalObj,
                                  jac=gradObj,
                                  x0=init)
    optVal = ret.fun
  return ret.x

def createX(N, S, zeroMean=True, normXTheta=None, theta=None):
  '''
  Creates an NxD matrix X s.t. its singular values are S and its columns have
    unit variance.
  '''
  assert(np.allclose((S**2).sum(), N*S.shape[0]))
  D = S.shape[0]
  #vprev = np.zeros((D,1))
  #for d in range(D):
  #  v = getNextv(vprev, S, N)
  #  vprev = np.append(vprev, v.reshape(-1,1), axis=1)
  #V = vprev[:,1:]
  evalObj = lambda V, S=S, N=N, D=D, normXTheta=normXTheta, theta=theta: objVSimultaneous(V, S, N, D, normXTheta, theta)
  gradObj = autograd.jacobian(evalObj)
  optVal = np.inf
  while optVal > 1e-8:
    Vinit = np.linalg.qr(np.random.normal(size=(D,D)))[0]    
    ret = scipy.optimize.minimize(evalObj,
                                  jac=gradObj,
                                  x0=Vinit.ravel(),
                                  method='L-BFGS-B',
                                  tol=1e-12)
    optVal = ret.fun
    print(optVal)

  ret = scipy.optimize.minimize(evalObj,
                                jac=gradObj,
                                x0=ret.x,
                                method='L-BFGS-B',
                                tol=1e-12)    
  V = ret.x.reshape(D,D)

  rand = np.random.normal(size=(N,D))  
  U = np.linalg.qr(rand)[0]
  if zeroMean:
    evalObj = lambda U, V=V, S=S: objU(V, S, N, U)
    gradObj = autograd.jacobian(evalObj)
    ret = scipy.optimize.minimize(evalObj,
                                  jac=gradObj,
                                  x0=U.ravel(),
                                  method='L-BFGS-B',
                                  tol=1e-12)

    U = ret.x.reshape(N,D)
  X = U @ np.diag(S) @ V
  return X


def getUFast(N,D):
  ones = np.ones(N) / np.sqrt(N)
  num = min(N,D)
  Utilde = np.random.normal(size=(num,N)).T
  for d in range(num):
    Utilde[:,d] -= np.inner(ones, Utilde[:,d]) * ones
    for a in range(d):
      Utilde[:,d] -= np.inner(Utilde[:,a], Utilde[:,d])*Utilde[:,a]
    Utilde[:,d] /= np.linalg.norm(Utilde[:,d])
  return Utilde

def objU(V, S, N, U):
  D = S.shape[0]  
  U = np.reshape(U, (N,D))
  return ( (((U @ np.diag(S) @ V).sum(axis=0))**2).sum() + 
           ((U.T @ U - np.eye(D))**2).sum() )

def pickThetaGivenNorm(S, norm):
  '''
  Picks a unit-norm theta s.t. np.linalg.norm(np.diag(S) @ theta)**2 == norm
  '''
  D = S.shape[0]
  init = np.random.normal(size=D)
  init /= np.linalg.norm(init)
  obj = lambda theta: ( ((theta**2).sum() - 1)**2 +
                        (((S*theta)**2).sum() - norm)**2 )
  grad = autograd.grad(obj)
  optVal = np.inf
  minVal = np.maximum(1e-8, (S[-1]**2 - norm) + 1e-8)
  while optVal > minVal:
    ret = scipy.optimize.minimize(obj,
                                  jac=grad,
                                  x0=init)
    optVal = ret.fun
  return ret.x


def runTrialRenormalizePern(X, Y, lams):
  (N, D) = X.shape
  lossesLoocv = np.empty((N,lams.shape[0]))
  for n in range(N):
    Yn = np.concatenate([Y[:n], Y[n+1:]])
    Xn = np.vstack([X[:n], X[n+1:]])
    Xn /= np.sqrt(np.var(Xn, axis=0))
    Xn -= Xn.mean(axis=0)
    Un, Sn, Vn = np.linalg.svd(Xn, full_matrices=False)
    Vn = Vn.T
    for ll, lam in enumerate(lams):
      thetaHat1 = Vn @ np.diag(Sn / (Sn**2 + lam)) @ Un.T @ Yn
      thetaHat2 = np.linalg.solve(Xn.T @ Xn + lam*np.eye(D), Xn.T @ Yn)
      assert(np.allclose(thetaHat1 - thetaHat2, 0))
      lossesLoocv[n,ll] = (X[n] @ thetaHat1 - Y[n])**2
  return lossesLoocv
      
      
def plotLocalMinQuality(ratio, phis, alphas, title, fs=20):
  if ratio.ndim == 3:
    ratio = ratio.mean(2)
  nAlphas = alphas.shape[0]
  tickEvery = int(nAlphas/5)
  tickFs = 16
  cax = plt.gca().imshow(ratio, 
                         aspect='auto',
                         extent=[phis.min(), phis.max(), nAlphas, 0])

  plt.xticks([0, np.pi, 2*np.pi], labels=['0', r'$\pi$', r'$2\pi$'], fontsize=fs-1)
  plt.yticks(np.arange(nAlphas)[::tickEvery], labels=[f'{alpha:0.1e}' for alpha in alphas[::tickEvery]], fontsize=fs-1)
  plt.title(title, fontsize=fs)
  plt.xlabel(r'$\theta^*$', fontsize=fs)
  plt.ylabel(r'Min singular value', fontsize=fs)
  cbar = plt.colorbar(cax)
  cbar.ax.tick_params(labelsize=fs-5)

  plt.tight_layout()

  
  
  
