import numpy as np 
from sklearn.linear_model import LinearRegression, LogisticRegression, HuberRegressor
from joblib import Parallel, delayed
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import lsqr
from math import sqrt

def random_data(m = 100, n = 100, k = 6, SNR = 1.0, p = 0.5):
	signal_power = 1.0
	noise_power = (signal_power / SNR) ** 2

	U = np.random.normal(0, signal_power, size=(m,k))
	V = np.random.normal(0, signal_power, size=(n,k))
	e = np.random.normal(0, noise_power, size=(m,n))

	M = np.dot(U, V.T) + e

	r = np.random.uniform(0, 1, size=(m, n))
	idx = (r <= p)
	omega = []
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			if idx[i, j]:
				omega.append((i,j))
	np.random.shuffle(omega)
	omega1 = omega[:int(0.8*len(omega))]
	omega2 = omega[int(0.8*len(omega)):]

	return U, V, e, M, omega1, omega2

def read_netflix(f):
	print('Reading', f)
	movie_id = None
	ratings = {}
	cnt = 0
	guy = {}
	for i, l in enumerate(open(f)):
		a = l.strip().split(',')

		if len(a) == 1:
			print(cnt)
			movie_id = int(a[0][:-1])
			print('movie', movie_id)
			cnt = 0
		else:
			cnt += 1
			user_id = int(a[0])
			rating = float(a[1])
			ratings[(movie_id, user_id)] = rating
			if user_id not in guy:
				guy[user_id] = 0
			guy[user_id] += 1
		if movie_id > 2000:
			break
	_, users = zip(*(sorted([(c, u) for (u, c) in guy.items()])[-2000:]))
	users = set(users)
	ratings2 = ratings.copy()
	for (mov, user) in ratings2.keys():
		if user not in users:
			ratings.pop((mov, user))
	movies, users = zip(*ratings.keys())
	movies = np.unique(movies)
	users = np.unique(users)
	m = len(movies)
	n = len(users)
	movies_inv = dict([(t[1], t[0]) for t in enumerate(movies)])
	users_inv = dict([(t[1], t[0]) for t in enumerate(users)])
	ratings_new = dict([((movies_inv[mov], users_inv[user]), ratings[(mov,user)]) for (mov, user) in ratings])

	M = coo_matrix((list(ratings_new.values()), tuple(zip(*ratings_new.keys()))))
	with open('netflix_medium.txt', 'w') as f:
		for ((movie, user), rating) in ratings_new.items():
			f.writelines('{} {} {}\n'.format(movie, user, rating))
	return M

def read_movielens(f):
	print('Reading', f)
	a = np.loadtxt(f, delimiter='\t').astype(np.int)
	idx = np.array(range(len(a)))
	np.random.shuffle(idx)
	idx1 = idx[:int(0.8*len(idx))]
	idx2 = idx[int(0.8*len(idx)):]
	omega = [(a[i,0], a[i,1]) for i in idx1]
	omega2 = [(a[i,0], a[i,1]) for i in idx2]
	M = csr_matrix((a[:,2].astype(np.float32), (a[:,0], a[:,1])))
	return M, omega, omega2

def read_movielens_u1(f):
	print('Reading', f)
	a = np.loadtxt('movielens/u1.base', delimiter='\t').astype(np.int)
	b = np.loadtxt('movielens/u1.test', delimiter='\t').astype(np.int)
	c = np.concatenate((a,b),axis=0)
	omega = [(a[i,0], a[i,1]) for i in range(len(a))]
	omega2 = [(b[i,0], b[i,1]) for i in range(len(b))]
	M = csr_matrix((c[:,2].astype(np.float32), (c[:,0], c[:,1])))
	return M, omega, omega2

