import numpy as np
from math import sqrt
from scipy.sparse import csr_matrix
from greedy import greedy
from greedy_reg import greedy_reg
from local import local
from local_other import local_other
from local_new import local_new
from local_newnew import local_newnew
from soft import knn, nuclear, soft
from mysvd import mysvd
import sys
from tinydb import *
import json
import tools

if len(sys.argv) == 1:
	print('Usage: python3 movielens.py split_id ranks algorithms [save|nosave]')
	exit(0)

#initialize database
TinyDB.DEFAULT_TABLE_KWARGS = {'cache_size': 0}
db = TinyDB('db.json')
TinyDB.DEFAULT_TABLE_KWARGS = {'cache_size': 0}
User = Query()
def save(split_id, alg, rank, train_error, val_error):
	if len(sys.argv) >= 5 and sys.argv[4] == 'save':
		db.remove((User.dataset == 'movielens_10m_' + split_id) & (User.algorithm == alg) & (User.k == int(rank)))
		db.insert({'dataset' : 'movielens_10m_' + split_id, 'algorithm' : alg, 'k' : int(rank), 
		'train_error' : float(train_error), 
		'val_error': float(val_error),
		})

split_id = sys.argv[1]
np.random.seed(1024 + int(split_id))

M, omega1, omega2 = tools.read_movielens_1m('movielens/movielens_10m.dat')

rank_list = tools.parse_ranks(sys.argv[2] if len(sys.argv) > 2 else None, max_rank = 20)

algorithms = sys.argv[3].split('-') if len(sys.argv) > 3 else ['greedy_alt']

if 'greedy_alt' in algorithms:
	print('greedy_alt')
	_, _, errors = \
		greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', err_type='rmse', rng=[1,5])
	for k in rank_list:
		save(split_id, 'greedy_alt', k, errors[k][0], errors[k][1])

if 'greedy_alt2' in algorithms:
	print('greedy_alt2')
	_, _, errors = \
		greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', err_type='rmse', iter_lim=3, rng=[1,5]) #best iter_lim=3
	for k in rank_list:
		save(split_id, 'greedy_alt2', k, errors[k][0], errors[k][1])

if 'greedy_reg_alt' in algorithms:
	print('greedy_reg_alt')
	_, _, train_error, val_errors = \
		greedy_reg(M, omega1, k = max(rank_list), omega2=omega2, method='alt', rho=0.015, err_type='rmse', rng=[1,5])

if 'local_alt' in algorithms:
	print('local_alt')
	_, _, errors = \
		local(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', T = k + 4, err_type='rmse', rng=[1,5])
	for k in rank_list:
		save(split_id, 'local_alt', k, errors[k][0], errors[k][1])

if 'local_alt2' in algorithms:
	print('local_alt2')
	_, _, errors = \
		local(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', T = max(rank_list), err_type='rmse', iter_lim=3, rng=[1,5])
	for k in rank_list:
		save(split_id, 'local_alt2', k, errors[k][0], errors[k][1])

if 'local_other_alt2' in algorithms:
	print('local_other_alt2')
	for k in rank_list:
		_, _, errors = \
			local_other(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rmse', iter_lim=3, rng=[1,5])
		save(split_id, 'local_other_alt2', k, errors[k][0], errors[k][1])

if 'local_new_alt2' in algorithms:
	print('local_new_alt2')
	for k in rank_list:
		_, _, errors = \
			local_new(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rmse', iter_lim=3, rng=[1,5])
		save(split_id, 'local_new_alt2', k, errors[k][0], errors[k][1])

if 'local_newnew_alt2' in algorithms:
	print('local_newnew_alt2')
	warm_start = None
	for k in rank_list:
		_, _, errors, warm_start = \
			local_newnew(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rmse', iter_lim=2, warm_start=warm_start, rng=[1,5])
		save(split_id, 'local_newnew_alt2', k, errors[k][0], errors[k][1])

if 'local_newnew_opt' in algorithms:
	print('local_newnew_opt')
	warm_start = None
	for k in rank_list:
		_, _, errors, warm_start = \
			local_newnew(csr_matrix(M), omega1, k = k, omega2=omega2, method='opt', err_type='rmse', warm_start=warm_start, rng=[1,5])
		save(split_id, 'local_newnew_opt', k, errors[k][0], errors[k][1])

if 'greedy_opt' in algorithms:
	print('greedy_opt')
	_, _, errors = \
		greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='opt', err_type='rmse', rng=[1,5])
	for k in rank_list:
		save(split_id, 'greedy_opt', k, errors[k][0], errors[k][1])

if 'greedy_opt2' in algorithms:
	print('greedy_opt2')
	_, _, errors = \
		greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='opt', err_type='rmse', iter_lim=100, rng=[1,5])
	for k in rank_list:
		save(split_id, 'greedy_opt2', k, errors[k][0], errors[k][1])

if 'soft' in algorithms:
	print('soft')
	ranks_found = {}
	rank_list_set = set(rank_list)
	for k in rank_list:
		if k in ranks_found: continue
		U, V, errors = \
			soft(csr_matrix(M), omega1, k = k, omega2=omega2, err_type='rmse', to_avoid = ranks_found, rng=[1,5])
		#for kk in errors:
	#		print(kk, 'error', errors[kk][0], errors[kk][1])
		for kk in errors:
			if kk not in ranks_found and kk in rank_list_set:
				print(kk, 'error', errors[kk][0], errors[kk][1])
				save(split_id, 'soft', k, errors[k][0], errors[k][1])
			ranks_found[kk] = True
'''
for algorithm in all_errors:
	for k in all_errors[algorithm]:
		cur_errors = all_errors[algorithm][k]
		cur_data = np.concatenate([np.array(x).reshape(1,-1) for x in cur_errors], axis=0)
		cur_mean = np.mean(cur_data, axis=0)
		cur_std = np.std(cur_data, axis=0)
		save(split_id, algorithm, k, cur_mean, cur_std)
'''
