import numpy as np
import numpy.linalg as LA
from quantization import *
import argparse
import matplotlib.pyplot as plt
import os
from scipy.stats import special_ortho_group

''' We generate data X (one sample in one row) from gaussian with diagonal covariance matrix. We distribute the data in the start itself
among the 2 machines who currently have a unit vector v. They should update it as v <- (X X^T v)/||X X^T v||. The difference ||X_1 X_1^T - X_2 X_2^T ||
will be quite small and can be set in the start itself fixed, by running a few iterations of baseline and setting y to twice the largest value seen, as
described in the paper.
'''

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n # just confirm this once

def infnorm(a):
	return LA.norm(a,np.inf) 

def random_rotate(X): # rotate the points by some random orthonormal matrix  
	Y = np.copy(X)
	n, d = Y.shape
	U = special_ortho_group.rvs(d)
	Y = Y @ U
	return Y

def experiment(n,d,qlevel,iterations,seed,args):
	# seed
	np.random.seed(seed)
	
	# print to stdout
	print("-----")
	print("0. seed = ", seed)
	print("1. data points (n) = {}, d = {}".format(n,d))
	print("2. qlevel = ", qlevel)
	print("-----")

	# generating the random dataset: X has row as data, so we need to look at C = X.T X
	mu = np.zeros(d) 
	C = np.eye(d)
	C[0,0], C[1,1] = 9, 10 # variance is largest along x axis 
	X = np.random.multivariate_normal(mu,C,n).astype(np.float64) # doing a /n might help for larger d
	if(args.random_rotate):
		X = random_rotate(X)
	X1 = X[0:int(n/2),:]
	X2 = X[int(n/2):n,:]

	# get principal eigenvector eig
	w, U = LA.eig(X.T @ X/n)
	eig = U[:,np.argmax(w)] 

	# arrays for plotting
	result = {"LQSGD":[], "RLQSGD":[], "QSGD":[], "hadamard":[], "baseline":[], "2diff":[], "diff":[], "maxmin":[], "norm":[]
	,"LQSGD_qe":[], "RLQSGD_qe":[], "QSGD_qe":[], "hadamard_qe":[]} # ||w - w_star|| as they did in that paper
		
	# v_init and making it a unit vector
	v_init = np.random.rand(d).astype(np.float64)
	v_init /= LA.norm(v_init)  

	# quantizers
	rlqsgd = RLQSGD(dimension = d, qlevel = qlevel, side = 0)
	lqsgd = LQSGD(dimension = d, qlevel = qlevel, side = 0)
	qsgd = QSGD(k = qlevel)
	hadamard = HadamardQuantizer(k = qlevel)

	for quantizer in ["baseline", rlqsgd, qsgd, hadamard,lqsgd]: # baseline should be before to ensure setting of y happens first
		name = quantizer.name if type(quantizer) is not str else quantizer 
		print(name)
		
		# initialize same v for all methods
		v = v_init

		for it in range(iterations):
			# getting Xi.t Xi v
			u1, u2 = X1.T @ X1 @ v, X2.T @ X2 @ v

			# random signing for the rlqsgd
			D = np.sign(np.random.rand(len(rlqsgd.pad(u1))) - 0.5)
			
			''' this basically corresponds to that running for a few iterations on baseline to set side
			In this experiment, we use a constant side'''
			if(it < 5 and quantizer == "baseline"):
				t = 2*infnorm(rlqsgd.HD(u1-u2,D))/(qlevel-1) 
				rlqsgd.set_side(max(rlqsgd.side, 2*t)) # 2 is for slack
				t1 = 2*infnorm(u1-u2)/(qlevel-1) 
				lqsgd.set_side(max(lqsgd.side, 2*t1)) # 2 is for slack
			
			# sum them, and make it unit vector
			if quantizer == rlqsgd:
				v = 2 * quantizer.average(u1,u2,D)[0]
			elif quantizer == lqsgd:
				v = 2 * quantizer.average(u1,u2)[0]
			else:
				v = (u1+u2) if quantizer == "baseline" else (quantizer.compress(u1) + quantizer.compress(u2))
			qe = LA.norm(v-(u1+u2)) # error that came due to quantizing in calculating X1^t X1 v + X2^t X2 v
			v /= LA.norm(v)
			
			# keeping track of error and appending
			err = min(LA.norm(v-eig),LA.norm(v+eig))
			result[name].append(err)
			if(name != "baseline"):
				result[name+"_qe"].append(qe)
			if(quantizer == "baseline"):
				result["2diff"].append(LA.norm(u1-u2))
				result["diff"].append(LA.norm(u1-u2,np.inf))
				result["norm"].append(LA.norm(u1))
				result["maxmin"].append(np.max(u1)-np.min(u1))

			# print stats
			if(it % 10 == 0):
				print("iteration = {}, error = {}".format(it,err))


	return result

