import argparse
import json
import logging
import multiprocessing as mp
import os

import numpy as np
import pandas as pn
import torch
import yaml
from tqdm import tqdm

import blahut_arimoto as ba
import eval
import utils
from sklearn.metrics import auc


def get_idx_out_d(datasets_idx_ood, ood_ds_idx: int = None):
    if ood_ds_idx is not None:
        idx_out_d = np.where(datasets_idx_ood == ood_ds_idx)[0]
        idx_out_d = np.random.permutation(idx_out_d)
    else:
        idx_out_d = np.where(datasets_idx_ood != 0)[0]
        idx_out_d = np.random.permutation(idx_out_d)
    return idx_out_d


def control_pi(idx_in_d, idx_out_d):
    # control over pi (pi=0.5)
    # take the minimum between the number of in-d and ood samples
    min_size = min(idx_in_d.shape[0], idx_out_d.shape[0])
    # pick the same amount of scores for in-d and ood
    idx_in_d = idx_in_d[:min_size]
    idx_out_d = idx_out_d[:min_size]
    return idx_in_d, idx_out_d


def get_s_ood(scores_array_ood, idx_in_d, idx_out_d):
    s_ood = np.concatenate((scores_array_ood[idx_in_d], scores_array_ood[idx_out_d]))
    return s_ood


def get_s_sc(scores_array_sc, idx_in_d, idx_out_d):
    s_sc = np.concatenate((scores_array_sc[idx_in_d], scores_array_sc[idx_out_d]))
    return s_sc


def get_y_scod(labels_array_ood, idx_in_d, idx_out_d):
    y_scod = np.concatenate((labels_array_ood[idx_in_d], labels_array_ood[idx_out_d]))
    return y_scod


def eval_scod(tprs_scod, fprs_scod, risks_scod, coverages_scod, lbds_scod_ood):
    return {
        # "fpr_at_0.95_tpr": eval.fpr_at_fixed_tpr(tprs_scod, fprs_scod, lbds_scod_ood, fixed_tpr=0.95),
        "auroc": auc(fprs_scod, tprs_scod),
        "aurc": auc(coverages_scod, risks_scod),
    }


