import copy
import numpy as np
import unittest

import ray
import src.rllib.agents.pg as pg
from src.rllib.evaluation.postprocessing import Postprocessing
from src.rllib.models.tf.tf_action_dist import Categorical
from src.rllib.models.torch.torch_action_dist import TorchCategorical
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils import check, check_compute_single_action, fc, \
    framework_iterator


class TestPG(unittest.TestCase):
    def setUp(self):
        ray.init()

    def tearDown(self):
        ray.shutdown()

    def test_pg_compilation(self):
        """Test whether a PGTrainer can be built with both frameworks."""
        config = pg.DEFAULT_CONFIG.copy()
        config["num_workers"] = 1
        config["rollout_fragment_length"] = 500
        num_iterations = 1

        for _ in framework_iterator(config):
            for env in ["FrozenLake-v0", "CartPole-v0"]:
                trainer = pg.PGTrainer(config=config, env=env)
                for i in range(num_iterations):
                    print(trainer.train())
                check_compute_single_action(
                    trainer, include_prev_action_reward=True)

    def test_pg_fake_multi_gpu_learning(self):
        """Test whether PGTrainer can learn CartPole w/ faked multi-GPU."""
        config = copy.deepcopy(pg.DEFAULT_CONFIG)

        # Fake GPU setup.
        config["num_gpus"] = 2
        config["_fake_gpus"] = True

        # Mimic tuned_example for PG CartPole.
        config["model"]["fcnet_hiddens"] = [64]
        config["model"]["fcnet_activation"] = "linear"

        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            trainer = pg.PGTrainer(config=config, env="CartPole-v0")
            num_iterations = 300
            learnt = False
            for i in range(num_iterations):
                results = trainer.train()
                print("reward={}".format(results["episode_reward_mean"]))
                # Make this test quite short (75.0).
                if results["episode_reward_mean"] > 65.0:
                    learnt = True
                    break
            assert learnt,\
                "PG multi-GPU (with fake-GPUs) did not learn CartPole!"
            trainer.stop()

    def test_pg_loss_functions(self):
        """Tests the PG loss function math."""
        config = pg.DEFAULT_CONFIG.copy()
        config["num_workers"] = 0  # Run locally.
        config["gamma"] = 0.99
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = "linear"

        # Fake CartPole episode of n time steps.
        train_batch = SampleBatch({
            SampleBatch.OBS: np.array([[0.1, 0.2, 0.3,
                                        0.4], [0.5, 0.6, 0.7, 0.8],
                                       [0.9, 1.0, 1.1, 1.2]]),
            SampleBatch.ACTIONS: np.array([0, 1, 1]),
            SampleBatch.REWARDS: np.array([1.0, 1.0, 1.0]),
            SampleBatch.DONES: np.array([False, False, True]),
            SampleBatch.EPS_ID: np.array([1234, 1234, 1234]),
            SampleBatch.AGENT_INDEX: np.array([0, 0, 0]),
        })

        for fw, sess in framework_iterator(config, session=True):
            dist_cls = (Categorical if fw != "torch" else TorchCategorical)
            trainer = pg.PGTrainer(config=config, env="CartPole-v0")
            policy = trainer.get_policy()
            vars = policy.model.trainable_variables()
            if sess:
                vars = policy.get_session().run(vars)

            # Post-process (calculate simple (non-GAE) advantages) and attach
            # to train_batch dict.
            # A = [0.99^2 * 1.0 + 0.99 * 1.0 + 1.0, 0.99 * 1.0 + 1.0, 1.0] =
            # [2.9701, 1.99, 1.0]
            train_batch_ = pg.post_process_advantages(policy,
                                                      train_batch.copy())
            if fw == "torch":
                train_batch_ = policy._lazy_tensor_dict(train_batch_)

            # Check Advantage values.
            check(train_batch_[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0])

            # Actual loss results.
            if sess:
                results = policy.get_session().run(
                    policy._loss,
                    feed_dict=policy._get_loss_inputs_dict(
                        train_batch_, shuffle=False))
            else:
                results = (pg.pg_tf_loss
                           if fw in ["tf2", "tfe"] else pg.pg_torch_loss)(
                               policy,
                               policy.model,
                               dist_class=dist_cls,
                               train_batch=train_batch_)

            # Calculate expected results.
            if fw != "torch":
                expected_logits = fc(
                    fc(train_batch_[SampleBatch.OBS],
                       vars[0],
                       vars[1],
                       framework=fw),
                    vars[2],
                    vars[3],
                    framework=fw)
            else:
                expected_logits = fc(
                    fc(train_batch_[SampleBatch.OBS],
                       vars[2],
                       vars[3],
                       framework=fw),
                    vars[0],
                    vars[1],
                    framework=fw)
            expected_logp = dist_cls(expected_logits, policy.model).logp(
                train_batch_[SampleBatch.ACTIONS])
            adv = train_batch_[Postprocessing.ADVANTAGES]
            if sess:
                expected_logp = sess.run(expected_logp)
            elif fw == "torch":
                expected_logp = expected_logp.detach().cpu().numpy()
                adv = adv.detach().cpu().numpy()
            else:
                expected_logp = expected_logp.numpy()
            expected_loss = -np.mean(expected_logp * adv)
            check(results, expected_loss, decimals=4)


if __name__ == "__main__":
    import pytest
    import sys
    sys.exit(pytest.main(["-v", __file__]))