def read_movielens_1m(f):
	print('Reading', f)
	a = np.loadtxt(f, delimiter='::').astype(np.int)
	idx = np.array(range(len(a)))
	np.random.shuffle(idx)
	idx1 = idx[:int(0.8*len(idx))]
	idx2 = idx[int(0.8*len(idx)):]
	omega = [(a[i,0], a[i,1]) for i in idx1]
	omega2 = [(a[i,0], a[i,1]) for i in idx2]
	M = csr_matrix((a[:,2].astype(np.float32), (a[:,0], a[:,1])))
	return M, omega, omega2

def parse_ranks(r, max_rank=100):
	if '-' in r:
		l = r.split('-')
		a, b = int(l[0]), int(l[1])
		my_list = list(range(a,b+1))
	elif ':' in r:
		l = r.split(':')
		a, b = int(l[0]), int(l[1])
		to = max_rank
		my_list = np.unique(np.concatenate((np.array(list(range(2,min(20,int(to))))).astype(int), np.linspace(2, to, min(to-1, 15)).astype(int))))
		my_list = my_list[(my_list >= a) & (my_list <= b)]
	elif int(r) > 0:
		my_list = [int(r)]
	else:
		to = max_rank
		my_list = np.unique(np.concatenate((np.array(list(range(2,min(20,int(to))))).astype(int), np.linspace(2, to, min(to-1, 15)).astype(int))))
	return my_list

def optimize(U, V, M, omega, iter_lim = None):
	A = []
	y = []
	for (i,j) in omega:
		A.append(np.dot(U[i:i+1].T, V[j:j+1]).reshape(1,-1))
		y.append(M[i,j])
	A = np.concatenate(A, axis=0)
	y = np.array(y).reshape(-1,1)
	#reg = LinearRegression(fit_intercept=False, n_jobs=-1).fit(A, y)
	#x = reg.coef_
	if iter_lim == None:
		x = lsqr(A, y, x0=np.eye(U.shape[1]).reshape(-1))[0]
	else:
		x = lsqr(A, y, iter_lim=iter_lim, x0=np.eye(U.shape[1]).reshape(-1))[0]
	return np.dot(U, x.reshape(U.shape[1], U.shape[1]))

def optimize_reg(U, V, M, omega, rho):
	A = []
	y = []
	for (i,j) in omega:
		A.append(np.dot(U[i:i+1].T, V[j:j+1]).reshape(1,-1))
		y.append(M[i,j])
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			A.append((rho * np.dot(U[i:i+1].T, V[j:j+1])).reshape(1,-1))
			y.append(0)
	A = np.concatenate(A, axis=0)
	y = np.array(y).reshape(-1,1)
	reg = LinearRegression(fit_intercept=False, n_jobs=-1).fit(A, y)
	x = reg.coef_
	return np.dot(U, x.reshape(U.shape[1], U.shape[1]))

def optimize_diag(U, V, M, omega):
	A = []
	y = []
	for (i,j) in omega:
		A.append(U[i:i+1] * V[j:j+1])
		y.append(M[i,j])
	A = np.concatenate(A, axis=0)
	y = np.array(y).reshape(-1,1)
	reg = LinearRegression(fit_intercept=False, n_jobs=-1).fit(A, y)
	x = reg.coef_.reshape(-1)
	return np.dot(U, np.diag(x))

def optimize_diag_reg(U, V, M, omega, rho):
	A = []
	y = []
	for (i,j) in omega:
		A.append(U[i:i+1] * V[j:j+1])
		y.append(M[i,j])
	for i in range(M.shape[0]):
		for j in range(M.shape[1]):
			A.append(rho * U[i:i+1] * V[j:j+1])
			y.append(0)
	A = np.concatenate(A, axis=0)
	y = np.array(y).reshape(-1,1)
	print(A.shape, y.shape)
	reg = LinearRegression(fit_intercept=False, n_jobs=-1).fit(A, y)
	x = reg.coef_.reshape(-1)
	return np.dot(U, np.diag(x))

