"""Script to make plots for the ICLR 2024 comparison."""

import argparse
import json
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt


def _read_jsons_from_dir(dir_path: Path) -> dict:
    """Reads all search results jsons from a directory and return them concatenated."""
    json_paths = dir_path.glob("*.json")
    output = None
    for json_path in json_paths:
        with open(json_path, "r") as f:
            res = json.load(f)
        if output is None:
            output = res
        else:
            assert res["analysis_times"] == output["analysis_times"]
            output["search_results"].update(res["search_results"])
    return output


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results-dir", type=str, required=True)
    parser.add_argument("--output-dir", type=str, required=True)
    args = parser.parse_args()

    # This dictionary will eventually be saved as a json.
    info_dict = dict()

    # Info about which heuristics/settings were used, and what to call the algorithms
    heuristic_list = ["optimistic", "sascore"]
    corr_list = ["independent", "gplvm"]
    marg_list = ["constant", "rank"]
    alg_to_label = {
        "breadth-first-NONE": "BFS",
        "ORbreadth-first-NONE": "O-BFS",
    }
    for heuristic in heuristic_list:
        alg_to_label[f"retro-star-{heuristic}"] = "retro*"
        alg_to_label[f"retro-fallback-{heuristic}"] = "retro-fb"
        alg_to_label[f"mcts-{heuristic}"] = "MCTS"
    marg_to_label = {
        "constant": "C",
        "rank": "R",
    }
    corr_to_label = {
        "independent": "I",
        "gplvm": "G",
    }

    # Read all the data.
    # We are reading a bunch of jsons from directories. The directories give information
    # on the algorith and heuristic used, and the jsons give the actual results
    # (in small batches because the experiments were done in parallel).
    # We read it in in a dict mapping experiment conditions to a dict which contains
    # results for each experiment under those conditions (itself a dict).
    conditions_to_results = dict()
    for heuristic in heuristic_list:
        for corr in corr_list:
            for marg in marg_list:
                condition_key = (heuristic, corr, marg)
                print(f"Reading {condition_key}")
                alg_to_res = dict()
                for alg in [  # Exclude OR BFS for now
                    f"retro-fallback-{heuristic}",
                    "breadth-first-NONE",
                    f"retro-star-{heuristic}",
                    f"mcts-{heuristic}",
                ]:
                    alg_path = Path(args.results_dir) / f"{marg}-{corr}" / alg
                    assert alg_path.exists()
                    alg_to_res[alg] = _read_jsons_from_dir(alg_path)
                    print(
                        f"\t{alg}: num SMILES = {len(alg_to_res[alg]['search_results'])}"
                    )
                conditions_to_results[condition_key] = alg_to_res
                del condition_key, alg_to_res

    # Check that all algorithms were evaluated for the same set of SMILES
    all_smiles = set(
        list(list(conditions_to_results.values())[0].values())[0][
            "search_results"
        ].keys()
    )
    for d in conditions_to_results.values():
        for res in d.values():
            assert set(res["search_results"].keys()) == all_smiles
        del res, d

    # Separate "trivial" and "non-trivial" SMILES.
    # Trivial SMILES are those which are solved by BFS in less than 10 iterations to a high success probability.
    # The set of SMILES which are trivial will differ for each condition, so we store them separately.
    conditions_to_trivial_smiles = dict()
    for k, d in conditions_to_results.items():
        trivial_index = d["breadth-first-NONE"]["analysis_times"].index(10)
        conditions_to_trivial_smiles[k] = set(
            [
                s
                for s, r in d["breadth-first-NONE"]["search_results"].items()
                if r["success_probabilities"][trivial_index] >= 0.95
            ]
        )
        del k, d, trivial_index
    info_dict["conditions_to_trivial_smiles_count"] = {
        str(k): len(v) for k, v in conditions_to_trivial_smiles.items()
    }

    # First plot: average success probability over time (with error bars). Only compare algorithms with the same heuristic.
    for heuristic in heuristic_list:
        for smiles_to_plot in ["all", "trivial", "non-trivial"]:
            fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12, 3))
            ax_i = -1
            for corr in corr_list:
                for marg in marg_list:
                    ax_i += 1
                    plt.sca(axes[ax_i])
                    for alg, res in conditions_to_results[
                        heuristic, corr, marg
                    ].items():
                        # Choose which smiles to keep
                        if smiles_to_plot == "all":
                            curr_smiles_set = all_smiles
                        elif smiles_to_plot == "trivial":
                            curr_smiles_set = conditions_to_trivial_smiles[
                                heuristic, corr, marg
                            ]
                        elif smiles_to_plot == "non-trivial":
                            curr_smiles_set = (
                                all_smiles
                                - conditions_to_trivial_smiles[heuristic, corr, marg]
                            )
                        else:
                            raise ValueError(smiles_to_plot)

                        # Get success probabilities
                        times = res["analysis_times"]
                        arr = np.asarray(
                            [
                                d["success_probabilities"]
                                for s, d in res["search_results"].items()
                                if s in curr_smiles_set
                            ]
                        )

                        # Mean and error bars
                        y = np.mean(arr, axis=0)
                        std = np.std(arr, axis=0, ddof=1) / np.sqrt(
                            arr.shape[0]
                        )  # sample std, just to be technically correct

                        # Plot the results
                        plt.errorbar(
                            times,
                            y,
                            yerr=std,
                            fmt=".-",
                            label=alg_to_label[alg],
                            capsize=3,
                        )
                        del curr_smiles_set, times, arr, y, std

                    # Set labels/etc for this subplot
                    if ax_i == 0:
                        plt.ylabel("average SSP")
                        plt.legend()
                    plt.xscale("log")
                    plt.xlabel("num. calls to $B$")
                    plt.title(
                        f"$\\xi_f$: {marg_to_label[marg]}, {corr_to_label[corr]} ({smiles_to_plot} mols)"
                    )

            # Finalize the details of the plot
            plt.tight_layout()
            plt.savefig(
                Path(args.output_dir)
                / f"avg-success-prob-{heuristic}-{smiles_to_plot}.pdf"
            )
            plt.close(fig=fig)

    # Second plot: fraction of SMILES solved over time. Only compare algorithms with the same heuristic.
    for heuristic in heuristic_list:
        fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12, 3))
        ax_i = -1
        for corr in corr_list:
            for marg in marg_list:
                ax_i += 1
                plt.sca(axes[ax_i])
                for alg, res in conditions_to_results[heuristic, corr, marg].items():
                    times = res["analysis_times"]
                    arr = np.asarray(
                        [
                            d["success_probabilities"]
                            for d in res["search_results"].values()
                        ]
                    )
                    plt.plot(
                        times, np.mean(arr > 0, axis=0), ".-", label=alg_to_label[alg]
                    )
                    del times, arr

                # Set labels/etc for this subplot
                if ax_i == 0:
                    plt.ylabel("Fraction solved")
                    plt.legend()
                plt.xscale("log")
                plt.xlabel("num. calls to $B$")
                plt.title(f"$\\xi_f$: {marg_to_label[marg]}, {corr_to_label[corr]}")
                plt.ylim(-0.05, 1.05)

        # Finalize the details of the plot
        plt.tight_layout()
        plt.savefig(Path(args.output_dir) / f"fraction-solved-{heuristic}.pdf")
        plt.close(fig=fig)

    # Third plot: distribution of shortest routes at last time point.
    for heuristic in heuristic_list:
        fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12, 3))
        ax_i = -1
        for corr in corr_list:
            for marg in marg_list:
                ax_i += 1
                plt.sca(axes[ax_i])
                sorted_alg_names = sorted(
                    conditions_to_results[heuristic, corr, marg].keys()
                )
                MAX_ROUTE_LEN = 20
                shortest_routes = [
                    [
                        min(MAX_ROUTE_LEN, d["shortest_route_over_time"][-1])
                        for d in conditions_to_results[heuristic, corr, marg][alg][
                            "search_results"
                        ].values()
                    ]
                    for alg in sorted_alg_names
                ]
                plt.boxplot(
                    shortest_routes,
                    labels=[alg_to_label[alg] for alg in sorted_alg_names],
                )
                del sorted_alg_names, shortest_routes

                # Set labels/etc for this subplot
                if ax_i == 0:
                    plt.ylabel("Shortest route")
                plt.axhline(MAX_ROUTE_LEN, color="black", linestyle="--")
                plt.title(f"{marg_to_label[marg]}, {corr_to_label[corr]}")

        # Finalize the details of the plot
        plt.tight_layout()
        plt.savefig(Path(args.output_dir) / f"shortest-routes-{heuristic}.pdf")
        plt.close(fig=fig)

    # Fourth plot: distribution of most successful route at last time point.
    for heuristic in heuristic_list:
        fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12, 3))
        ax_i = -1
        for corr in corr_list:
            for marg in marg_list:
                ax_i += 1
                plt.sca(axes[ax_i])
                sorted_alg_names = sorted(
                    conditions_to_results[heuristic, corr, marg].keys()
                )
                most_feas_routes = [
                    [
                        d["most_feasible_route_over_time"][-1]
                        for d in conditions_to_results[heuristic, corr, marg][alg][
                            "search_results"
                        ].values()
                    ]
                    for alg in sorted_alg_names
                ]
                plt.boxplot(
                    most_feas_routes,
                    labels=[alg_to_label[alg] for alg in sorted_alg_names],
                )
                del sorted_alg_names, most_feas_routes

                # Set labels/etc for this subplot
                if ax_i == 0:
                    plt.ylabel("Most successful route")
                plt.title(f"{marg_to_label[marg]}, {corr_to_label[corr]}")

        # Finalize the details of the plot
        plt.tight_layout()
        plt.savefig(Path(args.output_dir) / f"most-successful-route-{heuristic}.pdf")
        plt.close(fig=fig)

    # Fifth set of plots: scatter plots of runtime across different SMILES
    for heuristic in heuristic_list:
        for corr in corr_list:
            for marg in marg_list:
                fig = plt.figure(figsize=(4, 3))
                plt.sca(fig.gca())
                for alg, res in conditions_to_results[heuristic, corr, marg].items():
                    # Get all node / runtime pairs from the dataset
                    n_nodes = []
                    times = []
                    for d in res["search_results"].values():
                        for n, t in zip(
                            d["num_nodes_over_time"], d["search_duration_over_time"]
                        ):
                            if not np.isnan(t):
                                times.append(t)
                                n_nodes.append(n)
                    n_nodes = np.asarray(n_nodes)
                    times = np.asarray(times)

                    # Figure out line of best fit (but only for long runtimes)
                    LOG10_LINEAR_FIT_CUTOFF = 2.5
                    mask = n_nodes > (10**LOG10_LINEAR_FIT_CUTOFF)
                    poly_best_fit = np.polynomial.polynomial.Polynomial.fit(
                        np.log10(n_nodes[mask]), np.log10(times[mask]), deg=1
                    )
                    plot_logt = np.linspace(LOG10_LINEAR_FIT_CUTOFF, 4, 25)

                    # Plot the results
                    plt_idxs = np.random.choice(
                        len(times), size=500, replace=False
                    )  # only plot a subset of the data (otherwise it's too crowded)
                    scatter = plt.loglog(n_nodes[plt_idxs], times[plt_idxs], ",")[0]
                    plt.loglog(
                        10**plot_logt,
                        10 ** poly_best_fit(plot_logt),
                        "-",
                        color=scatter.get_color(),
                        label=alg_to_label[alg]
                        + f" p={poly_best_fit.convert().coef[-1]:.2f}",
                    )

                # Format and save figure
                plt.xlabel("num. nodes")
                plt.ylabel("runtime (s)")
                plt.legend(title="$ax^p$ fit")
                plt.tight_layout()
                plt.savefig(
                    Path(args.output_dir) / f"runtime-{heuristic}-{corr}-{marg}.pdf"
                )
                plt.close(fig=fig)

                del n_nodes, times, d, n, t, mask, poly_best_fit, plot_logt, scatter

    # Output statistics dict
    with open(Path(args.output_dir) / "info.json", "w") as f:
        json.dump(info_dict, f, indent=2)


if __name__ == "__main__":
    main()
