"""Script to make plots for the variability study."""

import argparse
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

from plots_iclr24_comparison import _read_jsons_from_dir


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--results-dir", type=str, required=True)
    parser.add_argument("--output-dir", type=str, required=True)
    args = parser.parse_args()

    # Info about which heuristics/settings were used, and what to call the algorithms
    # NOTE: this is copied from other script
    heuristic_list = ["optimistic"]
    corr_list = ["independent", "gplvm"]
    marg_list = ["constant", "rank"]
    alg_to_label = {
        "breadth-first-NONE": "BFS",
    }
    for heuristic in heuristic_list:
        alg_to_label[f"retro-star-{heuristic}"] = "retro*"
        alg_to_label[f"retro-fallback-{heuristic}"] = "retro-fb"
    marg_to_label = {
        "constant": "C",
        "rank": "R",
    }
    corr_to_label = {
        "independent": "I",
        "gplvm": "G",
    }

    # Read all the data. Format is also jsons.
    # This time it is in foldes of algorithm/number of samples/trials.
    conditions_to_results = dict()
    for heuristic in heuristic_list:
        for corr in corr_list:
            for marg in marg_list:
                condition_key = (heuristic, corr, marg)
                print(f"Reading {condition_key}")
                alg_to_res = dict()
                for alg in [
                    f"retro-fallback-{heuristic}",
                    "breadth-first-NONE",
                    f"retro-star-{heuristic}",
                ]:
                    alg_path = Path(args.results_dir) / f"{marg}-{corr}" / alg
                    assert alg_path.exists()
                    alg_to_res[alg] = {
                        int(num_samples.name): [
                            _read_jsons_from_dir(trial)
                            for trial in num_samples.iterdir()
                        ]
                        for num_samples in alg_path.iterdir()
                    }
                conditions_to_results[condition_key] = alg_to_res
                del condition_key, alg_to_res

    # Check that all algorithms were evaluated for the same set of SMILES
    all_smiles = None
    for d in conditions_to_results.values():
        for ns_res in d.values():
            for res_list in ns_res.values():
                for res in res_list:
                    if all_smiles is None:
                        all_smiles = set(res["search_results"].keys())
                    else:
                        assert set(res["search_results"].keys()) == all_smiles
        del d, ns_res, res_list, res
    assert all_smiles is not None
    all_smiles = list(sorted(all_smiles))
    print(f"Found {len(all_smiles)} SMILES")

    # Plot 1: for each kind of feasibility model, make a plot of retro-fallback success probability w.r.t. number of samples
    for heuristic in heuristic_list:
        for corr in corr_list:
            for marg in marg_list:
                fig = plt.figure(figsize=(4, 3))
                plt.sca(fig.gca())
                for n_samples, res_list in sorted(
                    conditions_to_results[(heuristic, corr, marg)][
                        f"retro-fallback-{heuristic}"
                    ].items()
                ):
                    times = res_list[0]["analysis_times"]
                    succ_probs = np.asarray(
                        [
                            [
                                res["search_results"][s]["success_probabilities"]
                                for s in all_smiles
                            ]
                            for res in res_list
                        ]
                    )
                    succ_probs = succ_probs.mean(axis=1)  # average over molecules
                    plt.errorbar(
                        times,
                        np.mean(succ_probs, axis=0),
                        yerr=np.std(succ_probs, axis=0),
                        fmt=".-",
                        label=f"{n_samples}",
                        capsize=3,
                    )
                    del times, succ_probs

                plt.title(f"$\\xi_f$: {marg_to_label[marg]}, {corr_to_label[corr]}")
                plt.xscale("log")
                plt.xlabel("num. calls to $B$")
                plt.legend(title="$k$")
                plt.ylabel("average SSP")
                plt.ylim(-0.05, 1.05)
                plt.tight_layout()
                plt.savefig(
                    Path(args.output_dir)
                    / f"rfb-succ-prob-{marg}-{corr}-{heuristic}-wrt-num-samples.pdf"
                )
                plt.close(fig=fig)

    # Plot 2: average runtime vs M (for all feasibility models)
    for heuristic in heuristic_list:
        fig = plt.figure(figsize=(4, 3))
        plt.sca(fig.gca())
        for corr in corr_list:
            for marg in marg_list:
                ns_list = []
                runtime_list = []
                for n_samples, res_list in sorted(
                    conditions_to_results[(heuristic, corr, marg)][
                        f"retro-fallback-{heuristic}"
                    ].items()
                ):
                    ns_list.append(n_samples)
                    runtime_list.append(
                        float(
                            np.mean(
                                [
                                    d["total_search_time"]
                                    for res in res_list
                                    for d in res["search_results"].values()
                                ]
                            )
                        )
                    )
                plt.loglog(
                    ns_list,
                    runtime_list,
                    ".-",
                    label=f"{marg_to_label[marg]}, {corr_to_label[corr]}",
                )

        plt.title("Retro-fallback average speed")
        plt.legend(title=r"$\xi_f$")
        plt.xlabel("Number of samples $k$")
        plt.ylabel("Average runtime (s)")
        plt.tight_layout()
        plt.savefig(
            Path(args.output_dir) / f"rfb-{heuristic}-runtime-wrt-num-samples.pdf"
        )
        plt.close(fig=fig)
    del heuristic, fig, corr, marg, ns_list, n_samples, res_list, runtime_list

    # Plot set 3: success probabilities for individual molecules for all feasibility models
    for heuristic in heuristic_list:
        for s_i, s in enumerate(all_smiles):
            fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12, 3))
            ax_i = -1
            for corr in corr_list:
                for marg in marg_list:
                    ax_i += 1
                    plt.sca(axes[ax_i])
                    for alg, alg_res in conditions_to_results[
                        (heuristic, corr, marg)
                    ].items():
                        if "retro-fallback" in alg:
                            num_samples = 1000
                        else:
                            num_samples = 0
                        times = alg_res[num_samples][0]["analysis_times"]
                        succ_probs = np.asarray(
                            [
                                res["search_results"][s]["success_probabilities"]
                                for res in alg_res[num_samples]
                            ]
                        )
                        median = np.median(succ_probs, axis=0)
                        q75 = np.quantile(succ_probs, 0.75, axis=0)
                        q25 = np.quantile(succ_probs, 0.25, axis=0)
                        plt.plot(times, median, ".-", label=alg_to_label[alg])
                        plt.fill_between(times, q25, q75, alpha=0.2)

                    # Set labels/etc for this subplot
                    if ax_i == 0:
                        plt.ylabel("SSP")
                        plt.legend()
                    plt.xscale("log")
                    plt.xlabel("num. calls to $B$")
                    plt.title(
                        f"$\\xi_f$: {marg_to_label[marg]}, {corr_to_label[corr]} (SMILES #{s_i+1})"
                    )
                    plt.ylim(-0.05, 1.05)

            # Finalize the details of the plot
            plt.tight_layout()
            plt.savefig(
                Path(args.output_dir) / f"per-smiles-succ-prob-{heuristic}-{s_i}.pdf"
            )
            plt.close(fig=fig)
    del (
        heuristic,
        s_i,
        s,
        corr,
        marg,
        alg,
        alg_res,
        num_samples,
        succ_probs,
        median,
        q25,
        q75,
    )


if __name__ == "__main__":
    main()
