#coding:utf-8

import argparse
from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS
import os
import pandas as pd
import subprocess as sb
from tqdm import tqdm
import yaml

import sys
sys.path.insert(0,"../src/")
from bdivrec.fabaphe.fabaphe import Fabaphe
from bdivrec.fabaphe.utils import chunks, seed_everything
from bdivrec.fabaphe.kernels import DenudedKernel
from known_setting.baselines_known import *
from known_setting.known_envs import *
from adaptive_lambda.adaptive_lambda import AdaHedgeInit
from simulate import run_single_user
from plots import plot_grid_single_user, plot_fulltraj_single_user, plot_regret

pd.set_option('future.no_silent_downcasting', True)

parser = argparse.ArgumentParser(description='Fabaphe')
parser.add_argument('--recommenders', type=str, help="Recommender systems delimited by commas", default="Fabaphe")
parser.add_argument('--recom_types', type=str, help="Type of recommendations, delimited by commas", default="SAMPLE", choices=["SAMPLE", "MAX", "MAX,SAMPLE", "SAMPLE,MAX"])
parser.add_argument('--T', type=int, help="Trajectory length", default=5)
parser.add_argument('--B', type=int, help="Batch size", default=3)
parser.add_argument('--users', type=str, help="User identifiers to consider, delimited by commas", default="0")
parser.add_argument('--niters', type=int, help="Number of iterations of experiments", default=1)
parser.add_argument('--config', type=str, help="Path to config .yml file", default="config.yml") 
parser.add_argument('--name', type=str, help="Experiment name", default="")
parser.add_argument('--adaptive_tradeoff', type=int, help="If true, performs an adaptive quality-diversity tradeoff", default=0, choices=[0,1])
## Data parameters
parser.add_argument('--envir', type=str, help="Environment", default="SyntheticCosine", choices=["SyntheticCosine", "SyntheticLinearGaussian", "MovieLens", "PREDICT(private)", "Gottlieb", "Cdataset", "DNdataset", "LRSSL", "PREDICT_Gottlieb", "TRANSCRIPT", "PREDICT"])
parser.add_argument('--learner', type=str, help="Type of learner (only in the unknown setting)", default="None", choices=["None","Linear"])
parser.add_argument('--nitem', type=int, help="Number of items (only for SyntheticCosine)", default=1000) 
parser.add_argument('--nuser', type=int, help="Number of users (only for SyntheticCosine)", default=10) 
parser.add_argument('--ngroup', type=int, help="Number of collinear groups (only for SyntheticCosine)", default=3) 
parser.add_argument('--nvar', type=float, help="Offset for collinear groups (only for SyntheticCosine)", default=0.01) 
parser.add_argument('--d', type=int, help="User/item embedding dimension (only for SyntheticCosine)", default=5) 
parser.add_argument('--quantize', type=float, help="Quantization number for values (only for SyntheticCosine)", default=-1) 
parser.add_argument('--nadd_to_hist', type=int, help="Size of initial history (only for SyntheticCosine)", default=5) 
parser.add_argument('--new_envir', type=int, help="Erases and creates a new data set (only for SyntheticCosine)", default=0, choices=[0,1]) 
parser.add_argument('--qthres', type=float, help="Threshold to distinguish between good and bad items", default=0.5)
## Model parameters
parser.add_argument('--lbd', type=float, help="Relative weight of the relevance task", default=0.5)
parser.add_argument('--c', type=float, help="Numerical stability", default=2.)
parser.add_argument('--alpha', type=float, help="Maximum distance in the user history (only for Fabaphe)", default=0.)
parser.add_argument('--eta', type=float, help="Regularization factor", default=0.01)
parser.add_argument('--epsilon', type=float, help="Percentage of greedy strategy (only for EpsGreedy)", default=0.1)
parser.add_argument('--type_algo', type=str, help="Type algorithm (tree, only for Fabaphe)", default="FAISS", choices=["kd-tree", "FAISS"])
## Kernel parameters
parser.add_argument('--kernel', type=str, help="Kernel type", default="linear", choices=list(PAIRWISE_KERNEL_FUNCTIONS.keys()))
parser.add_argument('--beta', type=float, help="Denuding factor for the kernel", default=-1)
parser.add_argument('--n_components', type=int, help="Number of components for Nystroem approximation", default=25)
## Miscellaneous
parser.add_argument('--shared_learners', type=bool, help="If set to True, share all online learners", default=False)
parser.add_argument('--seed', type=int, help="Random seed number", default=1234)
parser.add_argument('--folder', type=str, help="Result folder", default="")
parser.add_argument('--data_folder', type=str, help="Data folder", default="")
parser.add_argument('--raw_data_folder', type=str, help="Raw data folder", default="../dw_datasets/")
parser.add_argument('--clean', type=int, help="Erases prior results with the same configuration", default=0, choices=[0,1]) 
parser.add_argument('--verbose', type=bool, help="Verbose", default=False)
args = parser.parse_args()

