import json
from pathlib import Path

import ray
from ray import tune
from src.rllib.agents.registry import get_trainer_class

from src.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from src.rllib.examples.env.stateless_cartpole import StatelessCartPole

envs = {
    "RepeatAfterMeEnv": RepeatAfterMeEnv,
    "StatelessCartPole": StatelessCartPole
}

config = {
    "name": "RNNSAC_example",
    "local_dir": str(Path(__file__).parent / "example_out"),
    "checkpoint_freq": 1,
    "keep_checkpoints_num": 1,
    "checkpoint_score_attr": "episode_reward_mean",
    "stop": {
        "episode_reward_mean": 65.0,
        "timesteps_total": 100000,
    },
    "metric": "episode_reward_mean",
    "mode": "max",
    "verbose": 2,
    "config": {
        "framework": "torch",
        "num_workers": 4,
        "num_envs_per_worker": 1,
        "num_cpus_per_worker": 1,
        "log_level": "INFO",

        # "env": envs["RepeatAfterMeEnv"],
        "env": envs["StatelessCartPole"],
        "horizon": 1000,
        "gamma": 0.95,
        "batch_mode": "complete_episodes",
        "prioritized_replay": False,
        "buffer_size": 100000,
        "learning_starts": 1000,
        "train_batch_size": 480,
        "target_network_update_freq": 480,
        "tau": 0.3,
        "burn_in": 4,
        "zero_init_states": False,
        "optimization": {
            "actor_learning_rate": 0.005,
            "critic_learning_rate": 0.005,
            "entropy_learning_rate": 0.0001
        },
        "model": {
            "max_seq_len": 20,
        },
        "policy_model": {
            "use_lstm": True,
            "lstm_cell_size": 64,
            "fcnet_hiddens": [64, 64],
            "lstm_use_prev_action": True,
            "lstm_use_prev_reward": True,
        },
        "Q_model": {
            "use_lstm": True,
            "lstm_cell_size": 64,
            "fcnet_hiddens": [64, 64],
            "lstm_use_prev_action": True,
            "lstm_use_prev_reward": True,
        },
    },
}

if __name__ == "__main__":
    # INIT
    ray.init(num_cpus=5)

    # TRAIN
    results = tune.run("RNNSAC", **config)

    # TEST
    best_checkpoint = results.best_checkpoint
    print("Loading checkpoint: {}".format(best_checkpoint))
    checkpoint_config_path = str(
        Path(best_checkpoint).parent.parent / "params.json")
    with open(checkpoint_config_path, "rb") as f:
        checkpoint_config = json.load(f)

    checkpoint_config["explore"] = False

    agent = get_trainer_class("RNNSAC")(
        env=config["config"]["env"], config=checkpoint_config)
    agent.restore(best_checkpoint)

    env = agent.env_creator({})
    state = agent.get_policy().get_initial_state()
    prev_action = 0
    prev_reward = 0
    obs = env.reset()

    eps = 0
    ep_reward = 0
    while eps < 10:
        action, state, info_trainer = agent.compute_action(
            obs,
            state=state,
            prev_action=prev_action,
            prev_reward=prev_reward,
            full_fetch=True)
        obs, reward, done, info = env.step(action)
        prev_action = action
        prev_reward = reward
        ep_reward += reward
        try:
            env.render()
        except (NotImplementedError, ImportError):
            pass
        if done:
            eps += 1
            print("Episode {}: {}".format(eps, ep_reward))
            ep_reward = 0
            state = agent.get_policy().get_initial_state()
            prev_action = 0
            prev_reward = 0
            obs = env.reset()
    ray.shutdown()
