#coding:utf-8

import argparse
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS
import os
import pandas as pd
import subprocess as sb
import yaml

import sys
sys.path.insert(0,"../src/")
from bdivrec.fabaphe.fabaphe import Fabaphe
from bdivrec.fabaphe.utils import chunks, seed_everything
from bdivrec.fabaphe.kernels import DenudedKernel
from known_setting.baselines_known import *
from known_setting.known_envs import *
from unknown_setting.baselines_unknown import *
from unknown_setting.unknown_envs import *
from unknown_setting.learners import *
from adaptive_lambda.adaptive_lambda import AdaHedgeInit
from simulate import run_single_user
from plots import plot_grid_single_user, plot_fulltraj_single_user, plot_regret

parser = argparse.ArgumentParser(description='Fabaphe')
parser.add_argument('--recommenders', type=str, help="Recommender systems delimited by commas", default="Fabaphe")
parser.add_argument('--recom_types', type=str, help="Type of recommendations, delimited by commas", default="SAMPLE", choices=["SAMPLE", "MAX", "MAX,SAMPLE", "SAMPLE,MAX"])
parser.add_argument('--T', type=int, help="Trajectory length", default=5)
parser.add_argument('--B', type=int, help="Batch size", default=3)
parser.add_argument('--users', type=str, help="User identifiers to consider, delimited by commas", default="0")
parser.add_argument('--niters', type=int, help="Number of iterations of experiments", default=1)
parser.add_argument('--config', type=str, help="Path to config .yml file", default="config.yml") 
parser.add_argument('--name', type=str, help="Experiment name", default="")
parser.add_argument('--adaptive_tradeoff', type=int, help="If true, performs an adaptive quality-diversity tradeoff", default=0, choices=[0,1])
## Data parameters
parser.add_argument('--envir', type=str, help="Environment", default="SyntheticCosine", choices=["SyntheticCosine", "SyntheticLinearGaussian", "MovieLens", "PREDICT(private)", "Gottlieb", "Cdataset", "DNdataset", "LRSSL", "PREDICT_Gottlieb", "TRANSCRIPT", "PREDICT"])
parser.add_argument('--learner', type=str, help="Type of learner (only in the unknown setting)", default="None", choices=["None","Linear"])
parser.add_argument('--nitem', type=int, help="Number of items (only for SyntheticCosine)", default=1000) 
parser.add_argument('--nuser', type=int, help="Number of users (only for SyntheticCosine)", default=10) 
parser.add_argument('--ngroup', type=int, help="Number of collinear groups (only for SyntheticCosine)", default=3) 
parser.add_argument('--nvar', type=float, help="Offset for collinear groups (only for SyntheticCosine)", default=0.01) 
parser.add_argument('--d', type=int, help="User/item embedding dimension (only for SyntheticCosine)", default=5) 
parser.add_argument('--quantize', type=float, help="Quantization number for values (only for SyntheticCosine)", default=-1) 
parser.add_argument('--new_envir', type=int, help="Erases and creates a new data set (only for SyntheticCosine)", default=0, choices=[0,1]) 
parser.add_argument('--qthres', type=float, help="Threshold to distinguish between good and bad items", default=0.5)
## Model parameters
parser.add_argument('--lbd', type=float, help="Relative weight of the relevance task", default=0.5)
parser.add_argument('--c', type=float, help="Numerical stability", default=2.)
parser.add_argument('--alpha', type=float, help="Maximum distance in the user history (only for Fabaphe)", default=0.)
parser.add_argument('--eta', type=float, help="Regularization factor", default=0.01)
parser.add_argument('--epsilon', type=float, help="Percentage of greedy strategy (only for EpsGreedy)", default=0.1)
## Kernel parameters
parser.add_argument('--kernel', type=str, help="Kernel type", default="linear", choices=list(PAIRWISE_KERNEL_FUNCTIONS.keys()))
parser.add_argument('--beta', type=float, help="Denuding factor for the kernel", default=-1)
parser.add_argument('--n_components', type=int, help="Number of components for Nystroem approximation", default=25)
## Miscellaneous
parser.add_argument('--shared_learners', type=bool, help="If set to True, share all online learners", default=False)
parser.add_argument('--seed', type=int, help="Random seed number", default=1234)
parser.add_argument('--folder', type=str, help="Result folder", default="")
parser.add_argument('--data_folder', type=str, help="Data folder", default="")
parser.add_argument('--raw_data_folder', type=str, help="Raw data folder", default="../dw_datasets/")
parser.add_argument('--clean', type=int, help="Erases prior results with the same configuration", default=0, choices=[0,1]) 
parser.add_argument('--verbose', type=bool, help="Verbose", default=False)
args = parser.parse_args()

assert (len(args.config)!=0 and ".yml"==args.config[-4:] and os.path.exists(args.config)) or (len(args.config)==0)
if (len(args.config)!=0 and ".yml"==args.config[-4:] and os.path.exists(args.config)):
	with open(args.config, "r") as f:
		old_config = yaml.safe_load(f)
	new_config = vars(args)
	new_config.update(old_config)
	args = argparse.Namespace(**new_config)

if (len(args.folder)==0):
	args.folder = f"../results/{args.name}/"
if (bool(args.clean) and os.path.exists(args.folder)):
	proc = sb.Popen(f"rm -rf {args.folder}".split(" "))
	proc.wait()
if (not os.path.exists(args.folder)):
	proc = sb.Popen(f"mkdir -p {args.folder}".split(" "))
	proc.wait()
if (len(args.data_folder)==0):
	args.data_folder = f"../datasets/{args.name}/"
if (not os.path.exists(args.data_folder)):
	proc = sb.Popen(f"mkdir -p {args.data_folder}".split(" "))
	proc.wait()