def step_par(U, M, mask):
	V = np.zeros((M.shape[1], U.shape[1]))
	def f(i):
		if i in mask:
			return lsqr(U[mask[i]], M[mask[i],i].toarray().reshape(-1))[0].reshape(1,-1)
		return np.zeros((1,V.shape[1]))
	Vis = Parallel(n_jobs=min(32, V.shape[0]))(delayed(f)(i) for i in range(V.shape[0]))
	return np.concatenate(Vis, axis=0)

def step(U, M, mask, iter_lim = None, V = None):
	V = np.zeros((M.shape[1], U.shape[1]))
	def f(i):
		if i in mask:
			if V is None:
				if iter_lim is None:
					return lsqr(U[mask[i]], M[mask[i], i].toarray().reshape(-1), x0 = np.random.normal(0, 0.1, size=(M.shape[1])))[0].reshape(1,-1)
				else:
					return lsqr(U[mask[i]], M[mask[i], i].toarray().reshape(-1), iter_lim=iter_lim, x0 = np.random.normal(0, 0.1, size=(M.shape[1])))[0].reshape(1,-1)
			else:
				#print('hello', U.shape[1], V.shape[1])
				#return sgd(U[mask[i]], M[mask[i], i].toarray().reshape(-1), x0=V[i]).reshape(1,-1)
				if iter_lim is None:
					#return lsqr(U[mask[i]], M[mask[i], i].toarray().reshape(-1), x0=V[i])[0].reshape(1,-1)
					huber = HuberRegressor(max_iter=10, tol=1e-3, fit_intercept=False).fit(U, M[:,i])
					return huber.coef_.reshape(1,-1)
				else:
					return lsqr(U[mask[i]], M[mask[i], i].toarray().reshape(-1), iter_lim=iter_lim, x0=V[i])[0].reshape(1,-1)
			#return LinearRegression(fit_intercept=False, n_jobs=-1).fit(U[mask[i]], M[mask[i], i].toarray().reshape(-1)).coef_.reshape(1,-1)
		return np.zeros((1,V.shape[1]))
	Vis = [f(i) for i in range(V.shape[0])]
	return np.concatenate(Vis, axis=0)

def step_reg(U, M, mask, rho, V=None):
	V = np.zeros((M.shape[1], U.shape[1]))
	def f(i):
		if i in mask:
			if V is None:
				#return sgd(np.concatenate((U[mask[i]], rho * U), axis=0), np.concatenate((M[mask[i], i].toarray().reshape(-1), np.zeros(U.shape[0])))).reshape(1,-1)
				return lsqr(np.concatenate((U[mask[i]], rho * U), axis=0), np.concatenate((M[mask[i], i].toarray().reshape(-1), np.zeros(U.shape[0]))), iter_lim=2)[0].reshape(1,-1)
			else:
				#return sgd(np.concatenate((U[mask[i]], rho * U), axis=0), np.concatenate((M[mask[i], i].toarray().reshape(-1), np.zeros(U.shape[0]))), x0=V[i]).reshape(1,-1)
				return lsqr(np.concatenate((U[mask[i]], rho * U), axis=0), np.concatenate((M[mask[i], i].toarray().reshape(-1), np.zeros(U.shape[0]))), iter_lim=2, x0=V[i])[0].reshape(1,-1)
			#return LinearRegression(fit_intercept=False, n_jobs=-1).fit(np.concatenate((U[mask[i]], rho * U), axis=0), np.concatenate((M[mask[i], i].toarray().reshape(-1), np.zeros(U.shape[0])))).coef_.reshape(1,-1)
		return np.zeros((1,V.shape[1]))
	Vis = [f(i) for i in range(V.shape[0])]
	return np.concatenate(Vis, axis=0)

