import pickle
import itertools
from tqdm import tqdm
import autograd.numpy as np
import matplotlib.pyplot as plt
from pathos.multiprocessing import ProcessingPool

import sys
sys.path.append('../')
import expUtils

def oneTrial(noise, U, S, theta, lams):
    X = U @ np.diag(S)
    Y = X @ theta + noise
    Y -= Y.mean()
    return expUtils.runTrialUOnlyNoisy(U, S, Y, lams)

def runForOneN(N, D, thetaStar, S, lams, nUs, nNoises, noiseMultipliers):
  nMultipliers = noiseMultipliers.shape[0]
  nonQVXLocal = np.empty((nMultipliers, nUs, nNoises))
  for mm in tqdm(range(nMultipliers)):
    noises = np.empty((nUs,nNoises,N))
    Us = np.empty((nUs,N,D))
    for uu in range(nUs):
        Us[uu] = expUtils.getUFast(N, D)
        for ii in range(nNoises):
            E = np.random.normal(size=N, scale=5.5)
            E = (np.eye(N) - Us[uu] @ Us[uu].T) @ E
            noises[uu,ii] = E / np.linalg.norm(E) * noiseMultipliers[mm]
    for uu in range(nUs):
      for ii in range(nNoises):
        nonQVXLocal[mm,uu,ii] = oneTrial(noises[uu,ii],
                                         Us[uu],
                                         S,
                                         thetaStar,
                                         lams)[1]
  return nonQVXLocal


np.random.seed(12346530)
D = 5
thetaStar = np.random.normal(size=D); thetaStar /= np.linalg.norm(thetaStar)
lams = np.logspace(-8,6,600)
S = np.ones(D)

nUs = 4000
nNoises = 150
nNs = 10
Ns = np.linspace(10,40,nNs).astype(np.int32)
nCores = 8
nMultipliers = 60
noiseMultipliers = np.linspace(0, 1.8, nMultipliers)

res = ProcessingPool(nCores).map(runForOneN,
                                 Ns,
                                 itertools.cycle([D]),
                                 itertools.cycle([thetaStar]),
                                 itertools.cycle([S]),
                                 itertools.cycle([lams]),
                                 itertools.cycle([nUs]),
                                 itertools.cycle([nNoises]),
                                 itertools.cycle([noiseMultipliers]))

nonQVX = np.array(res)
from IPython import embed;np.set_printoptions(linewidth=80);embed()
res = {'Ns':Ns,
       'nUs':nUs,
       'nNoises':nNoises,
       'noiseMultipliers':noiseMultipliers,
       'nonQVX':nonQVX,
       'thetaStar':thetaStar}
f = open('output/epsilonScalingExperiment_fewerNoiseGPComp.pkl', 'wb')
pickle.dump(res, f)
f.close()

res = {'Ns':Ns,
       'nUs':nUs,
       'nNoises':nNoises,
       'noiseMultipliers':noiseMultipliers,
       'nonQVXMean':nonQVX.mean((2,3)),
       'thetaStar':thetaStar}
f = open('output/epsilonScalingExperiment_fewerNoiseGPComp_meanQVX.pkl', 'wb')
pickle.dump(res, f)
f.close()


from IPython import embed;np.set_printoptions(linewidth=80);embed()
