#!/usr/bin/env python
import os
import argparse
import numpy as np
from bottleneck_env import SimpleEnv
from generate_state_transition_matrix import build_state_transition_matrix
from utils import BottleneckVisualization

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--save_dir', type=str, default='option_results',
                        help='Directory containing VPSOpt option files')
    parser.add_argument('--max_show', type=int, default=50,
                        help='Number of options to visualize per group')
    parser.add_argument('--outer_num', type=int, default=1,
                        help='Number of experiment groups')
    parser.add_argument('--k_base', type=int, default=10,
                        help='Number of base intrinsic rewards')
    parser.add_argument('--keyword', type=str, default='VPSOpt',
                        help='File keyword for VPS Option')
    args = parser.parse_args()

    env = SimpleEnv(render_mode=None)
    env.reset()
    T, wall = build_state_transition_matrix(env)
    vis = BottleneckVisualization(env)

    for m in range(args.outer_num):
        pattern = f'gridworld_*_{args.keyword}_{m}.npy'
        files = [f for f in os.listdir(args.save_dir)
                 if args.keyword in f and f.endswith(f'_{m}.npy')]
        if not files:
            print(f"[Warning] No file for group {m}, pattern: {pattern}")
            continue
        for fn in files:
            Q = np.load(os.path.join(args.save_dir, fn))
            total = Q.shape[0]
            k_base = args.k_base
            sign = (total == 2 * k_base)
            print(f"[Load] {fn}  (K={total})  sign={'True' if sign else 'False'}")
            for k in range(min(args.max_show, total)):
                if sign:
                    sign_str = '(+)' if k < k_base else '(-)'
                    base_idx = k if k < k_base else k - k_base
                    title = f"VPS Option #{base_idx} {sign_str}"
                else:
                    title = f"VPS Option #{k}"
                policy = np.argmax(Q[k], 1)
                vis.plot_policy_arrows(policy, Q[k], title=title)

if __name__ == "__main__":
    main()