def main():
	parser = argparse.ArgumentParser(description='Superlinear Convergence')
	parser.add_argument('--qlevel', type=int, default=64, metavar='QLEVEL', help='qlevel (default: 64)')
	parser.add_argument('--iterations', type=int, default=80, metavar='ITER', help='iterations (default: 80)')
	parser.add_argument('--n', type=int, default=8192, metavar='N', help='no of points (default: 8192)')
	parser.add_argument('--d', type=int, default=128, metavar='D', help='dimension (default: 256)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='NSEEDS', help='nseeds (default: 1)')
	parser.add_argument('--log-scale', action='store_true')
	parser.add_argument('--random_rotate', action='store_true') # do you want a random eigenvector
	parser.add_argument('--save', action='store_true')
	args = parser.parse_args()
	
	# parameters for experiment
	n = args.n # number of data points
	d = args.d # dimension
	iterations = args.iterations
	qlevel = args.qlevel
	
	# repeating the experiment for five seeds
	fields = ["LQSGD","RLQSGD","QSGD","hadamard","baseline","2diff","diff","maxmin","norm","LQSGD_qe","RLQSGD_qe","QSGD_qe","hadamard_qe"]
	seeds = [10*i for i in range(args.nseeds)]

	result_avg = {} 
	for field in fields:
		result_avg[field] = np.zeros(iterations)

	for seed in seeds:
		result = experiment(n,d,qlevel,iterations,seed,args)
		for field in fields:
			result_avg[field] += np.array(result[field])

	for field in fields:
		result_avg[field] /= len(seeds)

	if(args.log_scale):
		for field in fields:
			result_avg[field] = np.log10(result_avg[field])
	
	''' Plots '''
	iteration = list(range(iterations))
	# (1) convergence
	start_iter, end_iter = max(0,iterations-300), iterations
	plt.plot(iteration[start_iter:end_iter],result_avg["LQSGD"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],result_avg["RLQSGD"][start_iter:end_iter],label='RLQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],result_avg["QSGD"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],result_avg["hadamard"][start_iter:end_iter],label='Hadamard')
	plt.plot(iteration[start_iter:end_iter],result_avg["baseline"][start_iter:end_iter],label='Baseline')
	plt.xlabel('iteration')
	plt.ylabel('$\log_{10}||v-v_{top}||_2$' if args.log_scale else '||v-v_{top}||_2')
	plt.title('Power Iteration Convergence\n S = {}, d = {}, qlevel = {}'.format(n,d,qlevel), fontsize = 15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		name = "out/convergence_q_{}".format(args.qlevel)
		if(args.log_scale):
			name += "_log"
		if(args.random_rotate):
			name += "_random_rotate"
		plt.savefig(name+".pdf")
	else:
		plt.show()
	plt.close()

	# (2) norm vs iteration
	start_iter, end_iter = max(0,iterations-300), iterations
	plt.plot(iteration[start_iter:end_iter],result_avg["2diff"][start_iter:end_iter],label='$||u_0-u_1||_2$')
	plt.plot(iteration[start_iter:end_iter],result_avg["diff"][start_iter:end_iter],label='$||u_0-u_1||_\infty$')
	plt.plot(iteration[start_iter:end_iter],result_avg["norm"][start_iter:end_iter],label='$||u_0||_2$')
	plt.plot(iteration[start_iter:end_iter],result_avg["maxmin"][start_iter:end_iter],label='$max(u_0)-min(u_0)$')
	plt.xlabel('iteration')
	plt.ylabel('$\log_{10}$(value)' if args.log_scale else 'value')
	plt.title('Power Iteration Norms\n S = {}, d = {}, qlevel = {}'.format(n,d,qlevel), fontsize = 15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		name = "out/norms_q_{}".format(args.qlevel)
		if(args.log_scale):
			name += "_log"
		if(args.random_rotate):
			name += "_random_rotate"
		plt.savefig(name+".pdf")		  
	else:
		plt.show()
	plt.close()

	# (3) quantization error
	start_iter, end_iter = max(0,iterations-300), iterations
	plt.plot(iteration[start_iter:end_iter],result_avg["LQSGD_qe"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],result_avg["RLQSGD_qe"][start_iter:end_iter],label='RLQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],result_avg["QSGD_qe"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],result_avg["hadamard_qe"][start_iter:end_iter],label='Hadamard')
	plt.xlabel('iteration')
	plt.ylabel('$\log_{10}$(quantization error)' if args.log_scale else 'error')
	plt.title('Power Iteration Quantization Error\n S = {}, d = {}, qlevel = {}'.format(n,d,qlevel), fontsize = 15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		name = "out/qe_q_{}".format(args.qlevel)
		if(args.log_scale):
			name += "_log"
		if(args.random_rotate):
			name += "_random_rotate"
		plt.savefig(name+".pdf")		  
	else:
		plt.show()
	plt.close()
	

if __name__ == '__main__':
	main()
