import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

np.random.seed(1)

plt.rcParams['figure.figsize'] = [12, 12]
plt.rcParams.update({'font.size': 18})

# define finite sum objective by listing all functions


fs = np.array([[40, 0, 0],
               [20, 40, 0],
               [6, 24, 0],
               [-10, 20, 0],
               [-3, 12, 0],
               [-20, 20, 0]
                ])


def get_fbar(fs):
  assert all([len(f) == 3 for f in fs])
  return np.mean(fs, axis= 0)
      
print(get_fbar(fs))


# consider all functions of the form ax^2 bx + c 

def grad(f, x):
  """
    f: list of 3 values [a,b,c] that defines 
    the quadratic f
    return: derivative wrt x evaluated at x
  """
  assert len(f) == 3
  a, b, c = f[0], f[1], f[2]
  return 2*a*x + b

def eval(f,x):
  a, b, c = f[0], f[1], f[2]
  return a*x**2 + b*x +c

def plot(xs, fs):
  for f in fs:
    vf = np.vectorize(lambda x : eval(f, x))
    if f[0] >= 0:
      plt.plot(xs, vf(xs), color='orange', label=r'convex $f_i$')
    else:
      plt.plot(xs, vf(xs), color='green', label=r'concave $f_i$')

  
  
  f_bar = get_fbar(fs)
  vf_bar = np.vectorize(lambda x : eval(f_bar, x))
  plt.plot(xs, vf_bar(xs), linestyle='dashed', color='red', label=r'$f(x)= 5.5 x^2 +19.33x$')
  # plot x_\ast
  plt.plot(0, 0, marker="*", markersize=15, color="black")
  plt.annotate(r"$x_\ast$", xy=(0,0), xycoords="data", xytext=(-0.2, 10),
            arrowprops=dict(facecolor='black', shrink=0.07, width=1),
            horizontalalignment='right', verticalalignment='top')
  handles, labels = plt.gca().get_legend_handles_labels()
  newLabels, newHandles = [], []
  for handle, label in zip(handles, labels):
      if label not in newLabels:
        newLabels.append(label)
        newHandles.append(handle)

  plt.legend(newHandles, newLabels, loc=2)
  
  plt.axhline(y=0, color='black')
  plt.axvline(x=0, color='black')

plot(np.linspace(-.5,1, num=1000), fs)
plt.title('Example with $\sigma^2_{\mathcal{X}}=0$')

plt.text(0.2, -10, r"$\mathcal{X} = [0,1]$", {'size': 20})
plt.savefig('non-convex-example.pdf')
plt.clf()
plt.rcParams['figure.figsize'] = [10, 10]
num_iterations = 200
num_trajectories = 10000

def distance_sol(x, sol):
  return 0.5*(x-sol)**2

def sample(fs):
  return fs[np.random.randint(fs.shape[0], size=1)][0]

def project(x):
  if x <= 0:
    return 0
  elif x >=1:
    return 1
  else:
    return x

def update(x, f, step):
  x = x - step*grad(f, x)
  return project(x)

def polyak_update(x, f, sol_val):
  if grad(f, x)==0:
    return x
  x = x - (eval(f,x)-sol_val)*grad(f, x)/grad(f, x)**2
  return project(x)

trajectories = []
distances = []
L_max = 2*np.max(np.abs(fs[:,0]))
print("L_max " + str(L_max))
for i in range(num_trajectories):
  x = 1
  xs = [x]
  distance = [distance_sol(x, 0)]
  for j in range(num_iterations):
    f = sample(fs[-3:])
    x = update(x, f, 1/(L_max))
    xs.append(x)
    distance.append(distance_sol(x, 0))
  trajectories.append(xs)
  plt.plot(xs, color='blue', alpha=0.01)
  distances.append(distance)

mean = np.mean(distances, axis=0)
plt.plot(mean, label=r'mean trajectory', color='blue')
err = 1.96*np.std(distances, axis=0)
plt.yscale('log')

handles, labels = plt.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
    if label not in newLabels:
      newLabels.append(label)
      newHandles.append(handle)
plt.title("Euclidean Divergence to Solution")
plt.ylabel(r'$\frac{1}{2}||x_t-x_\ast||^2_2$')
plt.xlabel('Iterations')
plt.legend(newHandles, newLabels, loc=1)
plt.savefig('trajectories.pdf')