import os
from pathlib import Path
import re
import sys
import unittest

import ray
from ray import tune
from src.rllib.examples.env.multi_agent import MultiAgentCartPole
from src.rllib.utils.test_utils import framework_iterator


def rollout_test(algo, env="CartPole-v0", test_episode_rollout=False):
    extra_config = ""
    if algo == "ARS":
        extra_config = ",\"train_batch_size\": 10, \"noise_size\": 250000"
    elif algo == "ES":
        extra_config = ",\"episodes_per_batch\": 1,\"train_batch_size\": 10, "\
                       "\"noise_size\": 250000"

    for fw in framework_iterator(frameworks=("tf", "torch")):
        fw_ = ", \"framework\": \"{}\"".format(fw)

        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            sys.exit(1)

        print("Saving results to {}".format(tmp_dir))

        rllib_dir = str(Path(__file__).parent.parent.absolute())
        print("RLlib dir = {}\nexists={}".format(rllib_dir,
                                                 os.path.exists(rllib_dir)))
        os.system("python {}/train.py --local-dir={} --run={} "
                  "--checkpoint-freq=1 ".format(rllib_dir, tmp_dir, algo) +
                  "--config='{" + "\"num_workers\": 1, \"num_gpus\": 0{}{}".
                  format(fw_, extra_config) +
                  ", \"timesteps_per_iteration\": 5,\"min_iter_time_s\": 0.1, "
                  "\"model\": {\"fcnet_hiddens\": [10]}"
                  "}' --stop='{\"training_iteration\": 1}'" +
                  " --env={} --no-ray-ui".format(env))

        checkpoint_path = os.popen("ls {}/default/*/checkpoint_000001/"
                                   "checkpoint-1".format(tmp_dir)).read()[:-1]
        if not os.path.exists(checkpoint_path):
            sys.exit(1)
        print("Checkpoint path {} (exists)".format(checkpoint_path))

        # Test rolling out n steps.
        os.popen("python {}/rollout.py --run={} \"{}\" --steps=10 "
                 "--out=\"{}/rollouts_10steps.pkl\" --no-render".format(
                     rllib_dir, algo, checkpoint_path, tmp_dir)).read()
        if not os.path.exists(tmp_dir + "/rollouts_10steps.pkl"):
            sys.exit(1)
        print("rollout output (10 steps) exists!")

        # Test rolling out 1 episode.
        if test_episode_rollout:
            os.popen("python {}/rollout.py --run={} \"{}\" --episodes=1 "
                     "--out=\"{}/rollouts_1episode.pkl\" --no-render".format(
                         rllib_dir, algo, checkpoint_path, tmp_dir)).read()
            if not os.path.exists(tmp_dir + "/rollouts_1episode.pkl"):
                sys.exit(1)
            print("rollout output (1 ep) exists!")

        # Cleanup.
        os.popen("rm -rf \"{}\"".format(tmp_dir)).read()


