# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Taken from https://raw.githubusercontent.com/facebookresearch/torchbeast/3f3029cf3d6d488b8b8f952964795f451a49048f/torchbeast/monobeast.py
# and modified slightly

import argparse
import logging
import os
import pprint
import threading
import time
import timeit
import traceback
import typing
import numpy as np
import itertools
import copy
import psutil

import tempfile
import os.path as path

os.environ["OMP_NUM_THREADS"] = "1"  # Necessary for multithreading.

import torch
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F

import continual_rl.utils.env_wrappers as atari_wrappers
from continual_rl.policies.impala.torchbeast.core import environment
from continual_rl.policies.impala.torchbeast.core import file_writer
from continual_rl.policies.impala.torchbeast.core import prof
from continual_rl.policies.impala.torchbeast.core import vtrace


# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")

parser.add_argument("--env", type=str, default="PongNoFrameskip-v4",
                    help="Gym environment.")
parser.add_argument("--mode", default="train",
                    choices=["train", "test", "test_render"],
                    help="Training or test mode.")
parser.add_argument("--xpid", default=None,
                    help="Experiment id (default: None).")

# Training settings.
parser.add_argument("--disable_checkpoint", action="store_true",
                    help="Disable saving checkpoint.")
parser.add_argument("--savedir", default="~/logs/torchbeast",
                    help="Root dir where experiment data will be saved.")
parser.add_argument("--num_actors", default=4, type=int, metavar="N",
                    help="Number of actors (default: 4).")
parser.add_argument("--total_steps", default=100000, type=int, metavar="T",
                    help="Total environment steps to train for.")
parser.add_argument("--batch_size", default=8, type=int, metavar="B",
                    help="Learner batch size.")
parser.add_argument("--unroll_length", default=80, type=int, metavar="T",
                    help="The unroll length (time dimension).")
parser.add_argument("--num_buffers", default=None, type=int,
                    metavar="N", help="Number of shared-memory buffers.")
parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int,
                    metavar="N", help="Number learner threads.")
parser.add_argument("--disable_cuda", action="store_true",
                    help="Disable CUDA.")
parser.add_argument("--use_lstm", action="store_true",
                    help="Use LSTM in agent model.")

# Loss settings.
parser.add_argument("--entropy_cost", default=0.0006,
                    type=float, help="Entropy cost/multiplier.")
parser.add_argument("--baseline_cost", default=0.5,
                    type=float, help="Baseline cost/multiplier.")
parser.add_argument("--discounting", default=0.99,
                    type=float, help="Discounting factor.")
parser.add_argument("--reward_clipping", default="abs_one",
                    choices=["abs_one", "none"],
                    help="Reward clipping.")

# Optimizer settings.
parser.add_argument("--learning_rate", default=0.00048,
                    type=float, metavar="LR", help="Learning rate.")
parser.add_argument("--alpha", default=0.99, type=float,
                    help="RMSProp smoothing constant.")
parser.add_argument("--momentum", default=0, type=float,
                    help="RMSProp momentum.")
parser.add_argument("--epsilon", default=0.01, type=float,
                    help="RMSProp epsilon.")
parser.add_argument("--grad_norm_clipping", default=40.0, type=float,
                    help="Global gradient norm clip.")
# yapf: enable


logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=0,
)

Buffers = typing.Dict[str, typing.List[torch.Tensor]]


def compute_baseline_loss(advantages):
    return 0.5 * torch.sum(advantages ** 2)


