import numpy as np

import ray
import src.rllib.agents.ppo as ppo
from src.rllib.models.torch.torch_modelv2 import TorchModelV2
from src.rllib.models.catalog import ModelCatalog
from src.rllib.utils.framework import try_import_torch

torch, _ = try_import_torch()

# __sphinx_doc_begin__


# The custom model that will be wrapped by an LSTM.
class MyCustomModel(TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.num_outputs = int(np.product(self.obs_space.shape))
        self._last_batch_size = None

    # Implement your own forward logic, whose output will then be sent
    # through an LSTM.
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs_flat"]
        # Store last batch size for value_function output.
        self._last_batch_size = obs.shape[0]
        # Return 2x the obs (and empty states).
        # This will further be sent through an automatically provided
        # LSTM head (b/c we are setting use_lstm=True below).
        return obs * 2.0, []

    def value_function(self):
        return torch.from_numpy(np.zeros(shape=(self._last_batch_size, )))


if __name__ == "__main__":
    ray.init()

    # Register the above custom model.
    ModelCatalog.register_custom_model("my_torch_model", MyCustomModel)

    # Create the Trainer.
    trainer = ppo.PPOTrainer(
        env="CartPole-v0",
        config={
            "framework": "torch",
            "model": {
                # Auto-wrap the custom(!) model with an LSTM.
                "use_lstm": True,
                # To further customize the LSTM auto-wrapper.
                "lstm_cell_size": 64,

                # Specify our custom model from above.
                "custom_model": "my_torch_model",
                # Extra kwargs to be passed to your model's c'tor.
                "custom_model_config": {},
            },
        })
    trainer.train()

# __sphinx_doc_end__
