#!/usr/bin/env python
import os
import argparse
import numpy as np
from .bottleneck_env import SimpleEnv
from .generate_state_transition_matrix import build_state_transition_matrix
from .utils import BottleneckVisualization


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_dir",
        type=str,
        default="option_results",
        help="Directory containing VPSOpt option files",
    )
    parser.add_argument(
        "--max_show", type=int, default=50, help="Number of options to visualize per group"
    )
    parser.add_argument("--outer_num", type=int, default=1, help="Number of experiment groups")
    parser.add_argument("--k_base", type=int, default=10, help="Number of base intrinsic rewards")
    parser.add_argument("--keyword", type=str, default="VPSOpt", help="File keyword for VPS Option")
    args = parser.parse_args()

    # Resolve save_dir relative to script directory if it's a relative path
    if not os.path.isabs(args.save_dir):
        script_dir = os.path.dirname(os.path.abspath(__file__))
        args.save_dir = os.path.join(script_dir, args.save_dir)

    if not os.path.exists(args.save_dir):
        raise FileNotFoundError(
            f"Directory not found: {args.save_dir}\n"
            f"Please provide a valid --save_dir path or create the directory."
        )

    if not os.path.isdir(args.save_dir):
        raise ValueError(f"Path exists but is not a directory: {args.save_dir}")

    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    vis = BottleneckVisualization(env)

    for m in range(args.outer_num):
        pattern = f"gridworld_*_{args.keyword}_{m}.npy"
        files = [f for f in os.listdir(args.save_dir) if args.keyword in f and f.endswith(f"_{m}.npy")]
        if not files:
            print(f"[Warning] No file for group {m}, pattern: {pattern}")
            continue
        for fn in files:
            Q = np.load(os.path.join(args.save_dir, fn))
            total = Q.shape[0]
            k_base = args.k_base
            sign = total == 2 * k_base
            print(f"[Load] {fn}  (K={total})  sign={'True' if sign else 'False'}")
            for k in range(min(args.max_show, total)):
                if sign:
                    sign_str = "(+)" if k < k_base else "(-)"
                    base_idx = k if k < k_base else k - k_base
                    title = f"VPS Option #{base_idx} {sign_str}"
                else:
                    title = f"VPS Option #{k}"
                policy = np.argmax(Q[k], 1)
                vis.plot_policy_arrows(policy, Q[k], title=title)


if __name__ == "__main__":
    main()