def compute_entropy_loss(logits):
    """Return the entropy loss, i.e., the negative entropy of the policy."""
    policy = F.softmax(logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    return torch.sum(policy * log_policy)


def compute_policy_gradient_loss(logits, actions, advantages):
    cross_entropy = F.nll_loss(
        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
        target=torch.flatten(actions, 0, 1),
        reduction="none",
    )
    cross_entropy = cross_entropy.view_as(advantages)
    return torch.sum(cross_entropy * advantages.detach())


def compute_policy_cloning_loss(old_logits, curr_logits):
    old_policy = F.softmax(old_logits, dim=-1)
    curr_log_policy = F.log_softmax(curr_logits, dim=-1)
    kl_loss = torch.nn.KLDivLoss(reduction='sum')(curr_log_policy, old_policy)
    return kl_loss


def compute_value_cloning_loss(old_value, curr_value):
    return torch.sum((curr_value - old_value.detach()) ** 2)


def get_replay_buffer_filled_indices(replay_buffers):
    # TODO: currently using a hack (if baseline.sum is exactly 0) to determine if an entry has been filled
    filled_buffer = np.concatenate([baseline.mean(dim=-1).detach() for baseline in replay_buffers['baseline']], axis=0)
    replay_indices = np.where(filled_buffer != 0)[0]
    return replay_indices


def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers: Buffers,
    replay_buffers: Buffers,
    initial_agent_state_buffers,
    replay_agent_state_buffers,
    replay_lock
):
    try:
        logging.info("Actor %i started.", actor_index)
        timings = prof.Timings()  # Keep track of how fast things are.

        gym_env = create_env(flags)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)
        env = environment.Environment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                timings.reset()

                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                timings.time("model")

                env_output = env.step(agent_output["action"])

                timings.time("step")

                for key in env_output:
                    buffers[key][index][0][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][0][t + 1, ...] = agent_output[key]

                timings.time("write")

            if flags.use_clear:
                # Each actor is responsible for populating a subset of the replay_buffer (only used in CLEAR)
                # To enable the enormous replay buffer to exist, it's structured as:
                # [num_actors, num_buffers_per_actor, ...]. The "normal" buffer is simply [num_buffers, 1, ...]
                # Generate the ids this actor is responsible, and some meta data required for reservoir sampling
                num_buffers_per_actor = len(replay_buffers['action'][actor_index])

                # Copy the most recent entry into the replay_buffer (vs "buffers" which is recent-only)
                new_entry_reservoir_val = np.random.uniform(0, 1.0)
                to_populate_replay_index = None

                # Get the unfilled entries in the replay buffer, in this actor's index range.
                # TODO: this is not the most efficient way to do this, but ...
                replay_indices = get_replay_buffer_filled_indices(replay_buffers)
                valid_id_set = set(range(num_buffers_per_actor * actor_index, num_buffers_per_actor * (actor_index+1)))
                valid_filled_replay_indices = valid_id_set.intersection(replay_indices)
                unfilled_indices = valid_id_set - valid_filled_replay_indices

                actor_replay_reservoir_vals = replay_buffers['reservoir_val'][actor_index]

                if len(unfilled_indices) > 0:
                    current_replay_index = min(unfilled_indices) % num_buffers_per_actor
                    to_populate_replay_index = current_replay_index
                else:

                    # If we've filled our quota, we need to find something to throw out.
                    reservoir_threshold = actor_replay_reservoir_vals.min()

                    # If our new value is higher than our existing minimum, replace that one with this new data
                    if new_entry_reservoir_val > reservoir_threshold:
                        to_populate_replay_index = np.argmin(actor_replay_reservoir_vals)

                # Do the replacement into the buffer, and update the reservoir_vals list
                if to_populate_replay_index is not None:
                    logging.info(f"For actor {actor_index}, replacing replay entry {to_populate_replay_index}")
                    actor_replay_reservoir_vals[to_populate_replay_index][0] = new_entry_reservoir_val

                    # The buffer indices are "locked" via the queues - when one is being processed by act, it's
                    # not available for processing by batch_and_learn, and vice versa. The replay buffers don't have
                    # this capability - they don't know ahead of time which one is going to need to get replaced.
                    # (TODO: ...well technically I guess they do, but it's awkward...)
                    with replay_lock:
                        for key in buffers:
                            if key == 'reservoir_val':
                                continue
                            replay_buffers[key][actor_index][to_populate_replay_index][...] = buffers[key][index]

                        for i, tensor in enumerate(agent_state):
                            # TODO: check...not currently using LSTMs so this is undervalidated
                            replay_agent_state_buffers[flags.num_actors * actor_index + to_populate_replay_index][i][...] = tensor

            time.sleep(0.01)  # TODO... this seems to prevent whatever deadlock was causing my with-clear runs to hang at 2M steps...
            full_queue.put(index)

        if actor_index == 0:
            logging.info("Actor %i: %s", actor_index, timings.summary())

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception as e:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise e