assert (len(args.config)!=0 and ".yml"==args.config[-4:] and os.path.exists(args.config)) or (len(args.config)==0)
if (len(args.config)!=0 and ".yml"==args.config[-4:] and os.path.exists(args.config)):
    with open(args.config, "r") as f:
        old_config = yaml.safe_load(f)
    new_config = vars(args)
    new_config.update(old_config)
    args = argparse.Namespace(**new_config)

if (len(args.folder)==0):
    args.folder = f"../results/{args.name}/"
if (bool(args.clean) and os.path.exists(args.folder)):
    proc = sb.Popen(f"rm -rf {args.folder}".split(" "))
    proc.wait()
if (not os.path.exists(args.folder)):
    proc = sb.Popen(f"mkdir -p {args.folder}".split(" "))
    proc.wait()
if (len(args.data_folder)==0):
    args.data_folder = f"../datasets/{args.name}/"
if (not os.path.exists(args.data_folder)):
    proc = sb.Popen(f"mkdir -p {args.data_folder}".split(" "))
    proc.wait()
with open(f"{args.folder}/config.yml", "w") as f:
    yaml.safe_dump(vars(args), f)
args.recommenders = args.recommenders.split(",")
assert(
    (
        (args.learner == "None")
        and
        set(args.recommenders).issubset(all_baselines_known+["Fabaphe"])
    ) or (
        (args.learner != "None")
    )
)
args.recom_types = args.recom_types.split(",")
assert args.T>0 or args.T==-1
assert args.B>0
args.users = list(map(int,args.users.split(",")))
assert args.niters>0
assert args.nadd_to_hist>=0
assert args.nitem>0
assert args.nuser>0
print(args.name)
assert len(args.name)>0
assert not (args.envir == "SyntheticCosine") or all([u < args.nuser and u >= 0 for u in args.users])
assert args.ngroup > 1
assert args.nvar > 0
assert args.d>0
assert args.lbd>=0 and args.lbd<=1
assert args.c>1
assert args.alpha>=0
assert args.eta>=0
assert args.epsilon>=0 and args.epsilon<=1
args.beta = None if (args.beta<0) else args.beta
assert args.n_components>0
assert (args.envir != "SyntheticCosine") or (args.n_components<=args.d)
assert args.seed>0
assert ((args.learner == "None") and (args.envir not in ["SyntheticLinearGaussian"])) or ((args.learner != "None") and (args.envir in ["SyntheticLinearGaussian"]))
args.adaptive_tradeoff = bool(args.adaptive_tradeoff)

