import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from tqdm import tqdm
from matplotlib.colors import LogNorm

np.random.seed(1)

d = 5

def lmbda(a,b):
	return np.linalg.norm(a)**2 + np.linalg.norm(b)**2

def Q(a,b):
	return np.sum(np.absolute(np.power(a,2) - np.power(b,2)))

def GD(a0,b0,eta,T):
	As,Bs,Eps,lmbdas,Qs = [a0.copy()], [b0.copy()], [np.dot(a0, b0) - 1], [lmbda(a0,b0)], [Q(a0,b0)]
	for t in range(T):
		As.append(As[t] - eta*Eps[t]*Bs[t])
		Bs.append(Bs[t] - eta*Eps[t]*As[t])
		Eps.append(np.dot(As[t+1],Bs[t+1])-1.)
		lmbdas.append(lmbda(As[t+1], Bs[t+1]))
		Qs.append(Q(As[t+1], Bs[t+1]))
		if np.absolute(Eps[t]) > 100 or max(np.absolute(Eps[-4:])) < 1e-3:
			break
	return As, Bs, Eps, lmbdas, Qs, (max(np.absolute(Eps[-4:])) < 1e-3)

def gen_ab_with_lambdas(eps, lambdas):
	a, b = np.random.randn(d), np.random.randn(d)
	factor = (eps + 1)/np.dot(a,b)
	a = a*factor # now a' b - 1 = eps

	anorm, bnorm = np.linalg.norm(a), np.linalg.norm(b)

	As, Bs = np.zeros([len(lambdas), d]), np.zeros([len(lambdas), d])
	for i in range(len(lambdas)):
		l = lambdas[i]
		factor = np.sqrt(l + np.sqrt(l**2 - (2*anorm*bnorm)**2)) / (np.sqrt(2)*anorm)
		As[i], Bs[i] = a*factor, b/factor
	return As, Bs


T = 500
eps0 = -2.

lambda_min = 2*np.abs(eps0+1)
lambda0s = np.linspace(1.5*lambda_min, 20*lambda_min, 500)

etas = np.linspace(3e-3, 4.25, 500)

A0s, B0s = gen_ab_with_lambdas(eps0, lambda0s)

T_to_converge = np.zeros([len(etas), len(lambda0s)])
Q_over_Q0 = np.zeros([len(etas), len(lambda0s)])

for i in tqdm(range(len(etas))):
	for j in tqdm(range(len(lambda0s)), leave=False):
		lambda0 = lambda0s[j]
		eta = etas[i] / lambda0
		a0, b0 = A0s[j], B0s[j]
		As, Bs, Eps, lmbdas, Qs, converged = GD(a0,b0,eta,T)

		if converged:
			T_to_converge[i,j] = len(Eps)
			Q_over_Q0[i,j] = Qs[-1] / Qs[0]
		else:
			T_to_converge[i,j] = -1
			Q_over_Q0[i,j] = -1

def closeness(n):
	return min(n%1, 1-(n%1))

eta_ticks = [0,len(etas)-1]
for i in range(1,len(etas)-1):
	if closeness(etas[i]) < closeness(etas[i+1]) and closeness(etas[i]) < closeness(etas[i-1]):
		eta_ticks.append(i)
eta_ticks.sort()
eta_labels = ['{:.3f}'.format(etas[0])] + ['{}'.format(round(etas[i])) for i in eta_ticks[1:-1]] + ['{:.2f}'.format(etas[-1])]

min_pos = T_to_converge[T_to_converge > 0].min()

plt.figure()
g = sns.heatmap(T_to_converge, mask=(T_to_converge<0), linewidth=0., norm=LogNorm(vmin=min_pos, vmax=T_to_converge.max()))
g.set(xlabel =r"$\lambda(0)$", ylabel = r"$\eta\lambda(0)$", title ='#steps to convergence')
xticks = list(range(0, len(lambda0s), 50)) + [len(lambda0s)-1]
g.set_xticks(xticks)
g.set_xticklabels(['{:.1f}'.format(lambda0s[i]) for i in xticks])
g.set_yticks(eta_ticks)
g.set_yticklabels(eta_labels)
plt.tight_layout()
plt.savefig('time-to-converge.jpg')

plt.figure()
g = sns.heatmap(Q_over_Q0, mask=(Q_over_Q0<0), linewidth=0.)
g.set(xlabel =r"$\lambda(0)$", ylabel = r"$\eta\lambda(0)$", title =r'$Q(T) / Q(0)$')
xticks = list(range(0, len(lambda0s), 50)) + [len(lambda0s)-1]
g.set_xticks(xticks)
g.set_xticklabels(['{:.1f}'.format(lambda0s[i]) for i in xticks])
g.set_yticks(eta_ticks)
g.set_yticklabels(eta_labels)
plt.tight_layout()
plt.savefig('q-over-q0.jpg')

plt.show()