def main_loop(configuration, logger, seed: int = 42):
    logger.info(f"Starting main function for configuration: {configuration}")

    res_df = None
    model_name = configuration["model_name"]
    utils.initialize_seeds(seed)
    logger.info(f"Seed: {seed}")
    scores_path1 = configuration["scores_path1"]
    eps = configuration["eps"]
    c_fn = configuration["c_fn"]

    scores = utils.get_prep_scores(scores_path1)
    scores = scores.query("model == @model_name")

    device = utils.get_device()

    # add baseline with equal weights
    scores_array, labels_array, datasets_idx = utils.get_model_scores_labels_idx_w_baseline(scores=scores, model_name=model_name, device=device)
    labels_array_ba = labels_array[:, -1]
    scores_array_ba = scores_array[:, -1]
    datasets_idx_ba = datasets_idx[:, -1]
    elem = {
        "model": model_name,
        "in_dataset_name": configuration["in_dataset"],
        "method": "baseline",
        "method_kwargs": {},
        "scores": json.dumps(scores_array_ba.tolist()),
        "labels": json.dumps(labels_array_ba.tolist()),
        "idx": json.dumps(datasets_idx_ba.tolist()),
    }
    elem = {k: [v] for k, v in elem.items()}
    scores = pn.concat([scores, pn.DataFrame.from_dict(elem)], ignore_index=True)

    scores_array, labels_array, datasets_idx = utils.get_model_scores_labels_idx_w_ba(scores=scores, model_name=model_name, device=device)
    # ba (ood)
    labels_array_ba = labels_array[:, -1]
    scores_array_ba = scores_array[:, -1]
    datasets_idx_ba = datasets_idx[:, -1]
    # insert ba entry in scores df
    elem = {
        "model": model_name,
        "in_dataset_name": configuration["in_dataset"],
        "method": "ba",
        "method_kwargs": {},
        "scores": json.dumps(scores_array_ba.tolist()),
        "labels": json.dumps(labels_array_ba.tolist()),
        "idx": json.dumps(datasets_idx_ba.tolist()),
    }
    elem = {k: [v] for k, v in elem.items()}
    scores = pn.concat([scores, pn.DataFrame.from_dict(elem)], ignore_index=True)

    famous_methods = ["knn_euclides", "igeood_logits", "max_logits", "energy"]
    subscores = scores.query("method==@famous_methods", inplace=False)
    scores_array_, labels_array_, datasets_idx_ = utils.get_model_scores_labels_idx_w_ba(scores=subscores, model_name=model_name, device=device)
    # ba (ood)
    labels_array_ba_ = labels_array_[:, -1]
    scores_array_ba_ = scores_array_[:, -1]
    datasets_idx_ba_ = datasets_idx_[:, -1]
    # insert ba entry in scores df
    elem = {
        "model": model_name,
        "in_dataset_name": configuration["in_dataset"],
        "method": "ba_famous",
        "method_kwargs": {},
        "scores": json.dumps(scores_array_ba_.tolist()),
        "labels": json.dumps(labels_array_ba_.tolist()),
        "idx": json.dumps(datasets_idx_ba_.tolist()),
    }
    elem = {k: [v] for k, v in elem.items()}
    scores = pn.concat([scores, pn.DataFrame.from_dict(elem)], ignore_index=True)

    methods_sc = configuration["methods_sc"]
    methods_ood = configuration["methods_ood"] + ["ba", "ba_famous", "baseline"]

    experiments_dict = utils.create_exp_dict(methods_sc, methods_ood)

    for experiment in experiments_dict.keys():  # each ood method
        method_ood = experiments_dict[experiment]["ood"]
        method_sc = experiments_dict[experiment]["sc"]
        logger.info(f"Experiment: {experiment}: method sc = {method_sc} and method ood = {method_ood}")
        # ood
        dfoi_ood = scores[(scores["model"] == model_name) & (scores["method"] == method_ood)]
        labels_array_ood = dfoi_ood["labels"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)
        scores_array_ood = dfoi_ood["scores"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)
        datasets_idx_ood = dfoi_ood["idx"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)
        # sc
        dfoi_sc = scores[(scores["model"] == model_name) & (scores["method"] == method_sc)]
        labels_array_sc = dfoi_sc["labels"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)
        scores_array_sc = dfoi_sc["scores"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)
        datasets_idx_sc = dfoi_sc["idx"].map(lambda x: np.array(json.loads(x)).reshape(-1, 1)).values[0].reshape(-1)

        utils.assert_equivalence_label_idx(labels_array_ood, datasets_idx_ood, labels_array_sc, datasets_idx_sc, labels_array_ba, datasets_idx_ba)

        idx_in_d = np.where(datasets_idx_ood == 0)[0]
        idx_in_d = np.random.permutation(idx_in_d)

        ls = utils.remove_element_from_list(np.unique(datasets_idx_ood).tolist(), 0)
        for ood_ds_idx in ls + [len(ls) + 1]:
            ood_dataset = configuration["datasets_idx"][ood_ds_idx]
            if ood_dataset == "all":
                ood_ds_idx = None
            idx_out_d = get_idx_out_d(datasets_idx_ood, ood_ds_idx)
            idx_in_d, idx_out_d = control_pi(idx_in_d, idx_out_d)

            s_ood = get_s_ood(scores_array_ood, idx_in_d, idx_out_d)
            s_sc = get_s_sc(scores_array_sc, idx_in_d, idx_out_d)
            y_scod = get_y_scod(labels_array_ood, idx_in_d, idx_out_d)

            risks_sc, coverages_sc, thrs_sc = eval.risks_coverages_selective_net(
                scores=torch.Tensor(scores_array_sc[idx_in_d]),
                labels=torch.Tensor(labels_array_sc[idx_in_d]),
                sort=True,
                n=1000,
            )
            ood_performance = eval.get_ood_results(in_scores=torch.Tensor(scores_array_ood[idx_in_d]), ood_scores=torch.Tensor(scores_array_ood[idx_out_d]))
            ood_performance["aurc"] = auc(coverages_sc, risks_sc)

            X = scores_array_ood[idx_in_d][:1000].reshape(-1, 1)
            probs_array = ba.p_value_fn(scores_array_ood.reshape(-1, 1), X=X).reshape(-1)
            ood_performance_with_CDF = eval.get_ood_results(in_scores=torch.Tensor(probs_array[idx_in_d]), ood_scores=torch.Tensor(probs_array[idx_out_d]))

            s_ood = np.concatenate((probs_array[idx_in_d], probs_array[idx_out_d]))
            s_ood_label = np.concatenate((labels_array_ood[idx_in_d], labels_array_ood[idx_out_d]))

            pi_in_star = len(idx_in_d) / float((len(idx_in_d) + len(idx_out_d)))

            (
                tprs_scod_ood,
                fprs_scod_ood,
                risks_scod_ood,
                risks_num_scod_ood,
                risks_den_scod_ood,
                coverages_scod_ood,
                lbds_scod_ood,
            ) = utils.get_scod_stats(
                scores_sc=s_sc,
                scores_ood=s_ood,
                lbd_range=utils.get_lbd_range(pi_in_star=pi_in_star, c_fn=c_fn),
                y_scod=y_scod,
                eps=eps,
                pi_in_star=pi_in_star,
                c_fn=c_fn,
            )
            scod_performance = eval_scod(tprs_scod_ood, fprs_scod_ood, risks_scod_ood, coverages_scod_ood, lbds_scod_ood)

            dict_res = {
                "seed": seed,
                "model": model_name,
                "ood_dataset": ood_dataset,
                "method_sc": method_sc,
                "method_ood": method_ood,
                "ood_ds_idx": ood_ds_idx,
                "risks_sc": risks_sc,
                "coverages_sc": coverages_sc,
                "thrs_sc": thrs_sc,
                "ood_performance": ood_performance,
                "ood_performance_with_CDF": ood_performance_with_CDF,
                "scod_performance": scod_performance,
                "tprs_scod_ood": tprs_scod_ood,
                "fprs_scod_ood": fprs_scod_ood,
                "risks_scod_ood": risks_scod_ood,
                "risks_num_scod_ood": risks_num_scod_ood,
                "risks_den_scod_ood": risks_den_scod_ood,
                "coverages_scod_ood": coverages_scod_ood,
                "lbds_scod_ood": lbds_scod_ood,
                "s_ood": get_s_ood(scores_array_ood, idx_in_d, idx_out_d),
                "s_ood_prob": np.concatenate((probs_array[idx_in_d], probs_array[idx_out_d])),
                "labels_array_ood": labels_array_ood,
                "labels_array_sc": labels_array_sc,
                "datasets_idx_ood": datasets_idx_ood,
                "datasets_idx_sc": datasets_idx_sc,
                "scores_array_ood": scores_array_ood,
                "scores_array_sc": scores_array_sc,
                "idx_in_d": idx_in_d,
                "idx_out_d": idx_out_d,
                "s_ood_label": s_ood_label,
            }
            if res_df is None:
                res_df = pn.DataFrame.from_dict([dict_res])
            else:
                res_df = pn.concat([res_df, pn.DataFrame.from_dict([dict_res])], ignore_index=True)

    res_df.to_csv(configuration["res_folder"] + f"/res_{seed}.csv")
    torch.save(res_df, configuration["res_folder"] + f"/res_{seed}.pt")
    logger.info(f"Saved results in {configuration['res_folder']}")


