#coding:utf-8

from tqdm import tqdm
from scipy.sparse import vstack
from time import time
import pandas as pd
import os
import subprocess as sb
from copy import deepcopy

from metrics import *
from adaptive_lambda.adaptive_lambda import AdaHedgeInit, compute_best_post_lbd
from bdivrec.fabaphe.utils import set_score

def run_single_user(user, T, B, recommender, env, K, seed, thres=0.5, checkpoint_fname=None, new_checkpoints=False, checkpoint_every=1, learner_type="None", adaptive_tradeoff=False, nu_with_hist=False, learner=None, adaptive=None, verbose=False):
	rng = np.random.default_rng(seed)
	H = env.get_user_hist(user)
	H0 = np.array(H).copy()
	nu0 = nunique_recs([H])
	if (new_checkpoints and checkpoint_fname is not None):
		proc = sb.Popen(f"rm -f {checkpoint_fname}".split(" "))
		proc.wait()
	if (checkpoint_fname is not None):
		if (not os.path.exists(checkpoint_fname)):
			if ("/" in checkpoint_fname):
				checkpoint_folder = checkpoint_fname[:-1] if (checkpoint_fname[-1]=="/") else checkpoint_fname
				checkpoint_folder = "/".join(checkpoint_folder.split("/")[:-1])
				proc = sb.Popen(f"mkdir -p {checkpoint_folder}".split(" "))
				proc.wait()
			cum_rel, cum_iadiv, cum_iediv, nu, cum_t, current_T, recs, regret, cum_regret, lbds, value, rels_total, lbd_star, value_star, cum_precision = 0, 0, 0, 0, 0, 0, [], [0], 0, [], 0, [], -1, 0, 0
		else:
			cum_results = pd.read_csv(checkpoint_fname, index_col=0)["0"].to_dict()
			cum_rel, cum_iadiv, cum_iediv, nu, cum_t, current_T, recs, regret, cum_regret, lbds, value, rels_total, lbd_star, value_star, cum_precision = [cum_results.get(k, 0) for k in ["cum_rel", "cum_iadiv", "cum_iediv", "nu", "cum_t", "current_T", "recs","regret","cum_regret","lbds","value","rels_total","lbd_star","value_star","cum_precision"]]
			cum_rel, cum_iadiv, cum_iediv, nu, cum_t = list(map(float, [cum_rel, cum_iadiv, cum_iediv, nu, cum_t]))
			cum_precision = float(cum_precision)
			current_T = int(current_T)
			recs = eval(recs)
			if (regret==0):
				regret = [0]
			if (lbds==0):
				lbds = []
			regret = eval(str(regret))
			rec_list = list(set([r for rl in recs for r in rl]))
			env.update_user_hist(user, rec_list)
	else:
		cum_rel, cum_iadiv, cum_iediv, nu, cum_t, current_T, recs, regret, cum_regret, lbds, value, rels_total, lbd_star, value_star, cum_precision = 0, 0, 0, 0, 0, 0, [], [0], 0, [], 0, [], -1, 0, 0
	if (len(H)==0):
		Phi_H = None
	else:
		Phi_H = env.item_embs(H)
	## Learner
	if ((learner is None) and (learner_type != "None")):
		learner = eval(learner_type)(dict(nelement=env.nitem,d=env.d+env.d_user,reg_eta=recommender.eta))
	if (learner is not None): 
		if (len(H)==0): ## sample at random to avoid cold start
			H_user = rng.choice(env.nitem, size=5, replace=False).ravel().tolist()
		else:
			H_user = deepcopy(H)
		## update the learner according to the history
		rels_user = env.observed_feedback(H_user, user)
		learner.update(H_user, rels_user, user, env, K)
	if (adaptive_tradeoff and (adaptive is None)):
		adaptive = AdaHedgeInit(recommender.lbd, 0.01)	
	if (adaptive is not None):
		recommender.lbd = adaptive.act()
	lbd_init = recommender.lbd
	for it in (pbar := tqdm(
        range(current_T, T),
        position=4,
        leave=False,
        disable=not verbose or T - current_T < 2
    )):
		if (recommender.name in ["kMarkovDPP"] and it==0): ## call only once
			start_time = time()
			recs = recommender.recommend(B, user, env, K, S=None)
			runtime = time()-start_time
			for S in recs:
				H = env.get_user_hist(user)
				Phi_S = env.item_embs(S)
				if (Phi_H is None):
					Phi_H = Phi_S.copy()
				else:
					Phi_H = vstack((Phi_S, Phi_H))
				#Phi_H = env.item_embs(H)
				if (str(type(env)) in ["<class 'known_setting.SyntheticCosine'>", "<class 'unknown_setting.SyntheticLinearGaussian'>"]):
					cum_rel += relevance_fast(Phi_S, user, env)
				else:
					cum_rel += relevance(S, user, env)
				#cum_iadiv += intrabatch_div(S, env, K, recommender.eta)
				cum_iadiv += intrabatch_div_fast(Phi_S, K, recommender.eta)
				#cum_iediv += interbatch_div(S, user, env, K, recommender.eta)
				iev = interbatch_div_fast(Phi_S, Phi_H, H, K, recommender.eta)
				cum_iediv += iev
				if (str(type(env)) in ["<class 'known_setting.SyntheticCosine'>", "<class 'unknown_setting.SyntheticLinearGaussian'>"]):
					cum_precision += precision_fast(Phi_S, user, env, thres=thres)
				else:
					cum_precision += precision(S, user, env, thres=thres)
				cum_t += runtime
				current_T = it+1
				env.update_user_hist(user, S)
				lbds.append(recommender.lbd)
				value += 0 ## ignored
				if (adaptive_tradeoff): 
					raise ValueError(f"{recommender.name} not compatible with adaptive tradeoff")
			pbar.set_description(f"\n\nT={current_T} REL={np.round(cum_rel/current_T,2)}, ADIV={np.round(cum_iadiv/current_T,2)}, EDIV={np.round(cum_iediv/current_T,2)}\n\n")
		elif (recommender.name in ["kMarkovDPP"] and it>0):
			continue
		else:
			start_time = time()
			if (learner is None):
				S = recommender.recommend(B, user, env, K, S=None)
			else:
				S = recommender.recommend(B, user, env, K, learner, S=None)
			runtime = time()-start_time
			H = env.get_user_hist(user)
			Phi_S = env.item_embs(S)
			if (str(type(env)) in ["<class 'known_setting.SyntheticCosine'>", "<class 'unknown_setting.SyntheticLinearGaussian'>"]):
				cum_rel += relevance_fast(Phi_S, user, env)
			else:
				cum_rel += relevance(S, user, env)
			#cum_iadiv += intrabatch_div(S, env, K, recommender.eta)
			cum_iadiv += intrabatch_div_fast(Phi_S, K, recommender.eta)
			#cum_iediv += interbatch_div(S, user, env, K, recommender.eta)
			iev = interbatch_div_fast(Phi_S, Phi_H, H, K, recommender.eta)
			cum_iediv += iev
			if (str(type(env)) in ["<class 'known_setting.SyntheticCosine'>", "<class 'unknown_setting.SyntheticLinearGaussian'>"]):
				cum_precision += precision_fast(Phi_S, user, env, thres=thres)
			else:
				cum_precision += precision(S, user, env, thres=thres)
			cum_t += runtime
			current_T = it+1
			recs += [S]
			## update history and history-related values
			env.update_user_hist(user, S)
			if (Phi_H is None):
				Phi_H = Phi_S.copy()
			else:
				Phi_H = vstack((Phi_S, Phi_H))
			lbds.append(recommender.lbd) 
			## update learner if needed according to the recommendations
			if (learner is not None):
				#rels_S = env.observed_feedback_fast(Phi_S, user)
				rels_S = env.observed_feedback(S, user)
				learner.update(S, rels_S, user, env, K)
				if (adaptive_tradeoff):
					if (recommender.name == "MMR"):
						raise ValueError(f"{recommender.name} not compatible with adaptive tradeoff")
					rels_total.append(rels_S)
					grads, val_lbd = recommender.grad_value_f(rels_S, user, env, K, S) 
					adaptive.incur(-grads) 
					recommender.lbd = adaptive.act()
					value += val_lbd 
			elif (adaptive_tradeoff):
				if (recommender.name == "MMR"):
					raise ValueError(f"{recommender.name} not compatible with adaptive tradeoff")
				#rels_S = env.feedback_fast(Phi_S, user)
				rels_S = env.feedback(S, user)
				rels_total.append(rels_S)
				grads, val_lbd = recommender.grad_value_f(rels_S, user, env, K, S) 
				adaptive.incur(-grads) 
				recommender.lbd = adaptive.act()
				value += val_lbd 
			pbar.set_description(f"\n\nT={current_T} REL={np.round(cum_rel/current_T,2)}, ADIV={np.round(cum_iadiv/current_T,2)}, EDIV={np.round(cum_iediv/current_T,2)}"+(f", lambda={recommender.lbd}" if (adaptive_tradeoff) else "")+"\n\n")
		if (it%checkpoint_every==0 and checkpoint_fname is not None):
			cum_results = pd.DataFrame({
				0: dict(cum_rel=cum_rel, 
					cum_iadiv=cum_iadiv, 
					cum_iediv=cum_iediv, 
					nu=nu, 
					cum_t=cum_t, 
					current_T=current_T,
					recs=recs,
					regret=regret,
					cum_regret=regret[-1],
					lbds=lbds,
					value=value,
					cum_precision=cum_precision)
			})
			cum_results.to_csv(checkpoint_fname)	
	assert len(recs)==T
	assert len([i for ls in recs for i in ls])==T*B
	if (learner is not None):
		learner.reset()
	if (nu_with_hist):
		nu = nunique_recs(recs+[H0.tolist()])/(T*B+len(H0.tolist())) 
	else:
		nu = (nunique_recs(recs+[H0.tolist()])-nu0)/(T*B) 
	totaldiv = totalpositive_div(user, env, K, recommender.eta, thres=thres, recs=recs) # recs=None # consider with the initial history
	if (adaptive_tradeoff and lbd_star<0):
		lbd_star, value_star = compute_best_post_lbd(recs, rels_total, lbd_init, recommender, user, env, K)
	results = dict(
		relevance=cum_rel/T, 
		intrabatch_div=cum_iadiv/T, 
		interbatch_div=cum_iediv/T, 
		pos_div=totaldiv, 
		nunique=nu, 
		recs=recs,
		runtime=cum_t/T,
		precision=cum_precision/T
	)
	if (learner is not None):
		results.update(dict(regret=regret,cum_regret=regret[-1]))
	if (adaptive_tradeoff):
		if ("str" in str(type(lbds))):
			lbd_f = eval(lbds)[-1]
		else:
			lbd_f = lbds[-1]
		results.update(dict(lbds=lbds,value=value_star-float(value),lbd_final=lbd_f,lbd_star=lbd_star,value_star=value_star))
	return results
    