def get_batch(
    flags,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: Buffers,
    replay_buffers: Buffers,
    initial_agent_state_buffers,
    replay_agent_state_buffers,
    timings,
    replay_lock,
    lock=threading.Lock()
):
    with lock:
        timings.time("lock")
        num_new_batches = int(flags.batch_size * (1 - flags.replay_ratio)) if flags.use_clear else flags.batch_size
        indices = [full_queue.get() for _ in range(num_new_batches)]
        # TODO: if we abort early, indices will contain a None. Just letting that exception kill the thread for now.
        timings.time("dequeue")
    batch = {
        key: torch.stack([buffers[key][m][0] for m in indices], dim=1) for key in buffers  # 0: all normal buffers have 1 entry per buffer
    }
    initial_agent_state = (
        torch.cat(ts, dim=1)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    replay_batch = None

    if flags.use_clear:
        # Select a random batch set of replay buffers to add also. Only select from ones that have been filled
        # This is to make sure stuff is working without needing to pipe through some "populated" marker.
        # But it sure is hacky
        # TODO: probably somewhat broader than necessary, but...testing?
        with replay_lock:
            num_entries_per_actor = len(replay_buffers['baseline'][0])
            replay_indices = get_replay_buffer_filled_indices(replay_buffers)
            replay_entry_count = int(flags.batch_size * flags.replay_ratio)
            shuffled_subset = np.random.choice(replay_indices, replay_entry_count)  # Defaults to replace=True TODO?

            replay_batch = {
                # Get the actor_index and entry_id from the raw id
                key: torch.stack([replay_buffers[key][m//num_entries_per_actor][m % num_entries_per_actor]
                                  for m in shuffled_subset], dim=1) for key in replay_buffers
            }
            replay_agent_state = (
                torch.cat(ts, dim=1)
                for ts in zip(*[replay_agent_state_buffers[m] for m in shuffled_subset])
            )

        # Combine the replay in with the recent entries
        combo_batch = {
            key: torch.cat((batch[key], replay_batch[key]), dim=1) for key in batch
        }
        initial_agent_state = itertools.chain(initial_agent_state, replay_agent_state)
    else:
        combo_batch = batch

    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    combo_batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in combo_batch.items()}
    replay_batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in replay_batch.items()} if replay_batch is not None else None
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
    )
    timings.time("device")
    return replay_batch, combo_batch, batch, initial_agent_state, num_new_batches


def learn(
    flags,
    actor_model,
    model,
    batch,
    replay_batch,
    batch_for_train,
    initial_agent_state,
    optimizer,
    scheduler,
    lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        learner_outputs, unused_state = model(batch_for_train, initial_agent_state)

        # Take final value function slice for bootstrapping.
        bootstrap_value = learner_outputs["baseline"][-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch_for_train = {key: tensor[1:] for key, tensor in batch_for_train.items()}
        learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()}

        rewards = batch_for_train["reward"]
        if flags.reward_clipping == "abs_one":
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch_for_train["done"]).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch_for_train["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch_for_train["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        pg_loss = compute_policy_gradient_loss(
            learner_outputs["policy_logits"],
            batch_for_train["action"],
            vtrace_returns.pg_advantages,
        )
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"]
        )
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs["policy_logits"]
        )

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch["episode_return"][batch["done"]]
        stats = {
            "episode_returns": tuple(episode_returns.cpu().numpy()),
            "mean_episode_return": torch.mean(episode_returns).item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "abs_max_vtrace_advantage": torch.abs(vtrace_returns.pg_advantages).max().item()
        }

        if flags.use_clear:
            replay_learner_outputs, unused_state = model(replay_batch, initial_agent_state)

            replay_batch_policy = replay_batch['policy_logits']
            current_policy = replay_learner_outputs['policy_logits']
            policy_cloning_loss = flags.policy_cloning_cost * compute_policy_cloning_loss(replay_batch_policy, current_policy)

            replay_batch_baseline = replay_batch['baseline']
            current_baseline = replay_learner_outputs['baseline']
            value_cloning_loss = flags.value_cloning_cost * compute_value_cloning_loss(replay_batch_baseline, current_baseline)

            total_loss = total_loss + policy_cloning_loss + value_cloning_loss

            stats["policy_cloning_loss"] = policy_cloning_loss.item()
            stats["value_cloning_loss"] = value_cloning_loss.item()

        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
        optimizer.step()
        scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats


