"""
Contains the code to replicate
the LINGAM experiments.
"""

import os
import time

import fire
import pandas as pd
import ray

from abs_lingam import evaluate_lingam
from learnabs.dataset import generate_datasets
from learnabs.experiments import preprocess_results


def plot(
    results_dir: str = "results/",
    flavor: str = "small",
    plot_dir: str = "plots/",
    store: bool = True,
    show: bool = False,
):
    import matplotlib.pyplot as plt
    import pandas as pd
    import seaborn as sns

    results_fname = os.path.join(results_dir, f"experiment1_{flavor}.csv")
    results = pd.read_csv(results_fname)

    if flavor == "small":
        dset = "d5_e8_tER_m5_M10"
    elif flavor == "medium":
        dset = "d10_e20_tER_m5_M10"
    elif flavor == "large":
        dset = "d10_e20_tER_m10_M15"
    dset += "_h1000.0_p0.5_iTrue_n20000_Nexponential_I0_jFalse_vNone_aTrue"

    print(f"== Experiment 1 ({flavor}) ==")

    # filter results
    print(f"Total Entries: {len(results)}")
    # results = results[results["dset/signature"] == dset]
    results = preprocess_results(results)
    print(f"Flavor Entries: {len(results)}")

    def replicate_per_paired_samples(target: str):
        # get all rows where "params/method" = Full Concrete
        rows = results[results["params/method"] == target]
        # convert to list of dictionary
        rows = rows.to_dict("records")
        # unique paired samples
        paired_samples = results["dset/paired_samples"].unique()
        # remove 0
        paired_samples = paired_samples[1:]
        new_rows = []
        for row in rows:
            for paired_sample in paired_samples:
                new_row = row.copy()
                new_row["dset/paired_samples"] = paired_sample
                new_rows.append(new_row)
        return pd.concat([results, pd.DataFrame(new_rows)], ignore_index=True)

    # print(len(results))
    results = replicate_per_paired_samples("Full Concrete")
    # print(len(results))
    results = replicate_per_paired_samples("Abs-GT")
    # print(len(results))
    avg_cnc_nodes = results["dset/cnc_nodes"].mean()
    std_cnc_nodes = results["dset/cnc_nodes"].std() * 1.96
    print(
        "Concrete Nodes:",
        round(avg_cnc_nodes, 2),
        "±",
        round(std_cnc_nodes, 2),
    )

    def add_bar(y, ax=None):
        # min_y = results[y].min()
        # max_y = 1.0  # results[y].max()
        plt.axvline(avg_cnc_nodes, color="black", alpha=0.5, linestyle="--")
        # plt.fill_between(
        #     [avg_cnc_nodes - std_cnc_nodes, avg_cnc_nodes + std_cnc_nodes],
        #     min_y,
        #     max_y,
        #     color="b",
        #     alpha=0.05,
        # )

    for y, y_label in [
        ("eval/abstract_roc_auc", r"ROCAUC $\mathcal{H}$"),
        ("eval/abstract_precision", r"Precision $\mathcal{H}$"),
        ("eval/abstract_recall", r"Recall $\mathcal{H}$"),
        ("eval/abstract_f1", r"F1 $\mathcal{H}$"),
        ("eval/tau_roc_auc", r"ROCAUC $\mathbf{T}$"),
        ("eval/tau_precision", r"Precision $\mathbf{T}$"),
        ("eval/tau_recall", r"Recall $\mathbf{T}$"),
        ("eval/tau_f1", r"F1 $\mathbf{T}$"),
        ("eval/concrete_roc_auc", r"ROCAUC $\mathcal{L}$"),
        ("eval/concrete_precision", r"Precision $\mathcal{L}$"),
        ("eval/concrete_recall", r"Recall $\mathcal{L}$"),
        ("eval/concrete_f1", r"F1 $\mathcal{L}$"),
        ("eval/pk_precision", "Prior Knowledge Precision"),
        ("eval/pk_recall", "Prior Knowledge Recall"),
        ("eval/time", "Time (s)"),
    ]:
        sns.lineplot(
            x="dset/paired_samples",
            y=y,
            hue="Method",
            style="Method",
            data=results,
        )
        # rename axis
        plt.xlabel(r"Paired Samples $|\mathcal{D}_P|$")
        plt.ylabel(y_label)
        # log y
        plt.yscale("log")
        # rename legend
        plt.legend(title="Method", loc="lower right")
        if y == "eval/concrete_roc_auc":
            plt.ylim(0.5, 1)
        elif "concrete" in y:
            plt.ylim(0, 1)
        add_bar(y)
        # store plots/exp1_rocauc.pdf
        metric = y.split("/")[1]
        if store:
            plt.savefig(os.path.join(plot_dir, f"exp1_{metric}_{flavor}.pdf"))
            plt.savefig(os.path.join(plot_dir, f"exp1_{metric}_{flavor}.pgf"))
        if show:
            plt.show()
        plt.clf()

    if show:
        return results


