import argparse

from s3.visualize import visualize_s3


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--algo', default='s3', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--gid', default=0, type=int)
    parser.add_argument('--env_name', default='AntMaze', type=str)
    parser.add_argument('--model_dir', default='./pretrained_models', type=str)

    parser.add_argument('--manager_propose_freq', default=10, type=int)
    parser.add_argument('--candidate_goals', default=10, type=int)
    parser.add_argument('--goal_loss_coeff', default=20.0, type=float)
    parser.add_argument('--no_correction', action='store_true')
    parser.add_argument('--absolute_goal', action='store_true')

    parser.add_argument('--n_mix', default=5, type=int)
    parser.add_argument('--disable_reach', action='store_true')
    parser.add_argument('--reach_path', default='', type=str,
                        help='Optional path template to reach-net weights (supports {env},{algo},{suffix},{tag}).')

    parser.add_argument('--checkpoints', default='base', type=str,
                        help="Comma separated checkpoint suffixes; use 'base' for latest weights.")
    parser.add_argument('--episodes', default=3, type=int)
    parser.add_argument('--max_steps', default=500, type=int)
    parser.add_argument('--plots_dir', default='./plots', type=str)

    parser.add_argument('--sequence_checkpoint', default='', type=str,
                        help='Checkpoint to use for the detailed sequence plot (defaults to last in --checkpoints).')
    parser.add_argument('--sequence_index', default=0, type=int,
                        help='Episode index (within the chosen checkpoint) for the sequence plot.')

    parser.add_argument('--figure_size', default=7.5, type=float)
    parser.add_argument('--floor_alpha', default=0.45, type=float)
    parser.add_argument('--overlay_cmap', default='viridis', type=str)
    parser.add_argument('--sequence_cmap', default='plasma', type=str)
    parser.add_argument('--trace_width', default=2.0, type=float)
    parser.add_argument('--trace_alpha', default=0.85, type=float)
    parser.add_argument('--subgoal_marker_size', default=60.0, type=float)
    parser.add_argument('--landing_marker_size', default=32.0, type=float,
                        help='Marker size for visualizing landing states.')
    parser.add_argument('--landing_alpha', default=0.45, type=float,
                        help='Alpha value for landing state markers.')
    parser.add_argument('--subgoal_radius_alpha', default=0.2, type=float,
                        help='Alpha value for the shaded landing radius around subgoals.')
    parser.add_argument('--path_color', default='tab:gray', type=str)
    parser.add_argument('--path_width', default=2.5, type=float)
    parser.add_argument('--default_marker_color', default='tab:blue', type=str)
    parser.add_argument('--annotate_height', action='store_true')
    parser.add_argument('--dpi', default=220, type=int)

    args = parser.parse_args()
    if not args.sequence_checkpoint:
        args.sequence_checkpoint = None

    visualize_s3(args)
