"""
Visualise a tree-structured reward function as a diagram.
"""
import argparse
from torch import load
from numpy import zeros
import matplotlib.pyplot as plt
from config.features.fastjet import *
from hyperrectangles import diagram


parser = argparse.ArgumentParser()
parser.add_argument("task", type=str)
parser.add_argument("model", type=str)
parser.add_argument("--svg", type=int, default=0)
args = parser.parse_args()

# Load tree, then rename features for visualisation
assert "tree_" in args.model
model = load(f"trained_models/{args.task}/{args.model}.reward", map_location=device); model.device = device
tree = model.forest[0]["tree"]
f = tree.space.dim_names
f = [ff.replace("_", " ") for ff in f]
f[-1] = "r"
tree.space.dim_names = f

# Gather trajectory-level return estimates (divided by lengths) and use for cmap_lims
eps, rewards = tree.root.data("ep", "r").T
g_per_t = zeros(200)
for i, r in zip(eps.astype(int), rewards): g_per_t[i] = r
cmap_lims = (min(g_per_t), max(g_per_t))

# Make diagram and either save or show using matplotlib
diag = diagram(tree, pred_dims=["r"], colour_dim="r", cmap_lims=cmap_lims,
                show_decision_node_preds=True, out_as="svg" if args.svg else "plt")
if not args.svg:
    _, ax = plt.subplots(figsize=(30,30)); ax.axis("off")
    ax.imshow(diag)
plt.show()