rng = seed_everything(args.seed)
seeds = rng.choice(int(max(args.niters,1e8)), size=args.niters)
results_mean, results_var, N = None, None, 0
with tqdm(
    total=len(seeds),
    position=0,
    desc='Runs',
    unit='seed'
) as pbar_seed:
    for niter, seed in enumerate(seeds):
        pbar_seed.set_postfix_str(f"Seed: {seed}")
        rng = np.random.default_rng(seed)
        env_name = args.data_folder
        add_to_hist = None
        if (args.envir in ["SyntheticCosine", "SyntheticLinearGaussian"]):
            add_to_hist = rng.choice(
                args.nitem, size=args.nadd_to_hist, replace=False
            ).ravel().tolist()
            params = dict(
                name=env_name, nitem=args.nitem, nvar=args.nvar,
                ngroups=args.ngroup, nuser=args.nuser, d=args.d, seed=seed,
                quantize_digit=args.quantize, new=bool(args.new_envir)
            )
        elif (args.envir=="MovieLens"):
            ## https://files.grouplens.org/datasets/movielens/ml-latest-small.zip
            movielens_filepath = f'{args.raw_data_folder}ml-latest-small/'
            params = dict(name=env_name, movielens_filepath=movielens_filepath, seed=seed)
        elif (args.envir=="PREDICT(private)"):
            filepath = f'{args.raw_data_folder}PREDICT_private/'
            params = dict(name=env_name, filepath=filepath, seed=seed)
        else:
            params = dict(name=env_name, dataset_name=args.envir, seed=seed)
        env = eval(args.envir if (args.envir in ["SyntheticCosine","MovieLens","SyntheticLinearGaussian"]) else "DrugRepurposing")(params)
        K = DenudedKernel(kernel=args.kernel, n_components=args.n_components, seed=args.seed, beta=args.beta)
        sample_ids = rng.choice(
            env.nitem,
            size=min(env.nitem, args.n_components),
            replace=False
        )
        Phi = env.item_embs(sample_ids)
        _, _ = K(Phi) ## fit the kernel in advance
        results_user_seed_fname = f"{args.folder}/final_user=%d_seed={seed}.csv" 
        results_user_seed = {}
        with tqdm(
            total=len(args.recommenders) * len(args.recom_types),
            position=1,
            leave=False,
            desc=f'Recommenders[seed={seed}]',
            unit='model'
        ) as pbar_reco:
            for rec_name in args.recommenders:
                for rec_type in args.recom_types:
                    pbar_reco.set_postfix_str(f"Recommender: {rec_name}[{rec_type}]")
                    ## Initialize shared learners 
                    if ((args.learner != "None") and args.shared_learners):
                        learner = eval(args.learner)(dict(nelement=env.nitem,d=env.d+env.d_user,reg_eta=args.eta))
                    else:
                        learner = None
                    if (args.adaptive_tradeoff and args.shared_learners): 
                        adaptive = AdaHedgeInit(args.lbd, 0.01)    
                    else:
                        adaptive = None
                    with tqdm(
                        total=len(args.users),
                        position=2,
                        leave=False,
                        desc=f'Users[model={rec_name}_{rec_type}]',
                        unit='user'
                    ) as pbar_user:
                        for user in args.users:
                            pbar_user.set_postfix_str(f"User: {user}")
                            if add_to_hist is not None:
                                ## set history for user in synthetic data sets (reinitialize at each call)
                                env.set_user_hist(user, add_to_hist)
                            else:
                                env.reset_user_hist()
                            if (args.T>0):
                                recom = eval(rec_name)(
                                        dict(lbd=args.lbd, c=args.c, eta=args.eta, 
                                            rec_type=rec_type, seed=args.seed, 
                                            alpha=args.alpha, epsilon=args.epsilon, 
                                            T=args.T, type_algo=args.type_algo
                                        )
                                )
                                result = run_single_user(user, args.T, args.B, recom, env, K, args.seed, args.qthres, 
                                    checkpoint_fname=f"{args.folder}/{rec_name}_{rec_type}_user={user}_seed={seed}.csv", 
                                    new_checkpoints=False, checkpoint_every=1, learner_type=args.learner, 
                                    adaptive_tradeoff=args.adaptive_tradeoff, nu_with_hist=False, learner=learner, adaptive=adaptive)
                            else: ## T=-1 and we recommend at each time 1, 2, ..., N where N is the size of the user history
                                H_initial = env.get_user_hist(user)
                                result_T = None
                                with tqdm(
                                    total=len(H_initial)+1,
                                    position=3,
                                    leave=False,
                                    desc=f'Steps[user={user}]',
                                    unit='step'
                                ) as pbar_step:
                                    for T_test in range(len(H_initial)+1):
                                        env.set_user_hist(user, H_initial[:T_test])
                                        recom = eval(rec_name)(
                                                dict(lbd=args.lbd if (adaptive is None) else adaptive.act(), c=args.c, eta=args.eta, 
                                                    rec_type=rec_type, seed=args.seed, 
                                                    alpha=args.alpha, epsilon=args.epsilon, 
                                                    T=1
                                                )
                                        )
                                        result_T_test = run_single_user(user, 1, args.B, recom, env, K, args.seed, args.qthres, 
                                            checkpoint_fname=f"{args.folder}/{rec_name}_{rec_type}_user={user}_seed={seed}_T={T_test}.csv", 
                                            new_checkpoints=False, checkpoint_every=1, learner_type=args.learner, 
                                            adaptive_tradeoff=args.adaptive_tradeoff, nu_with_hist=True, learner=learner, adaptive=adaptive)
                                        env.set_user_hist(user, H_initial[:T_test])
                                        result_T_test = pd.DataFrame({T_test: result_T_test})
                                        if (result_T is None):
                                            result_T = result_T_test.copy()
                                        else:
                                            result_T = result_T.join(result_T_test)
                                        pbar_step.update(1)
                                res = result_T.drop("recs")
                                if ("regret" in res.index):
                                    res = res.drop("regret")
                                if ("lbds" in res.index):
                                    res = res.drop("lbds")
                                result = pd.DataFrame(res.mean(axis=1), columns=[0])
                                result.loc["recs"] = [[result_T.loc["recs"][0][0] for col in result_T.columns]]
                                if ("regret" in result_T.index):
                                    result.loc["regret"] = [[result_T.loc["regret"][0][0] for col in result_T.columns]]
                                if ("lbds" in result_T.index):
                                    result.loc["lbds"] = [[result_T.loc["lbds"][0][0] for col in result_T.columns]]
                                result = result.to_dict()[0]
                            result = {f"{rec_name}_{rec_type}": result}
                            di_user = results_user_seed.get(user, {})
                            di_user.update(result)
                            results_user_seed.update({user: di_user})
                            pbar_user.update(1)
                    pbar_reco.update(1)
        for user in args.users:
            results_user_seed_df = results_user_seed[user]
            df = pd.DataFrame(results_user_seed_df)
            df.to_csv(results_user_seed_fname % user)
            dff = df.drop(["recs"]+(["regret"] if ("regret" in df.index) else [])+(["lbds","value_star"] if ("lbds" in df.index) else []))
            if (results_mean is None):
                results_mean = dff.astype(float)
                results_var = pd.DataFrame([], index=results_mean.index, columns=results_mean.columns).fillna(0)
            else:
                results_meanm1 = results_mean.copy()
                results_mean = (results_mean*N+dff.astype(float))/(N+1)
                results_var = (N-1)/N*results_var+1/(N+1)*(results_mean-results_meanm1).pow(2)
            N += 1
        pbar_seed.update(1)

