import os
import sys
import subprocess
import shlex
import re
import json
import tqdm
import multiprocessing


RESULTS_FILE = "v4r_results.json"
UPDATE_TARGET = 500
NUM_SEEDS = 3


COMMAND_TEMPLATE = r"""
{PYTHON} -u -m torch.distributed.launch --nproc_per_node={NGPU} --use_env
-m habitat_baselines.run
--exp-config habitat_baselines/config/pointnav/ddppo_pointnav.yaml
--run-type train
NUM_PARALLEL_SCENES {NS}
TASK_CONFIG.DATASET.SPLIT {SPLIT}
TASK_CONFIG.SEED {SEED}
TENSORBOARD_DIR "tb/speed"
CHECKPOINT_FOLDER "data/checkpoints/speed"
EVAL_CKPT_PATH_DIR "data/checkpoints/speed"
LOG_INTERVAL 50
NUM_PROCESSES {NPROC}
TRAINER_NAME ddppo
RL.PPO.lr 1.0e-5
RL.PPO.weight_decay 1e-2
RL.PPO.lamb True
RL.DDPPO.scale_lr True
RL.PPO.lamb_min_trust 0.9
RL.PPO.num_mini_batch 2
RL.PPO.ppo_epoch 1
RL.PPO.num_steps 32
RL.DDPPO.backbone se_resnet9_fixup
RL.DDPPO.rnn_type LSTM
RL.DDPPO.num_recurrent_layers 2
RL.DDPPO.distrib_backend NCCL
RL.PPO.hidden_size 512
TOTAL_NUM_STEPS -1.0
NUM_CHECKPOINTS 1
NUM_UPDATES {NUM_UPDATES}
COLOR {COLOR}
DEPTH {DEPTH}
RL.DDPPO.pretrained_weights {MODEL}
RL.DDPPO.pretrained True
RL.DDPPO.train_encoder True
RL.DDPPO.reset_critic False
""".strip().replace(
    "\n", " "
)


def build_results():
    env = os.environ.copy()
    env["V4R_BENCHMARK"] = "1"
    env["GLOG_minloglevel"] = "2"
    env["MAGNUM_LOG"] = "quiet"

    fps_finder = re.compile("update:\s(?P<update>\d+).*fps:\s(?P<fps>\d+\.\d+)")
    results = dict(ngpu=[], fps=[], sensor=[])

    for ngpu in tqdm.tqdm([1, 8]):
        env["OMP_NUM_THREADS"] = str(
            multiprocessing.cpu_count() // 8 if ngpu == 8 else 12
        )
        env["OPENBLAS_NUM_THREADS"] = str(
            multiprocessing.cpu_count() // 8 if ngpu == 8 else 12
        )
        env["MKL_NUM_THREADS"] = str(
            multiprocessing.cpu_count() // 8 if ngpu == 8 else 12
        )
        for run_params in tqdm.tqdm(
            [
                dict(
                    DEPTH=False,
                    COLOR=True,
                    SPLIT="train-2plus-mp3d",
                    NS=4,
                    NPROC=128,
                    MODEL="data/checkpoints/final/gibson-mp3d-rgb-se-rs9-r_0/ckpt.87.pth",
                ),
                dict(
                    DEPTH=True,
                    COLOR=False,
                    SPLIT="train-2plus",
                    NS=4 if ngpu == 8 else 16,
                    NPROC=128 if ngpu == 8 else 512,
                    MODEL="data/checkpoints/final/gibson-depth-se-rs9-r_0/ckpt.87.pth",
                ),
            ],
            leave=False,
        ):
            for seed in tqdm.trange(NUM_SEEDS, leave=False):
                command = COMMAND_TEMPLATE.format(
                    PYTHON=sys.executable,
                    NGPU=ngpu,
                    NUM_UPDATES=UPDATE_TARGET + 1,
                    # I like my mersenne primes :-)
                    SEED=seed * (1 << 13 - 1),
                    **run_params
                )
                tqdm.tqdm.write(command)

                res = subprocess.run(
                    shlex.split(command),
                    env=env,
                    check=False,
                    stderr=subprocess.STDOUT,
                    stdout=subprocess.PIPE,
                )

                matches = fps_finder.findall(res.stdout.decode("utf-8"))

                for m in matches:
                    if int(m[0]) == UPDATE_TARGET:
                        results["ngpu"].append(ngpu)
                        results["fps"].append(float(m[1]))
                        results["sensor"].append(
                            "rgb" if run_params["COLOR"] else "depth"
                        )

                        break

    with open(RESULTS_FILE, "w") as f:
        json.dump(results, f)


def print_results():
    import numpy as np

    with open(RESULTS_FILE, "r") as f:
        results = json.load(f)

    def _select(results, sensor, ngpu):
        return [
            results["fps"][i]
            for i in range(len(results["fps"]))
            if results["ngpu"][i] == ngpu and results["sensor"][i] == sensor
        ]

    for sensor in ["depth", "rgb"]:
        for ngpu in [1, 8]:
            config_results = _select(results, sensor, ngpu)
            print(
                sensor,
                ngpu,
                np.round(np.mean(config_results), 2),
                np.round(
                    1.96 * np.std(config_results) / np.sqrt(len(config_results)), 2
                ),
            )


if __name__ == "__main__":
    build_results()
    print_results()
