import os
import shutil
import json
import argparse
import numpy as np
from pathlib import Path

from vgn.detection import VGN
from vgn.detection_implicit import VGNImplicit
# from vgn.detection_implicit_top import VGNImplicitTop
# from vgn.detection_implicit_pc import VGNImplicitPC
from vgn.experiments import clutter_removal_single
from vgn.utils.misc import set_random_seed


def main(args):
    grasp_planner = VGNImplicit(args.experiment_name,
                                args.ckpt_index,
                                best=args.best,
                                qual_th=args.qual_th,
                                force_detection=args.force,
                                out_th=args.out_th,
                                select_top=args.select_top,
                                visualize=args.vis,
                                resolution=args.res)

    set_random_seed(args.seed)

    results = {}
    save_dir_root = Path("experiments") / args.experiment_name / args.save_dir
    if os.path.exists(save_dir_root):
        shutil.rmtree(save_dir_root)
    os.makedirs(save_dir_root)

    for index in range(args.num_rounds):
        args.seed = np.random.randint(3000)
        save_dir = save_dir_root / f'round_{index:03d}'
        print(f'===============  Round {index:03d} start  ===============')
        os.makedirs(save_dir, exist_ok=True)
        object_count = np.random.poisson(args.num_objects - 1) + 1
        results[index] = clutter_removal_single.run_one_scene(grasp_plan_fn=grasp_planner,
                                                              save_pkl=False,
                                                              save_dir=save_dir,
                                                              scene=args.scene,
                                                              object_set=args.object_set,
                                                              num_objects=object_count,
                                                              num_view=args.num_view,
                                                              seed=args.seed,
                                                              sim_gui=args.sim_gui,
                                                              add_noise=args.add_noise,
                                                              sideview=args.sideview)
        print(f'Round {index} finished, result: {results[index]}')

    # grasp success rates (GSR) = success / cnt
    GSR = np.sum([results[i][0] for i in range(args.num_rounds)]) / np.sum([results[i][1] for i in range(args.num_rounds)])
    # declutter rates (DR) = (success == total_objs)
    declutter_cnt = np.sum([results[i][0] == results[i][2] for i in range(args.num_rounds)])
    DR = declutter_cnt / args.num_rounds
    print(f'GSR: {GSR:.4f}, DR: {DR}')

    with open(save_dir_root / 'results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f'Saving results to {save_dir_root / "results.json"}')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--experiment_name", type=str,
                        default="2023-04-23-11-38-26")
    parser.add_argument("--ckpt_index", type=int,
                        default=-1)
    parser.add_argument("--type", type=str,
                        default="agate")
    parser.add_argument("--save-dir", type=Path,
                        default="grasp_results")
    parser.add_argument("--scene",
                        type=str,
                        choices=["pile", "packed"],
                        default="pile")
    parser.add_argument("--object-set", type=str,
                        default="pile/test")  # blocks or packed/test or pile/test
    parser.add_argument("--num-objects", type=int, default=5, help="average number of objects")
    parser.add_argument("--num-view", type=int, default=1)
    parser.add_argument("--num-rounds", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--sim-gui", action="store_true")
    parser.add_argument("--qual-th", type=float, default=0.9)
    parser.add_argument(
        "--best",
        action="store_true",
        default=True,
        help="Whether to use best valid grasp (or random valid grasp)")
    parser.add_argument("--result-path", type=str, default="results/001")
    parser.add_argument(
        "--force",
        action="store_true",
        default=True,
        help="When all grasps are under threshold, force the detector to select the best grasp"
    )
    parser.add_argument(
        "--add-noise",
        type=str,
        default='dex',
        help="Whether add noise to depth observation, trans | dex | norm | ''")
    parser.add_argument("--sideview",
                        action="store_true",
                        help="Whether to look from one side",
                        default=True)
    parser.add_argument("--simple-constrain",
                        action="store_true",
                        help="Whether to contrain grasp from backward")
    parser.add_argument("--res", type=int, default=40)
    parser.add_argument("--out-th", type=float, default=0.5)
    parser.add_argument("--silence",
                        action="store_true",
                        help="Whether to disable tqdm bar")
    parser.add_argument("--select-top",
                        action="store_true",
                        help="Use top heuristic")
    parser.add_argument("--vis",
                        action="store_true",
                        help="visualize and save affordance")

    args = parser.parse_args()
    main(args)
