from argparse import ArgumentParser
import os
import pathlib
import pprint

import numpy as np
import pandas as pd
from pyod.models.dif import DIF
from pyod.models.knn import KNN
from pyod.models.auto_encoder_torch import AutoEncoder
from pyod.models.ecod import ECOD
from sklearn.model_selection import train_test_split
from ruamel.yaml import YAML
from tqdm.auto import tqdm

yaml = YAML(typ='safe')

from evaluation import RankingEvaluator
from utils import get_dataset, get_true_ranking, save_model, save_config, EVALUATOR

# implement some dummy rankers
class PayoutRanker(object):
    def __init__(self):
        pass

    def fit(self, X):
        print("Observational-only -- no fitting needed.")

    def decision_function(self, X):
        return X[OUTCOME_COL]

class RandomRanker(object):
    def __init__(self, seed=42):
        self.seed = seed
        np.random.seed(self.seed)

    def fit(self, X): 
        print("Observational-only -- no fitting needed.")

    def decision_function(self, X):
        return np.random.rand(len(X))


ALGO_DICT = {
    "knn": KNN,
    "dif": DIF,
    "ecod": ECOD,
    "payout": PayoutRanker,
    "random": RandomRanker,
}
COVARIATE_COL_PREFIX = "x"
OUTCOME_COL = "d_obs"
TREATMENT_COL = "t"

def fit_outlier_detector(dev_df, test_df, model_cfg, split_no):
    feat_cols = [c for c in dev_df.columns if c.startswith(COVARIATE_COL_PREFIX)] + [OUTCOME_COL] 
    X_tr = dev_df[feat_cols]
    X_ts = test_df[feat_cols]
    kwargs = model_cfg.get("kwargs", {}) 
    if model_cfg["algorithm"] == "random":
        kwargs["seed"] = split_no
    clf = ALGO_DICT[model_cfg["algorithm"]](**kwargs)
    print("Fitting outlier detector...")
    clf.fit(X_tr)

    print("Evaluating outlier detector...")

    test_df["scores"] = clf.decision_function(X_ts)
    agg_scores = test_df.groupby(TREATMENT_COL)["scores"].mean()
    final_ranking = agg_scores.sort_values().index.tolist()
    return clf, final_ranking
    

if __name__ == '__main__':
    psr = ArgumentParser()
    psr.add_argument("--config", type=str, required=True)
    psr.add_argument("--name", type=str, required=True)
    psr.add_argument("--dataset", type=str)
    psr.add_argument("--overwrite", action='store_true')
    args = psr.parse_args()
    
    with open(args.config, "r") as f:
        cfg = yaml.load(f) 
    with open(EVALUATOR, "r") as f: # we probably don't need to do I/O every time -- refactor
        eval_cfg = yaml.load(f)
    evaluator = RankingEvaluator.from_config(eval_cfg)

    save_dir = os.path.join("./estimators", args.name) #cfg["name"])
    if os.path.isdir(save_dir) and not args.overwrite:
        raise ValueError(f"{save_dir} exists. Exiting.")

    save_path = os.path.join(save_dir, "model_{}.pkl")
    result_path = os.path.join(save_dir, "results.csv")
    ranking_path = os.path.join(save_dir, "rankings.csv")
    config_path = os.path.join(save_dir, "config.yml")

    dataset_name = cfg["dataset"]["name"] if args.dataset is None else args.dataset
    if args.dataset is not None:
        print("Overriding dataset specification -- using dataset:", args.dataset)


    df, _, data_config_file = get_dataset(dataset_name)
    if cfg["dataset"]["name"].startswith("synth"):
        true_ranking = get_true_ranking(data_config_file)
    else:
        true_ranking = None

    all_metrics = []
    all_rankings = []

    for i, split in enumerate(tqdm(sorted(df["split"].unique()))):
        print("Working on split", i)
        df_subset = df.loc[df["split"] == split]
        dev_df, test_df = train_test_split(df_subset, test_size=cfg["split"]["size"], random_state=cfg["split"]["seed"])
        clf, pred_rankings = fit_outlier_detector(dev_df, test_df, cfg["model"], split)
 
        print("Predicted rankings:", pred_rankings)
        if true_ranking is not None:
            metrics = evaluator.evaluate(true_ranking, pred_rankings, save_results=False)
            pprint.pprint(metrics)
            all_metrics.append(metrics)
            print("True ranking:", true_ranking)
        print("Saving model...")
        pathlib.Path(save_dir).mkdir(exist_ok=True)
        save_model(save_path.format(i), clf)
        all_rankings.append(pred_rankings)

    print("Saving config...")

    save_config(config_path, cfg)
    metric_df = pd.DataFrame(all_metrics)
    ranking_arr = np.array(all_rankings)
    if len(metric_df):
        with pd.option_context('display.max_columns', 20):
            print(metric_df.describe(percentiles=[0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.975]))
        metric_df.to_csv(result_path)
    np.savetxt(ranking_path, ranking_arr, delimiter=",")
