import numpy as np
from math import sqrt
from scipy.sparse import csr_matrix
from greedy import greedy
from greedy_reg import greedy_reg
from local import local
from local_other import local_other
from local_new import local_new
from local_newnew import local_newnew
from soft import knn, nuclear, soft
from mysvd import mysvd
import sys
from tinydb import *
import json
import tools

if len(sys.argv) == 1:
	print('Usage: python3 rnd.py config_id ranks algorithms [save|nosave]')
	exit(0)

#initialize database
TinyDB.DEFAULT_TABLE_KWARGS = {'cache_size': 0}
db = TinyDB('db.json')
TinyDB.DEFAULT_TABLE_KWARGS = {'cache_size': 0}
User = Query()
def save(config_id, alg, rank, cur_mean, cur_std):
	if len(sys.argv) >= 5 and sys.argv[4] == 'save':
		db.remove((User.dataset == 'random_error_' + config_id) & (User.algorithm == alg) & (User.k == int(rank)))
		db.insert({'dataset' : 'random_error_' + config_id, 'algorithm' : alg, 'k' : int(rank), 
		'train_error_mean' : float(cur_mean[0]), 
		'train_error_std' : float(cur_std[0]), 
		'val_error_mean': float(cur_mean[1]),
		'val_error_std': float(cur_std[1])
		})