def run(
    seed: int = 1011329608,
    num_cpus: int = 2,
    num_runs: int = 5,
    flavor: str = "small",
    data_dir: str = "data/",
    results_dir: str = "results/",
    verbose: bool = False,
):
    """run"""

    # initialize ray
    assert num_cpus > 1
    ray.init(num_cpus=num_cpus, num_gpus=0)

    datetime = time.strftime("%Y-%m-%d-%H-%M-%S")
    dset_params = {
        "n_nodes": 5,
        "n_edges": 8,
        "graph_type": "ER",
        "min_readout": 5,
        "max_readout": 10,
        "alpha": 1e3,
        "marginalize_ratio": 0.5,
        "internal": True,
        "n_samples": 50000,
        "noise_term": "exponential",
    }

    assert flavor in ["small", "medium", "large"], f"Unknown flavour {flavor}."

    if flavor == "small":
        pass
    elif flavor == "medium":
        dset_params["n_nodes"] = 10
        dset_params["n_edges"] = 20
    elif flavor == "large":
        dset_params["n_nodes"] = 10
        dset_params["n_edges"] = 20
        dset_params["min_readout"] = 10
        dset_params["max_readout"] = 15

    # TODO: outer loop on different dset configuraitons
    generate_datasets(dset_params, data_dir, num_runs, force=False)
    futures = []
    n_concrete = 20000
    max_paired_samples = (
        dset_params["n_nodes"] * dset_params["max_readout"] * 3
    )
    shift = int(max_paired_samples / 10)
    paired_range = list(range(shift, max_paired_samples, shift))
    if max_paired_samples not in paired_range:
        paired_range.append(max_paired_samples)
    for run in range(num_runs):
        # Test the Abs-GT method
        futures.append(
            evaluate_lingam.remote(
                dset_params,
                data_dir,
                method="Abs-GT",
                n_paired=0,
                n_concrete=n_concrete,
                run=run,
                shuffle_features=True,
                normalize=True,
                seed=seed,
                verbose=verbose,
            )
        )
        for n_paired in paired_range:
            # Test the Abs-Fit method
            futures.append(
                evaluate_lingam.remote(
                    dset_params,
                    data_dir,
                    method="Abs-Fit",
                    n_paired=n_paired,
                    n_concrete=n_concrete,
                    run=run,
                    shuffle_features=True,
                    normalize=True,
                    seed=seed,
                    verbose=verbose,
                )
            )

            for bootstrap_samples in [1, 2, 5, 10]:
                # Test the Abs-Fit method
                futures.append(
                    evaluate_lingam.remote(
                        dset_params,
                        data_dir,
                        method="Abs-Fit",
                        n_paired=n_paired,
                        n_concrete=n_concrete,
                        run=run,
                        shuffle_features=True,
                        normalize=True,
                        seed=seed,
                        verbose=verbose,
                        bootstrap_samples=bootstrap_samples,
                    )
                )

    # get records and add experiment info
    print(f"Launched {len(futures)} total jobs.")
    records = ray.get(futures)
    print(f"Finished {len(records)} total jobs.")
    for record in records:
        record["experiment/seed"] = seed
        record["experiment/datetime"] = datetime

    # build and append dataframe
    df = pd.DataFrame.from_records(records)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    results_fname = os.path.join(results_dir, f"experiment1_{flavor}.csv")
    df.to_csv(results_fname, index=False)


if __name__ == "__main__":
    fire.Fire()