def learn_test_plus_rollout(algo, env="CartPole-v0"):
    for fw in framework_iterator(frameworks=("tf", "torch")):
        fw_ = ", \\\"framework\\\": \\\"{}\\\"".format(fw)

        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            # Last resort: Resolve via underlying tempdir (and cut tmp_.
            tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
            if not os.path.exists(tmp_dir):
                sys.exit(1)

        print("Saving results to {}".format(tmp_dir))

        rllib_dir = str(Path(__file__).parent.parent.absolute())
        print("RLlib dir = {}\nexists={}".format(rllib_dir,
                                                 os.path.exists(rllib_dir)))
        os.system("python {}/train.py --local-dir={} --run={} "
                  "--checkpoint-freq=1 --checkpoint-at-end ".format(
                      rllib_dir, tmp_dir, algo) +
                  "--config=\"{\\\"num_gpus\\\": 0, \\\"num_workers\\\": 1, "
                  "\\\"evaluation_config\\\": {\\\"explore\\\": false}" + fw_ +
                  "}\" " + "--stop=\"{\\\"episode_reward_mean\\\": 150.0}\"" +
                  " --env={}".format(env))

        # Find last checkpoint and use that for the rollout.
        checkpoint_path = os.popen("ls {}/default/*/checkpoint_*/"
                                   "checkpoint-*".format(tmp_dir)).read()[:-1]
        checkpoints = [
            cp for cp in checkpoint_path.split("\n")
            if re.match(r"^.+checkpoint-\d+$", cp)
        ]
        # Sort by number and pick last (which should be the best checkpoint).
        last_checkpoint = sorted(
            checkpoints,
            key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
        assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
        if not os.path.exists(last_checkpoint):
            sys.exit(1)
        print("Best checkpoint={} (exists)".format(last_checkpoint))

        # Test rolling out n steps.
        result = os.popen(
            "python {}/rollout.py --run={} "
            "--steps=400 "
            "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
                rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
        if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
            sys.exit(1)
        print("Rollout output exists -> Checking reward ...")
        episodes = result.split("\n")
        mean_reward = 0.0
        num_episodes = 0
        for ep in episodes:
            mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
            if mo:
                mean_reward += float(mo.group(1))
                num_episodes += 1
        mean_reward /= num_episodes
        print("Rollout's mean episode reward={}".format(mean_reward))
        assert mean_reward >= 150.0

        # Cleanup.
        os.popen("rm -rf \"{}\"".format(tmp_dir)).read()


def learn_test_multi_agent_plus_rollout(algo):
    for fw in framework_iterator(frameworks=("tf", "torch")):
        tmp_dir = os.popen("mktemp -d").read()[:-1]
        if not os.path.exists(tmp_dir):
            # Last resort: Resolve via underlying tempdir (and cut tmp_.
            tmp_dir = ray._private.utils.tempfile.gettempdir() + tmp_dir[4:]
            if not os.path.exists(tmp_dir):
                sys.exit(1)

        print("Saving results to {}".format(tmp_dir))

        rllib_dir = str(Path(__file__).parent.parent.absolute())
        print("RLlib dir = {}\nexists={}".format(rllib_dir,
                                                 os.path.exists(rllib_dir)))

        def policy_fn(agent_id, episode, **kwargs):
            return "pol{}".format(agent_id)

        config = {
            "num_gpus": 0,
            "num_workers": 1,
            "evaluation_config": {
                "explore": False
            },
            "framework": fw,
            "env": MultiAgentCartPole,
            "multiagent": {
                "policies": {"pol0", "pol1"},
                "policy_mapping_fn": policy_fn,
            },
        }
        stop = {"episode_reward_mean": 150.0}
        tune.run(
            algo,
            config=config,
            stop=stop,
            checkpoint_freq=1,
            checkpoint_at_end=True,
            local_dir=tmp_dir,
            verbose=1)

        # Find last checkpoint and use that for the rollout.
        checkpoint_path = os.popen("ls {}/PPO/*/checkpoint_*/"
                                   "checkpoint-*".format(tmp_dir)).read()[:-1]
        checkpoint_paths = checkpoint_path.split("\n")
        assert len(checkpoint_paths) > 0
        checkpoints = [
            cp for cp in checkpoint_paths
            if re.match(r"^.+checkpoint-\d+$", cp)
        ]
        # Sort by number and pick last (which should be the best checkpoint).
        last_checkpoint = sorted(
            checkpoints,
            key=lambda x: int(re.match(r".+checkpoint-(\d+)", x).group(1)))[-1]
        assert re.match(r"^.+checkpoint_\d+/checkpoint-\d+$", last_checkpoint)
        if not os.path.exists(last_checkpoint):
            sys.exit(1)
        print("Best checkpoint={} (exists)".format(last_checkpoint))

        ray.shutdown()

        # Test rolling out n steps.
        result = os.popen(
            "python {}/rollout.py --run={} "
            "--steps=400 "
            "--out=\"{}/rollouts_n_steps.pkl\" --no-render \"{}\"".format(
                rllib_dir, algo, tmp_dir, last_checkpoint)).read()[:-1]
        if not os.path.exists(tmp_dir + "/rollouts_n_steps.pkl"):
            sys.exit(1)
        print("Rollout output exists -> Checking reward ...")
        episodes = result.split("\n")
        mean_reward = 0.0
        num_episodes = 0
        for ep in episodes:
            mo = re.match(r"Episode .+reward: ([\d\.\-]+)", ep)
            if mo:
                mean_reward += float(mo.group(1))
                num_episodes += 1
        mean_reward /= num_episodes
        print("Rollout's mean episode reward={}".format(mean_reward))
        assert mean_reward >= 150.0

        # Cleanup.
        os.popen("rm -rf \"{}\"".format(tmp_dir)).read()


class TestRolloutSimple1(unittest.TestCase):
    def test_a3c(self):
        rollout_test("A3C")

    def test_ddpg(self):
        rollout_test("DDPG", env="Pendulum-v0")


class TestRolloutSimple2(unittest.TestCase):
    def test_dqn(self):
        rollout_test("DQN")

    def test_es(self):
        rollout_test("ES")


class TestRolloutSimple3(unittest.TestCase):
    def test_impala(self):
        rollout_test("IMPALA", env="CartPole-v0")

    def test_ppo(self):
        rollout_test("PPO", env="CartPole-v0", test_episode_rollout=True)


class TestRolloutSimple4(unittest.TestCase):
    def test_sac(self):
        rollout_test("SAC", env="Pendulum-v0")


class TestRolloutLearntPolicy(unittest.TestCase):
    def test_ppo_train_then_rollout(self):
        learn_test_plus_rollout("PPO")

    def test_ppo_multi_agent_train_then_rollout(self):
        learn_test_multi_agent_plus_rollout("PPO")


if __name__ == "__main__":
    import pytest

    # One can specify the specific TestCase class to run.
    # None for all unittest.TestCase classes in this file.
    class_ = sys.argv[1] if len(sys.argv) > 1 else None
    sys.exit(
        pytest.main(
            ["-v", __file__ + ("" if class_ is None else "::" + class_)]))