all_errors = {}
for t in range(50):
	print('t:', t)
	#generate random data
	np.random.seed(1024 + t)
	config_id = sys.argv[1]
	with open('config.json') as f:
		L = json.load(f)[config_id]
	U, V, e, M, omega1, omega2 = tools.random_data(m=int(L["m"]), n=int(L["n"]), k=int(L["k"]), p=float(L["p"]), SNR=float(L["SNR"]))


	rank_list = tools.parse_ranks(sys.argv[2] if len(sys.argv) > 2 else None, max_rank = 20)

	algorithms = sys.argv[3].split('-') if len(sys.argv) > 3 else ['greedy_alt']

	if 'greedy_alt' in algorithms:
		print('greedy_alt')
		if 'greedy_alt' not in all_errors: all_errors['greedy_alt'] = {}
		_, _, errors = \
			greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', err_type='rse')
		for k in rank_list:
			if k not in all_errors['greedy_alt']: all_errors['greedy_alt'][k] = []
			all_errors['greedy_alt'][k].append(errors[k])
			#save(config_id, 'greedy_alt', k, errors[k][0], errors[k][1])

	if 'greedy_alt2' in algorithms:
		print('greedy_alt2')
		if 'greedy_alt2' not in all_errors: all_errors['greedy_alt2'] = {}
		_, _, errors = \
			greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', err_type='rse', iter_lim=3) #best iter_lim=3
		for k in rank_list:
			if k not in all_errors['greedy_alt2']: all_errors['greedy_alt2'][k] = []
			all_errors['greedy_alt2'][k].append(errors[k])
			#save(config_id, 'greedy_alt2', k, errors[k][0], errors[k][1])

	if 'local_alt' in algorithms:
		print('local_alt')
		if 'local_alt' not in all_errors: all_errors['local_alt'] = {}
		_, _, errors = \
			local(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', T = k + 4, err_type='rse')
		for k in rank_list:
			if k not in all_errors['local_alt']: all_errors['local_alt'][k] = []
			all_errors['local_alt'][k].append(errors[k])
			#save(config_id, 'local_alt', k, errors[k][0], errors[k][1])

	if 'local_alt2' in algorithms:
		print('local_alt2')
		if 'local_alt2' not in all_errors: all_errors['local_alt2'] = {}
		_, _, errors = \
			local(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='alt', T = max(rank_list), err_type='rse', iter_lim=3)
		for k in rank_list:
			if k not in all_errors['local_alt2']: all_errors['local_alt2'][k] = []
			all_errors['local_alt2'][k].append(errors[k])
			#save(config_id, 'local_alt2', k, errors[k][0], errors[k][1])

	if 'local_other_alt2' in algorithms:
		print('local_other_alt2')
		if 'local_other_alt2' not in all_errors: all_errors['local_other_alt2'] = {}
		for k in rank_list:
			_, _, errors = \
				local_other(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rse', iter_lim=3)
			if k not in all_errors['local_other_alt2']: all_errors['local_other_alt2'][k] = []
			all_errors['local_other_alt2'][k].append(errors[k])
			#save(config_id, 'local_other_alt2', k, errors[k][0], errors[k][1])

	if 'local_new_alt2' in algorithms:
		print('local_new_alt2')
		if 'local_new_alt2' not in all_errors: all_errors['local_new_alt2'] = {}
		for k in rank_list:
			_, _, errors = \
				local_new(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rse', iter_lim=3)
			if k not in all_errors['local_new_alt2']: all_errors['local_new_alt2'][k] = []
			all_errors['local_new_alt2'][k].append(errors[k])
			#save(config_id, 'local_new_alt2', k, errors[k][0], errors[k][1])

	if 'local_newnew_alt2' in algorithms:
		print('local_newnew_alt2')
		if 'local_newnew_alt2' not in all_errors: all_errors['local_newnew_alt2'] = {}
		warm_start = None
		for k in rank_list:
			_, _, errors, warm_start = \
				local_newnew(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rse', iter_lim=3, warm_start=warm_start)
			if k not in all_errors['local_newnew_alt2']: all_errors['local_newnew_alt2'][k] = []
			all_errors['local_newnew_alt2'][k].append(errors[k])
			#save(config_id, 'local_newnew_alt2', k, errors[k][0], errors[k][1])

	if 'greedy_opt' in algorithms:
		print('greedy_opt')
		if 'greedy_opt' not in all_errors: all_errors['greedy_opt'] = {}
		_, _, errors = \
			greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='opt', err_type='rse')
		for k in rank_list:
			if k not in all_errors['greedy_opt']: all_errors['greedy_opt'][k] = []
			all_errors['greedy_opt'][k].append(errors[k])
			#save(config_id, 'greedy_opt', k, errors[k][0], errors[k][1])

	if 'greedy_opt2' in algorithms:
		print('greedy_opt2')
		if 'greedy_opt2' not in all_errors: all_errors['greedy_opt2'] = {}
		_, _, errors = \
			greedy(csr_matrix(M), omega1, k = max(rank_list), omega2=omega2, method='opt', err_type='rse', iter_lim=100)
		for k in rank_list:
			if k not in all_errors['greedy_opt2']: all_errors['greedy_opt2'][k] = []
			all_errors['greedy_opt2'][k].append(errors[k])
			#save(config_id, 'greedy_opt2', k, errors[k][0], errors[k][1])

	if 'soft' in algorithms:
		print('soft')
		if 'soft' not in all_errors: all_errors['soft'] = {}
		ranks_found = {}
		rank_list_set = set(rank_list)
		for k in rank_list:
			if k in ranks_found: continue
			U, V, errors = \
				soft(csr_matrix(M), omega1, k = k, omega2=omega2, err_type='rse', to_avoid = ranks_found)
			#for kk in errors:
		#		print(kk, 'error', errors[kk][0], errors[kk][1])
			for kk in errors:
				if kk not in ranks_found and kk in rank_list_set:
					print(kk, 'error', errors[kk][0], errors[kk][1])
					if kk not in all_errors['soft']: all_errors['soft'][kk] = []
					all_errors['soft'][kk].append(errors[kk])
					#save(config_id, 'soft', k, errors[k][0], errors[k][1])
				ranks_found[kk] = True
	'''
	for k in rank_list:
		if 'soft' in algorithms:
			print('soft')
			U, V, train_error, val_error, rank = \
				soft(csr_matrix(M), omega1, k = k, omega2=omega2)
			#print('rank', rank, k)
			print(rank, 'error', train_error, val_error)
			if rank == k:
				save(config_id, 'soft', k, train_error, val_error)
		if 'local_alt' in algorithms:
			print('local_alt')
			_, _, train_error, val_error = \
				local(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', T = k + 4)
			save(config_id, 'local_alt', k, train_error, val_error)
		if 'local_opt' in algorithms:
			print('local_opt')
			_, _, train_error, val_error = \
				local(csr_matrix(M), omega1, k = k, omega2=omega2, method='opt', T = k + 4)
			save(config_id, 'local_opt', k, train_error, val_error)
		if 'greedy_alt' in algorithms:
			print('greedy_alt')
			_, _, train_error, val_error = \
				greedy(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', err_type='rse')
			save(config_id, 'greedy_alt', k, train_error, val_error)
		if 'greedy_opt' in algorithms:
			print('greedy_opt')
			_, _, train_error, val_error = \
				greedy(csr_matrix(M), omega1, k = k, omega2=omega2, method='opt')
			save(config_id, 'greedy_opt', k, train_error, val_error)
		if 'greedy_reg_alt' in algorithms:
			print('greedy_reg_alt')
			_, _, train_error, val_error = \
				greedy_reg(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt', rho=0.02, err_type='rse')
			save(config_id, 'greedy_reg', k, train_error, val_error)
		if 'mysvd_alt' in algorithms:
			print('mysvd_alt')
			_, _, train_error, val_error = \
				mysvd(csr_matrix(M), omega1, k = k, omega2=omega2, method='alt')
			save(config_id, 'mysvd_alt', k, train_error, val_error)
		if 'mysvd_opt' in algorithms:
			print('mysvd_opt')
			_, _, train_error, val_error = \
				mysvd(csr_matrix(M), omega1, k = k, omega2=omega2, method='opt')
			save(config_id, 'mysvd_opt', k, train_error, val_error)
	'''
for algorithm in all_errors:
	for k in all_errors[algorithm]:
		cur_errors = all_errors[algorithm][k]
		cur_data = np.concatenate([np.array(x).reshape(1,-1) for x in cur_errors], axis=0)
		cur_mean = np.mean(cur_data, axis=0)
		cur_std = np.std(cur_data, axis=0)
		save(config_id, algorithm, k, cur_mean, cur_std)