def create_buffers(flags, num_buffers, entries_per_buffer, unroll_length, obs_shape, num_actions) -> Buffers:
    T = unroll_length
    specs = dict(
        frame=dict(size=(entries_per_buffer, T + 1, *obs_shape), dtype=torch.uint8),
        reward=dict(size=(entries_per_buffer, T + 1,), dtype=torch.float32),
        done=dict(size=(entries_per_buffer, T + 1,), dtype=torch.bool),
        episode_return=dict(size=(entries_per_buffer, T + 1,), dtype=torch.float32),
        episode_step=dict(size=(entries_per_buffer, T + 1,), dtype=torch.int32),
        policy_logits=dict(size=(entries_per_buffer, T + 1, num_actions), dtype=torch.float32),
        baseline=dict(size=(entries_per_buffer, T + 1,), dtype=torch.float32),
        last_action=dict(size=(entries_per_buffer, T + 1,), dtype=torch.int64),
        action=dict(size=(entries_per_buffer, T + 1,), dtype=torch.int64),
        reservoir_val=dict(size=(entries_per_buffer, 1,), dtype=torch.float32),
    )
    buffers: Buffers = {key: [] for key in specs}
    for _ in range(num_buffers):
        for key in buffers:
            # Note: used to be "empty", zeros to make checking whether an entry has been filled more easily
            # (read: hackily)
            # Creating a file for storage allows us to have larger buffers (e.g. 5M+ replay buffer entries)
            # Specifying a location because I'm not confident the files were getting deleted in all cases, and this was
            # quickly filling up the drive. (TODO?)
            temp_file = tempfile.NamedTemporaryFile(dir=flags.large_file_path)

            size = 1
            for dim in specs[key]["size"]:
                size *= dim

            storage_type = None
            tensor_type = None
            if specs[key]["dtype"] == torch.uint8:
                storage_type = torch.ByteStorage
                tensor_type = torch.ByteTensor
            elif specs[key]["dtype"] == torch.int32:
                storage_type = torch.IntStorage
                tensor_type = torch.IntTensor
            elif specs[key]["dtype"] == torch.int64:
                storage_type = torch.LongStorage
                tensor_type = torch.LongTensor
            elif specs[key]["dtype"] == torch.bool:
                storage_type = torch.BoolStorage
                tensor_type = torch.BoolTensor
            elif specs[key]["dtype"] == torch.float32:
                storage_type = torch.FloatStorage
                tensor_type = torch.FloatTensor

            shared_file_storage = storage_type.from_file(temp_file.name, shared=True, size=size)
            buffers[key].append(tensor_type(shared_file_storage).view(specs[key]["size"]).share_memory_())
            #buffers[key].append(torch.zeros(**specs[key]).share_memory_())
    return buffers