with open(f"{args.folder}/config.yml", "w") as f:
	yaml.safe_dump(vars(args), f)
args.recommenders = args.recommenders.split(",")
assert ((args.learner == "None") and all([r in all_baselines_known+["Fabaphe"] for r in args.recommenders])) or ((args.learner != "None") and all([r in all_baselines_unknown+["FabapheUCB"] for r in args.recommenders])) 
args.recom_types = args.recom_types.split(",")
assert args.T>0 or args.T==-1
assert args.B>0
args.users = list(map(int,args.users.split(",")))
assert args.niters>0
assert args.nitem>0
assert args.nuser>0
print(args.name)
assert len(args.name)>0
assert not (args.envir == "SyntheticCosine") or all([u < args.nuser and u >= 0 for u in args.users])
assert args.ngroup > 1
assert args.nvar > 0
assert args.d>0
assert args.lbd>=0 and args.lbd<=1
assert args.c>1
assert args.alpha>=0
assert args.eta>=0
assert args.epsilon>=0 and args.epsilon<=1
args.beta = None if (args.beta<0) else args.beta
assert args.n_components>0
assert (args.envir != "SyntheticCosine") or (args.n_components<=args.d)
assert args.seed>0
assert ((args.learner == "None") and (args.envir not in ["SyntheticLinearGaussian"])) or ((args.learner != "None") and (args.envir in ["SyntheticLinearGaussian"]))
args.adaptive_tradeoff = bool(args.adaptive_tradeoff)

rng = seed_everything(args.seed)
seeds = rng.choice(int(max(args.niters,1e8)), size=args.niters)
results_mean, results_var, N = None, None, 0
for niter, seed in enumerate(seeds):
	rng = np.random.default_rng(seed)
	env_name = args.data_folder
	results_user_seed_fname = f"{args.folder}/final_user=%d_seed={seed}.csv" 
	results_user_seed = {}
	for rec_name in args.recommenders:
		for rec_type in args.recom_types:
			for user in args.users:
				if (args.T>0):
					recom = eval(rec_name)(
							dict(lbd=args.lbd, c=args.c, eta=args.eta, 
								rec_type=rec_type, seed=args.seed, 
								alpha=args.alpha, epsilon=args.epsilon, 
								T=args.T
							)
					)
					nm = f"{args.folder}/final_user={user}_seed={seed}.csv"
					if (os.path.exists(nm)):
						result = pd.read_csv(nm, index_col=0)
					else:
						break
				else: ## T=-1 and we recommend at each time 1, 2, ..., N where N is the size of the user history
					nm = f"{args.folder}/final_user={user}_seed={seed}.csv"
					if (os.path.exists(nm)):
						result = pd.read_csv(nm, index_col=0)
					else:
						break
				di_user = results_user_seed.get(user, {})
				di_user.update(result.to_dict())
				results_user_seed.update({user: di_user})
	print(f"niter x #users = {(niter,len(results_user_seed))}")
	for user in args.users:
		results_user_seed_df = results_user_seed.get(user, None)
		if (results_user_seed_df is None):
			continue
		df = pd.DataFrame(results_user_seed_df)
		dff = df.drop(["recs"]+(["regret"] if ("regret" in df.index) else [])+(["lbds"] if ("lbds" in df.index) else [])+(["value_star"] if ("value_star" in df.index) else []))
		if (results_mean is None):
			results_mean = dff.astype(float)
			results_var = pd.DataFrame([], index=results_mean.index, columns=results_mean.columns).fillna(0)
		else:
			results_meanm1 = results_mean.copy()
			results_mean = (results_mean*N+dff.astype(float))/(N+1)
			results_var = (N-1)/N*results_var+1/(N+1)*(results_mean-results_meanm1).pow(2)
		N += 1
results = results_mean.round(3).astype(str)+" +-"+results_var.pow(0.5).round(3).astype(str)
results_latex = results_mean.round(2).astype(str)+" $\pm$"+results_var.pow(0.5).round(2).astype(str)
results_latex = results_latex.T
results_latex.columns = [{
			"relevance": "Relevance",
			"intrabatch_div": "Intra. div.",
			"interbatch_div": "Inter. div.", 
			"precision": "Precision",
			"pos_div": "Summary", 
			"nunique": "\%unique rec.", 
			"runtime": "Runtime (sec.)",
			"cum_regret": "Cum. regret",
			"lbd_final": "$\lambda_f$",
			"value": "value",
			"lbd_star": "$\lambda^\star$"
		}[i] for i in results_latex.columns]
results_latex = results_latex[["Relevance","Precision","Intra. div.","Inter. div.", "Summary", "\%unique rec.", "Runtime (sec.)"]+(["Cum. regret"] if (args.learner != "None") else [])+(["$\lambda_f$","value","$\lambda^\star$"] if (args.adaptive_tradeoff) else [])]
results_latex = results_latex.loc[
	[i for i in results_latex.index if ("_SAMPLE" in i and "Fabaphe" not in i)]
	+(["Fabaphe_SAMPLE"] if ("Fabaphe_SAMPLE" in results_latex.index) else [])
	+(["FabapheUCB_SAMPLE"] if ("FabapheUCB_SAMPLE" in results_latex.index) else [])
	+[i for i in results_latex.index if ("_MAX" in i and "Fabaphe" not in i)]
	+(["Fabaphe_MAX"] if ("Fabaphe_MAX" in results_latex.index) else [])
	+(["FabapheUCB_MAX"] if ("FabapheUCB_MAX" in results_latex.index) else [])
	]
print(results_latex[["Relevance","Precision","Intra. div.","Inter. div.","Summary","Runtime (sec.)"]])
for lst in chunks(results.shape[1], 5):
	print(results[results.columns[lst]])
