#coding:utf-8

import numpy as np
from scipy.sparse import vstack, lil_array
from scipy.sparse.linalg import norm
import sparseqr

from bdivrec.fabaphe.utils import volume

def relevance_fast(Phi_S, u, env):
	return float(np.mean(env.feedback_fast(Phi_S, u)))

def intrabatch_div_fast(Phi_S, K, eta):
	return float(volume(Phi_S, K, eta))

def interbatch_div_fast(Phi_S, Phi_H, H, K, eta):
	if (len(H)>0):
		PSH = vstack((Phi_S,Phi_H))
	else:
		PSH = Phi_S
	return float(volume(PSH, K))

def relevance(S, u, env):
	return float(np.mean(env.feedback(S, u)))

def intrabatch_div(S, env, K, eta):
	Phi_S = env.item_embs(S)
	return float(volume(Phi_S, K, eta))
	
def precision(S, u, env, thres=0.5):
	qs = env.feedback(S, u)
	return (qs>=thres).sum()/len(S)
	
def precision_fast(Phi_S, u, env, thres=0.5):
	qs = env.feedback_fast(Phi_S, u)
	return (qs>=thres).sum()/Phi_S.shape[0]

def interbatch_div(S, u, env, K, eta):
	H = env.get_user_hist(u)
	Phi_S = env.item_embs(S)
	if (len(H)>0):
		Phi_H = env.item_embs(H)
		PSH = vstack((Phi_S,Phi_H))
	else:
		PSH = Phi_S
	return float(volume(PSH, K))
	
def dist_2_space(S, u, env): 
	H = env.get_user_hist(u)
	Phi_S = env.item_embs(S)
	if (len(H)>0):
		Phi_H = env.item_embs(H)
		PSH = vstack((Phi_S,Phi_H))
	else:
		PSH = Phi_S
	# Perform QR decomposition
	## PSH.T = Q @ lil_array(R)[:,EE]
	Q, R, E, _ = sparseqr.qr(PSH.T)
	EE = np.argsort(E)
	RR  = lil_array(R)[:,EE]
	arr = np.abs(RR.diagonal())
	return arr[arr!=0].prod()
	
def dist_2_space_fast(Phi_S, Phi_H, H, K, eta): 
	if (len(H)>0):
		PSH = vstack((Phi_S,Phi_H))
	else:
		PSH = Phi_S
	# Perform QR decomposition
	## PSH.T = Q @ lil_array(R)[:,EE]
	Q, R, E, _ = sparseqr.qr(PSH.T)
	EE = np.argsort(E)
	RR  = lil_array(R)[:,EE]
	arr = np.abs(RR.diagonal())
	return arr[arr!=0].prod()

def totalpositive_div(u, env, K, eta, thres=0.5, recs=None):
	if (recs is None):
		H = env.get_user_hist(u)
	else:
		H = list(set([r for rl in recs for r in rl]))
	scores = env.feedback(H, u).toarray().ravel()
	positive = scores>=thres
	Hpos = [h for i, h in enumerate(H) if (positive[i])]
	if (len(Hpos)>=1):
		Phi_Hpos = env.item_embs(Hpos)
		volHpos = volume(Phi_Hpos, K, eta)
	else:
		volHpos = 0
	return float(volHpos)

def nunique_recs(recs):
	return len(set([rec for r in recs for rec in r]))

if __name__ == "__main__":
	from time import time
	from kernels import DenudedKernel
	from known_setting.known_envs import TabularNoHistory
	folder_name = "../../datasets/Tabular_test_metrics"
	items1 = np.array([
	    [1, 0],
	    [1.2, 0.1],
	    [0, 1],
	    [-1, 0.1],
	    [-1.2, 0],
	    [-1.4, 0.3],
	    [0, -1]
	])
	users1 = np.array([[0, 1]])
	items2 = np.eye(7)
	users2 = items2[-1,:].reshape((1,-1))
	for items, users in [(items1, users1), (items2, users2)]:
		u = 0
		eta = 1e-3
		seed = 1234
		feedbacks = items @ users.T
		env = TabularNoHistory(dict(name=folder_name, items=items, users=users, feedbacks=feedbacks, seed=seed, new=True))
		beta = 0.
		assert str(type(env))=="<class 'known_setting.known_envs.TabularNoHistory'>"
		print("Feedback")
		print(env.feedback_slice(0, env.nitem, 0).toarray().ravel())
		for nc in [items.shape[1]]:
			K = DenudedKernel(kernel="linear", n_components=nc, seed=seed, beta=beta)
			all_recs = [[0,1], [2,6], [1,2,6]]
			for recs in all_recs:
				print(f"\n**** Recommendations: {recs}")
				print(f"Relevance {relevance(recs, u, env)}")
				print(f"Intrabatch {intrabatch_div(recs, env, K, eta)}")
				print(f"Interbatch {interbatch_div(recs, u, env, K, eta)}")
				env.update_user_hist(u, recs)
			print(f"\nTotal positive diversity {totalpositive_div(u, env, K, eta, thres=0.1)}")
			print(f"#unique recommendations {nunique_recs(all_recs)}")
		print("\n")