results = results_mean.round(3).astype(str)+" +-"+results_var.pow(0.5).round(3).astype(str)
results.to_csv(f"{args.folder}/final_summary.csv")
results_latex = results_mean.round(2).astype(str)+r" $\pm$"+results_var.pow(0.5).round(2).astype(str)
results_latex = results_latex.T
print(results_latex)
results_latex.columns = [{
            "relevance": "Relevance",
            "intrabatch_div": "Intra. div.",
            "interbatch_div": "Inter. div.", 
            "precision": "Precision",
            "pos_div": "Summary", 
            "nunique": r"\%unique rec.", 
            "runtime": "Runtime (sec.)",
            "cum_regret": "Cum. regret",
            "lbd_final": r"$\lambda_f$",
            "value": "value",
            "lbd_star": r"$\lambda^\star$"
        }[i] for i in results_latex.columns]
results_latex = results_latex[["Relevance","Precision","Intra. div.","Inter. div.", "Summary", r"\%unique rec.", "Runtime (sec.)"]+(["Cum. regret"] if (args.learner != "None") else [])+([r"$\lambda_f$","value",r"$\lambda^\star$"] if (args.adaptive_tradeoff) else [])]
results_latex = results_latex.loc[
    [i for i in results_latex.index if ("_SAMPLE" in i and "Fabaphe" not in i)]
    +(["Fabaphe_SAMPLE"] if ("Fabaphe_SAMPLE" in results_latex.index) else [])
    +(["FabapheUCB_SAMPLE"] if ("FabapheUCB_SAMPLE" in results_latex.index) else [])
    +[i for i in results_latex.index if ("_MAX" in i and "Fabaphe" not in i)]
    +(["Fabaphe_MAX"] if ("Fabaphe_MAX" in results_latex.index) else [])
    +(["FabapheUCB_MAX"] if ("FabapheUCB_MAX" in results_latex.index) else [])
    ]
results_latex.to_csv(f"{args.folder}/final_summary_latex.csv", sep='&', lineterminator=" \\\\\n")
print(f"Saved in file {args.folder}/final_summary.csv")
for lst in chunks(results.shape[1], 5):
    print(results[results.columns[lst]])
env = eval(args.envir if (args.envir in ["SyntheticCosine","SyntheticLinearGaussian","MovieLens"]) else "DrugRepurposing")(params)
for user in (args.users if (len(args.users)<4) else args.users[:4]):
    for seed in (seeds if (len(seeds)<4) else seeds[:4]):
        results_us = pd.read_csv(f"{args.folder}/final_user={user}_seed={seed}.csv", index_col=0)
        plot_grid_single_user(results_us, env, list(results_us.columns), f"{args.folder}/grid_plot_user={user}_seed={seed}", seed)
        plot_fulltraj_single_user(results_us, user, env, list(results_us.columns), f"{args.folder}/traj_plot_user={user}_seed={seed}", thres=args.qthres)
        if (args.learner != "None"):
            plot_regret(results_us, list(results_us.columns), f"{args.folder}/regret_plot_user={user}_seed={seed}")