def worker(args):
    main_loop(*args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parser")
    parser.add_argument("-c", "--config", type=str, default="configs/cifar10.yaml", help="Path to the config file.")
    parser.add_argument("-m", "--model", type=str, default="resnet18_cifar10", help="Model name.")
    parsed_args, unknown = parser.parse_known_args()

    with open(parsed_args.config, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        config["model_name"] = parsed_args.model

    res_folder = "/".join([config["res_folder"], config["model_name"]])

    try:
        os.makedirs(res_folder)
    except FileExistsError:
        files = os.listdir(res_folder)
        for f in files:
            os.remove("/".join([res_folder, f]))
    config["res_folder"] = res_folder

    config["log_path"] = "/".join([res_folder, "log.txt"])
    logger_main = logging.getLogger(__name__)
    logger_main.setLevel(logging.INFO)
    handler = logging.FileHandler(config["log_path"], mode="w")
    handler.setLevel(logging.INFO)
    logger_main.addHandler(handler)
    handler = logging.StreamHandler()
    handler.setLevel(logging.INFO)
    logger_main.addHandler(handler)
    logger_main.info(f"Running with configuration file: {parsed_args.config}")

    seeds = config["seeds"]
    seeds = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    # seeds=[0]
    all_args = []
    for seed in seeds:
        all_args.append((config, logger_main, seed))
    processes = []
    with mp.Pool(processes=os.cpu_count() - 1) as pool:
        list(tqdm(pool.imap(worker, all_args), total=len(all_args)))

    # all seeds finished, aggregate scores into a single file
    print("All seeds finished, aggregating results")
    all_res = []
    for seed in seeds:
        all_res.append(torch.load(config["res_folder"] + f"/res_{seed}.pt"))
    all_res = pn.concat(all_res, ignore_index=True)
    all_res.to_csv(config["res_folder"] + "/res.csv")
    torch.save(all_res, config["res_folder"] + "/res.pt")

    # remove intermediate files
    files = os.listdir(res_folder)
    for f in files:
        if f != "res.csv" and f != "res.pt" and f != "log.txt":
            os.remove("/".join([res_folder, f]))