if __name__ == "__main__":
	from time import time
	from kernels import DenudedKernel
	import sys
	sys.path.insert(0,"../src/")
	from bdivrec.fabaphe import Fabaphe
	from utils import chunks
	seed = 1234
	eta = 1
	lbd = 0.8
	epsilon = 0.1
	beta = None
	alpha = 0.5
	nc = 5
	B = 5
	T = 3 
	thres = 0.5
	user = 0
	nitem = 5000
	result_folder_name = "../results/Synthetic/DenudedKernel"
	result_folder_name += (f"_beta={beta}" if (beta is not None) else "")
	result_folder_name += (f"_nitem={nitem}" if (nitem != 500000) else "")
	result_folder_name += "/"
	rng = np.random.default_rng(seed)
	add_to_hist = rng.choice(nitem, size=5, replace=False).ravel().tolist()
	results = {}
	K = DenudedKernel(kernel="linear", n_components=nc, seed=seed, beta=beta)
	
	## Known setting
	from known_setting.known_envs import SyntheticCosine
	from known_setting.baselines_known import *
	env = SyntheticCosine(dict(name="../../datasets/Synthetic", nitem=nitem, nuser=100, d=25, seed=seed, 
		quantize_digit=-1, new=True))
	Phi = env.item_embs_slice(0, env.nitem)
	_, _ = K(Phi) ## fit the kernel in advance
	for rec_name in all_baselines_known+["Fabaphe"]: 
		if (rec_name in ["kMarkovDPP"] and env.nitem>5000):
			continue
		for rec_type in ["SAMPLE", "MAX"]: 
			## set history for user 0 (reinitialize at each call)
			env.set_user_hist(user, add_to_hist)
			print(f"{rec_name}_{rec_type}")
			recom = eval(rec_name)(dict(lbd=lbd, c=2, eta=eta, rec_type=rec_type, seed=seed, alpha=alpha, epsilon=epsilon, T=T))
			starttt=time()
			result = run_single_user(user, T, B, recom, env, K, seed, thres=thres, 
				checkpoint_fname=f"{result_folder_name}/{rec_name}_{rec_type}.csv", 
				new_checkpoints=False, 
				checkpoint_every=1)
			result.setdefault("total_time", time()-starttt)
			result = {f"{rec_name}_{rec_type}": result}
			results.update(result)
			print("")
	pd.DataFrame(results).to_csv(f"{result_folder_name}/final_{env.name.split('/')[-1]}_{K.name}.csv")
	df = pd.DataFrame(results)
	df = df.loc[[i for i in df.index if (i != "recs")]]
	for lst in chunks(df.shape[1], 3):
		print(df[df.columns[lst]])
	#proc = sb.Popen(f"rm -rf {result_folder_name}".split(" "))
	#proc.wait()