def train(flags, existing_replay_buffers=None):  # pylint: disable=too-many-branches, too-many-statements
    if flags.xpid is None:
        flags.xpid = "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    plogger = file_writer.FileWriter(
        xpid=flags.xpid, xp_args=flags.__dict__, rootdir=flags.savedir
    )
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
    )

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    env = create_env(flags)

    model = Net(env.observation_space.shape, env.action_space.n, flags.use_lstm)
    buffers = create_buffers(flags, flags.num_buffers, 1, flags.unroll_length, env.observation_space.shape,
                             model.num_actions)

    try:
        checkpoint = torch.load(checkpointpath, map_location="cpu")
        model.load_state_dict(checkpoint["model_state_dict"])
    except FileNotFoundError:
        logging.warning("Failed to load existing model (possibly nothing to load)")
        pass

    if existing_replay_buffers is None and flags.use_clear:
        num_entries_per_buffer = flags.replay_buffer_size // flags.num_actors
        replay_buffers = create_buffers(flags, flags.num_actors, num_entries_per_buffer, flags.unroll_length,
                                        env.observation_space.shape, model.num_actions)
    else:
        replay_buffers = existing_replay_buffers

    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    replay_agent_state_buffers = []
    for _ in range(flags.replay_buffer_size):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        replay_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()
    replay_lock = threading.Lock()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                replay_buffers,
                initial_agent_state_buffers,
                replay_agent_state_buffers,
                replay_lock
            ),
        )
        actor.start()
        actor_processes.append(actor)

    learner_model = Net(
        env.observation_space.shape, env.action_space.n, flags.use_lstm
    ).to(device=flags.device)

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    def lr_lambda(epoch):
        # TODO: using the original B, which means that for CLEAR it's not quite right...but also I'm not sure what it should
        # be for CLEAR in the first place.
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    logger = logging.getLogger("logfile")
    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
    ]
    logger.info("# Step\t%s", "\t".join(stat_keys))

    step, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        try:
            nonlocal step, stats
            timings = prof.Timings()
            while step < flags.total_steps:
                timings.reset()
                replay_batch, batch_for_train, batch, agent_state, num_new_batches = get_batch(
                    flags,
                    free_queue,
                    full_queue,
                    buffers,
                    replay_buffers,
                    initial_agent_state_buffers,
                    replay_agent_state_buffers,
                    timings,
                    replay_lock
                )
                stats = learn(
                    flags, model, learner_model, batch, replay_batch, batch_for_train, agent_state, optimizer, scheduler
                )
                timings.time("learn")
                with lock:
                    to_log = dict(step=step)
                    to_log.update({k: stats[k] for k in stat_keys})
                    plogger.log(to_log)
                    step += T * num_new_batches

            if i == 0:
                logging.info("Batch and learn: %s", timings.summary())

        except KeyboardInterrupt:
            pass  # Return silently.
        except Exception as e:
            logging.error("Exception in batch and learn process %i", i)
            traceback.print_exc()
            print()
            raise e

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(
            target=batch_and_learn, name="batch-and-learn-%d" % i, args=(i,)
        )
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        last_returned_step = None
        while step < flags.total_steps:
            start_step = step
            start_time = timer()
            time.sleep(5)

            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            # Copy right away, because there's a race where stats can get re-set and then certain things set below
            # will be missing (eg "step")
            stats_to_return = copy.deepcopy(stats)

            sps = (step - start_step) / (timer() - start_time)
            if stats_to_return.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats_to_return["mean_episode_return"]
                )
            else:
                mean_return = ""
            total_loss = stats_to_return.get("total_loss", float("inf"))
            logging.info(
                "Steps %i @ %.1f SPS. Loss %f. %sStats:\n%s",
                step,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats_to_return),
            )
            stats_to_return["step"] = step

            if last_returned_step is None or last_returned_step != step:
                last_returned_step = step
                checkpoint()  # TODO: checkpointing more often...mostly to try get it before the last yield

                # The actors will keep going unless we pause them, so...do that.
                for actor in actor_processes:
                    psutil.Process(actor.pid).suspend()

                yield stats_to_return, replay_buffers

                # Resume the actors
                for actor in actor_processes:
                    psutil.Process(actor.pid).resume()

    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)
        for _ in range(flags.num_learner_threads * flags.batch_size):
            full_queue.put(None)

    checkpoint()
    plogger.close()