def compute_UV(U, V, omegaT, rng=None):
	def err_val(a):
		return U[omegaT[0], a] * V[omegaT[1], a]
	ret = sum(Parallel(n_jobs=min(12,U.shape[1]))(delayed(err_val)(a) for a in range(U.shape[1])))
	if rng != None:
		ret[ret > rng[1]] = rng[1]
		ret[ret < rng[0]] = rng[0]
	return csr_matrix((ret, omegaT), shape=(U.shape[0], V.shape[0]))

def abse(U, V, M, omega, omegaT = None, rng = None, UV=None):
	if omegaT == None:
		def err_val(a,b):
			x = np.dot(U[a], V[b])
			if rng != None:
				if x > rng[1]:
					x = rng[1]
				if x < rng[0]:
					x = rng[0]
			return np.abs(M[a,b] - x), np.abs(M[a,b])
		error1, error2 = tuple(zip(*Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega)))
		error1, error2 = sum(error1), sum(error2)
		return error1 / error2
	else:
		UV = compute_UV(U, V, omegaT, rng=rng) if UV is None else UV
		def err_val(a,b):
			return np.abs(M[a,b] - UV[a,b]), np.abs(M[a,b])
		error1, error2 = tuple(zip(*Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega)))
		error1, error2 = sum(error1), sum(error2)
		return error1 / error2

def rse(U, V, M, omega, omegaT = None, rng = None, UV=None):
	if omegaT == None:
		def err_val(a,b):
			x = np.dot(U[a], V[b])
			if rng != None:
				if x > rng[1]:
					x = rng[1]
				if x < rng[0]:
					x = rng[0]
			return (M[a,b] - x)**2, M[a,b]**2
		error1, error2 = tuple(zip(*Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega)))
		error1, error2 = sum(error1), sum(error2)
		return error1 / error2
	else:
		UV = compute_UV(U, V, omegaT, rng=rng) if UV is None else UV
		def err_val(a,b):
			return (M[a,b] - UV[a,b])**2, M[a,b]**2
		error1, error2 = tuple(zip(*Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega)))
		error1, error2 = sum(error1), sum(error2)
		return error1 / error2

def rmse(U, V, M, omega, omegaT = None, rng = None, UV=None):
	if omegaT == None:
		def err_val(a,b):
			x = np.dot(U[a], V[b])
			if rng != None:
				if x > rng[1]:
					x = rng[1]
				elif x < rng[0]:
					x = rng[0]
			return (M[a,b] - x)**2
		error1 = sum(Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega))
		return sqrt(error1 / len(omega))
	else:
		UV = compute_UV(U, V, omegaT, rng=rng) if UV is None else UV
		def err_val(a,b):
			return (M[a,b] - UV[a,b])**2
		error1 = sum(Parallel(n_jobs=12)(delayed(err_val)(a,b) for (a,b) in omega))
		return sqrt(error1 / len(omega))

def get_max_sv(U, V, M, omega, rng = None):
	x = np.random.normal(0, 1, size=(M.shape[1],1))
	x /= np.linalg.norm(x)
	M_ = M[omegaT]
	for t in range(20):
		tmp = np.dot(M, x) - np.dot(U, np.dot(V.T, x))
		x = np.dot(M.T, tmp) - np.dot(V, np.dot(U.T, tmp))
		x /= np.linalg.norm(x)
	y = np.dot(M, x) - np.dot(U, np.dot(V.T, x))
	y /= np.linalg.norm(y)
	return x, y
def sgd(A, b, T=10, x0 = None):
	lr = 0.01
	if x0 is None:
		x = np.random.normal(0, 0.1, size=(A.shape[1]))
	else:
		x = x0.copy()
	for t in range(T):
		omega = list(range(A.shape[0]))
		np.random.shuffle(omega)
		for i in omega:
			r = b[i] - np.dot(A[i], x)
			x += lr * r * A[i]
	return x
if __name__ == "__main__":
	files = ['netflix/combined_data_' + str(i) + '.txt' for i in range(1,5)]
	read(files[0])
