"""Example of using rollout worker classes directly to implement training.

Instead of using the built-in Trainer classes provided by RLlib, here we define
a custom Policy class and manually coordinate distributed sample
collection and policy optimization.
"""

import argparse
import gym
import numpy as np
import os

import ray
from ray import tune
from src.rllib.evaluation import RolloutWorker
from src.rllib.evaluation.metrics import collect_metrics
from src.rllib.policy.policy import Policy
from src.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch

parser = argparse.ArgumentParser()
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--num-iters", type=int, default=20)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--num-cpus", type=int, default=0)


class CustomPolicy(Policy):
    """Example of a custom policy written from scratch.

    You might find it more convenient to extend TF/TorchPolicy instead
    for a real policy.
    """

    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        self.config["framework"] = None
        # example parameter
        self.w = 1.0

    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        # return random actions
        return np.array(
            [self.action_space.sample() for _ in obs_batch]), [], {}

    def learn_on_batch(self, samples):
        # implement your learning code here
        return {}

    def update_some_value(self, w):
        # can also call other methods on policies
        self.w = w

    def get_weights(self):
        return {"w": self.w}

    def set_weights(self, weights):
        self.w = weights["w"]


def training_workflow(config, reporter):
    # Setup policy and policy evaluation actors
    env = gym.make("CartPole-v0")
    policy = CustomPolicy(env.observation_space, env.action_space, {})
    workers = [
        RolloutWorker.as_remote().remote(
            env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy)
        for _ in range(config["num_workers"])
    ]

    for _ in range(config["num_iters"]):
        # Broadcast weights to the policy evaluation workers
        weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()})
        for w in workers:
            w.set_weights.remote(weights)

        # Gather a batch of samples
        T1 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Update the remote policy replicas and gather another batch of samples
        new_value = policy.w * 2.0
        for w in workers:
            w.for_policy.remote(lambda p: p.update_some_value(new_value))

        # Gather another batch of samples
        T2 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Improve the policy using the T1 batch
        policy.learn_on_batch(T1)

        # Do some arbitrary updates based on the T2 batch
        policy.update_some_value(sum(T2["rewards"]))

        reporter(**collect_metrics(remote_workers=workers))


if __name__ == "__main__":
    args = parser.parse_args()
    ray.init(num_cpus=args.num_cpus or None)

    tune.run(
        training_workflow,
        resources_per_trial={
            "gpu": 1 if args.gpu
            or int(os.environ.get("RLLIB_FORCE_NUM_GPUS", 0)) else 0,
            "cpu": 1,
            "extra_cpu": args.num_workers,
        },
        config={
            "num_workers": args.num_workers,
            "num_iters": args.num_iters,
        },
        verbose=1,
    )