def test(flags, existing_replay_buffers=None, num_episodes: int = 10):
    if flags.xpid is None:
        checkpointpath = "./latest/model.tar"
    else:
        checkpointpath = os.path.expandvars(
            os.path.expanduser("%s/%s/%s" % (flags.savedir, flags.xpid, "model.tar"))
        )

    gym_env = create_env(flags)
    env = environment.Environment(gym_env)
    model = Net(gym_env.observation_space.shape, gym_env.action_space.n, flags.use_lstm)
    model.eval()
    checkpoint = torch.load(checkpointpath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])

    observation = env.initial()
    returns = []
    step = 0

    assert flags.return_after_reward_num is None or num_episodes >= flags.return_after_reward_num, \
        "Not enough episodes requested to fulfill the requested reward count"

    # Ensure we've both hit the min number of episodes and the min number of steps
    while (len(returns) < num_episodes or step < flags.total_steps) and \
            not (flags.return_after_reward_num is not None and len(returns) > flags.return_after_reward_num):
        if flags.mode == "test_render":
            env.gym_env.render()
        agent_outputs = model(observation)
        policy_outputs, _ = agent_outputs
        observation = env.step(policy_outputs["action"])
        step += 1
        if observation["done"].item():
            returns.append(observation["episode_return"].item())
            logging.info(
                "Episode ended after %d steps. Return: %.1f",
                observation["episode_step"].item(),
                observation["episode_return"].item(),
            )
    env.close()
    logging.info(
        "Average returns over %i episodes: %.1f", num_episodes, sum(returns) / len(returns)
    )
    stats = {"episode_returns": returns, "step": step, "num_episodes": num_episodes}
    yield stats, existing_replay_buffers  # Just for simple consistency with train


class AtariNet(nn.Module):
    def __init__(self, observation_shape, num_actions, use_lstm=False):
        super(AtariNet, self).__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions

        # Feature extraction.
        self.conv1 = nn.Conv2d(
            in_channels=self.observation_shape[0],
            out_channels=32,
            kernel_size=8,
            stride=4,
        )
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        # Fully connected layer.
        self.fc = nn.Linear(3136, 512)

        # FC output size + one-hot of last action + last reward.
        core_output_size = self.fc.out_features + num_actions + 1

        self.use_lstm = use_lstm
        if use_lstm:
            self.core = nn.LSTM(core_output_size, core_output_size, 2)

        self.policy = nn.Linear(core_output_size, self.num_actions)
        self.baseline = nn.Linear(core_output_size, 1)

    def initial_state(self, batch_size):
        if not self.use_lstm:
            return tuple()
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def forward(self, inputs, core_state=()):
        x = inputs["frame"]  # [T, B, C, H, W].
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(T * B, -1)
        x = F.relu(self.fc(x))

        one_hot_last_action = F.one_hot(
            inputs["last_action"].view(T * B), self.num_actions
        ).float()
        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1)
        core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1)

        if self.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input
            core_state = tuple()

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return (
            dict(policy_logits=policy_logits, baseline=baseline, action=action),
            core_state,
        )


Net = AtariNet


def create_env(flags):
    return atari_wrappers.wrap_pytorch(
        atari_wrappers.wrap_deepmind(
            atari_wrappers.make_atari(flags.env),
            clip_rewards=False,
            frame_stack=True,
            scale=False,
        )
    )


def main(flags):
    # If I keep this main-able, move these to args (TODO)
    flags.replay_buffer_size = 0
    flags.use_clear = False
    flags.return_after_reward_num = None
    flags.large_file_path = "tmp"

    if flags.mode == "train":
        generator = train(flags)
    else:
        generator = test(flags)

    while next(generator):
        pass


if __name__ == "__main__":
    flags = parser.parse_args()
    main(flags)
