"""This file contains a pipeline for running steps in experiments related to reward learning and evaluation.

This pipeline is basically a custom pipeline tool (e.g., like Apache Airflow). I should probably use an existing
pipeline library, but that might be overkill and I don't know any well enough to use them without the risk
of it not working out.

At a high level, the pipeline runs steps necessary for reward learning and reward evaluation in a way that is general
with respect to the environment / MDP. The steps are:
1. (optional) Learning an (expert) policy.
2. Collecting data from that policy or from a random policy.
3. (optional) Learning one or more reward models from that data.
4. Evaluating learned or manually defined reward models using one or more evaluation algorithms.
5. Visualizing the results of those evaluations.

The pipeline as a whole operates on a single high-level config file. That config file references one or more config
files for each step in the pipeline.

This pipeline is design for the following requirements (nonexhaustive):
1. Each step of the pipeline should be runnable individually (assuming necessary outputs from prior steps are available).
2. The intermediate outputs from each step should be saved to disk in some format to enable (1) among other reasons.
3. The individual step logic should not be specific to the manner in which data is saved between steps.
4. When multiple algorithms run during step (e.g., multiple reward learning algorithms are to be run at once),
    they should be runnable individually, and their outputs should be writable independent of previous outputs.
5. It should be easy to parallelize over a single computer (but not a cluster) if it comes to it.
6. Within a step, a configuration file completely and uniquely defines which algorithm to run and how to run it.

The configuration files for this pipeline are stored in a particular manner. Here's the expected structure:

> pipeline.yaml (the high-level config that orchestrates the pipeline)
    > env_name_1
        > expert_policy_learning
            > defaults.yaml
            > algo_1.yaml
        > data_generation
            > random.yaml
        > reward_learning
            > regression
                > defaults.yaml
                > regression_config_1.yaml
            > other_algorithm_dir
        > reward_evaluation
            > defaults.yaml
            > evaluation_algo_1.yaml
            > evaluation_algo_2.yaml
        > policy_evaluation
            > defaults.yaml

The env_name is extracted from the config filepath so the configs have to be stored under the env name.
The default.yaml is loaded first and then merged with specific algorithms to allow for reusing config values.
"""
from __future__ import annotations

import concurrent.futures
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
import enum
import os
import pathlib
import random
from typing import Dict, List, Optional, Union

import fire
import numpy as np
from omegaconf import OmegaConf
import torch

from offline_rl.scripts.policy_evaluation.run_policy_evaluation import main as run_policy_evaluation
from offline_rl.scripts.rewards.evaluation.run_gym_reward_evaluation import main as run_reward_evaluation
from offline_rl.scripts.rewards.learning.run_gym_reward_learning import main as run_reward_learning
from offline_rl.scripts.rl.run import ExperimentManager
from offline_rl.utils.file_utils import load_json, save_json, get_datetime_string
from offline_rl.utils.rllib_utils import get_latest_rllib_checkpoint


class DemonstrationDatasetType(str, enum.Enum):
    """Enum of the types of demonstration datasets.
    
    This inherits from str in order to make it json serializable.
    """
    # A dataset collected by an expert policy.
    EXPERT = 1
    # A dataset collected by randomly sampling an action space, _not_ from a randomly init policy.
    RANDOM = 2
    # A dataset that is a mixture of expert and random data. If other types of datasets exist those
    # should not be included in this type of dataset.
    MIXED = 3


# pylint: disable=pointless-string-statement
"""These storage and output classes essentially play the role of a database in a normal pipeline library."""
@dataclass
class StoredDemonstrationDataset:
    """Represents a demonstration dataset stored on the filesystem.

    We define "dataset" here to refer to the train/val datasets used for reward learning
    as well as the dataset used for reward evaluation.

    Args:
        dataset_type: The type of this dataset.
        reward_learning_train_pattern: The filepath pattern of the dataset used for learning a reward model.
        reward_learning_val_pattern: The filepath pattern of the dataset used for validating a reward model.
        reward_evaluation_pattern: The filepath pattern of the dataset used for reward model evaluation.
    """
    dataset_type: DemonstrationDatasetType
    reward_learning_train_pattern: str
    reward_learning_val_pattern: str
    reward_evaluation_pattern: str

    def data(self) -> Dict:
        """Returns the data in this class from which it can be reconstructed."""
        return self.__dict__


@dataclass
class StoredRewardModel:
    """Represents a reward model stored on the filesystem.

    Args:
        dataset_type: The type of dataset this reward model was trained on.
        config_filepath: Filepath of config file used to train this reward model.
        checkpoint_filepath: Filepath to the saved weights of this reward model.
    """
    dataset_type: DemonstrationDatasetType
    config_filepath: str
    checkpoint_filepath: str

    def data(self):
        """Returns the data in this class from which it can be reconstructed."""
        return self.__dict__


@dataclass
class PolicyLearningOutput:
    """The output from policy learning.
    
    This exists to encapsulate the details of tracking / loading policy checkpoints.
    There are some other advantages as well:
    - Allows for easily hard-coding a policy to pass to data generation.
    - Allows for easily implementing storing a policy on s3 without changes to data generation.
    - Allows for easily extending policy learning to output multiple policy checkpoints.

    Internally, this saves / loads a json file with paths to checkpoint file(s).

    Args:
        config_filepath: The filepath of the rllib config file used to train the policy.
        checkpoint_filepath: The filepath of the checkpoint to restore (rllib).
    """
    config_filepath: str
    checkpoint_filepath: str

    @staticmethod
    def save(
            output_filepath: str,
            config_filepath: str,
            checkpoint_filepath: str,
    ) -> None:
        """Saves data necessary to instantiate this class to file.

        Args:
            output_filepath: Where to save the data.
            config_filepath: The filepath of the rllib config file used to train the policy.
            checkpoint_filepath: The filepath of the checkpoint to restore (rllib).
        """
        assert os.path.exists(config_filepath)
        assert os.path.exists(checkpoint_filepath)
        save_json(output_filepath, {"checkpoint": checkpoint_filepath, "config": config_filepath})

    @classmethod
    def load(cls: type, filepath: str) -> PolicyLearningOutput:
        """Loads this class from a file."""
        data = load_json(filepath)
        assert os.path.exists(data["checkpoint"])
        assert os.path.exists(data["config"])
        return cls(data["config"], data["checkpoint"])


@dataclass
class DataGenerationOutput:
    """Represents the output from data generation.

    This class saves and loads the output from file, again acting as a sort of database.

    Args:
        stored_expert_dataset: The dataset generated from an expert.
        stored_random_dataset: The dataset generated from a random sampling of the action space.
    """
    stored_expert_dataset: StoredDemonstrationDataset
    stored_random_dataset: StoredDemonstrationDataset

    def get_dataset(self, dataset_type: DemonstrationDatasetType) -> StoredDemonstrationDataset:
        """Gets the dataset associated with the provided type.

        Args:
            dataset_type: The type of the dataset to get.

        Returns:
            The stored dataset.
        """
        assert isinstance(dataset_type, DemonstrationDatasetType)
        if dataset_type == DemonstrationDatasetType.EXPERT:
            return self.stored_expert_dataset
        elif dataset_type == DemonstrationDatasetType.RANDOM:
            return self.stored_random_dataset
        else:
            raise NotImplementedError()

    @staticmethod
    def save(
            output_filepath: str,
            expert_dataset: StoredDemonstrationDataset,
            random_dataset: StoredDemonstrationDataset,
    ) -> None:
        """Saves the output data to file.

        Args:
            output_filepath: Where to save the data.
            expert_dataset: The expert dataset to save.
            random_dataset: The random dataset to save.
        """
        data = {
            "expert": expert_dataset.data(),
            "random": random_dataset.data(),
        }
        save_json(output_filepath, data)

    @classmethod
    def load(cls: type, filepath: str) -> DataGenerationOutput:
        """Loads this class from file.

        Args:
            filepath: The filepath from which to load the class contents.

        Returns:
            An instantiation of this class loaded from the file.
        """
        data = load_json(filepath)
        expert_dataset = StoredDemonstrationDataset(**data["expert"])
        random_dataset = StoredDemonstrationDataset(**data["random"])
        return cls(expert_dataset, random_dataset)


@dataclass
class RewardLearningOutput:
    """Represents the output from reward learning.

    Args:
        models: A mapping from reward learning algorithm name to a stored reward model.
    """
    models: Dict[str, StoredRewardModel]

    @staticmethod
    def save(output_filepath: str, models: Dict[str, StoredRewardModel]) -> None:
        """Saves models to file to be loaded later.

        If the output file already exists, this method updates its contents to include / overwrite
        the provided models.

        Args:
            output_filepath: Where to save the data.
            models: The models to save. This may be a partial set of models. 
        """
        data = dict()
        if os.path.exists(output_filepath):
            data = load_json(output_filepath)

        new_data = {k: v.data() for k, v in models.items()}
        data.update(new_data)
        save_json(output_filepath, data)

    @classmethod
    def load(cls: type, filepath: str) -> RewardLearningOutput:
        """Loads this class from file.

        Args:
            filepath: File from which to load this class.

        Returns:
            This class with models populated.
        """
        data = load_json(filepath)
        models = {k: StoredRewardModel(**v) for k, v in data.items()}
        return cls(models)


class RewardEvaluationOutput:
    pass


class Pipeline:
    """The pipeline that orchestrates the individual steps.
        
    Args:
        experiment_dir: The directory in which all experiments outputs should be stored.
        debug: If True, runs all steps in a debug mode such that it runs quickly.
            This is useful for checking for bugs.
        seed: Random seed to use for experiments.
    """
    # The strings corresponding to the steps in order.
    ORDERED_STEPS = [
        "expert_policy_learning",
        "data_generation",
        "reward_learning",
        "reward_evaluation",
        "arbitrary_reward_policy_learning",
        "policy_evaluation",
        "visualization",
    ]

    # Debug constants.
    DEBUG_POLICY_TRAINING_TIMESTEPS = 10000
    DEBUG_DATASET_SIZE = 1000
    DEBUG_DATASET_LOADING_MODE = "ordered"
    DEBUG_REWARD_TRAINING_EPOCHS = 2
    DEBUG_REWARD_EVALUATION_SIZE = 1000
    DEBUG_POLICY_EVALUATION_STEPS = 1000

    def __init__(self, experiment_dir: str, debug: bool = False, seed: int = 0):
        self._experiment_dir = experiment_dir
        self.debug = debug
        # Use datetime as a unique identifier for each execution of any step in the pipeline.
        self.unique_id = get_datetime_string()
        # Seed set internally
        self._set_seed(seed)

    @property
    def experiment_dir(self):
        """Returns the experiment directory with the random seed included in the path."""
        return os.path.join(self._experiment_dir, f"{self.seed:04d}")

    @property
    def gpu_index(self) -> int:
        """Returns the index of the gpu to use for this experiment.

        This assumes that each experiment should run on a single gpu.
        """
        num_gpus = torch.cuda.device_count()
        gpu_index = self.seed % num_gpus
        return gpu_index

    def _set_gpu_environ(self) -> None:
        """Sets the gpu environment variables."""
        # This works for ray I believe b/c ray.init is what checks this environment variable.
        os.environ["CUDA_VISIBLE_DEVICES"] = f"{self.gpu_index}"

    def _get_policy_learning_dir(self, env_name: str) -> str:
        """Gets the directory where policy learning outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "expert_policy_learning")

    def _get_policy_learning_output_filepath(self, env_name: str) -> str:
        """Gets the file where policy learning outputs are tracked."""
        return os.path.join(self._get_policy_learning_dir(env_name), "output.json")

    def run_policy_learning(self, config_filepath: str, env_name: str) -> None:
        """Runs the policy learning step.

        This step saves a file containing references to its outputs that the next step can load.
        That file is saved in a fixed location (relative to `self.experiment_dir`) and loaded as
        such from subsequent steps.

        Args:
            config_filepath: Filepath to rllib config file used to train the policy.
            env_name: Name of the environment to train in.
        """
        manager = ExperimentManager(self.experiment_dir)
        output_dir = self._get_policy_learning_dir(env_name)
        self._set_gpu_environ()
        config_overrides = {}
        if self.debug:
            config_overrides["stop.timesteps_total"] = self.DEBUG_POLICY_TRAINING_TIMESTEPS
            config_overrides["config.num_workers"] = 0
        manager.train(config_filepath, output_dir=output_dir, **config_overrides)
        checkpoint_filepath = get_latest_rllib_checkpoint(output_dir)

        assert checkpoint_filepath is not None, "No policy checkpoint output from rllib policy training."
        output_filepath = self._get_policy_learning_output_filepath(env_name)
        PolicyLearningOutput.save(output_filepath, config_filepath, checkpoint_filepath)

    def _get_data_generation_dir(self, env_name: str) -> str:
        """Gets the directory where data generation outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "data_generation")

    def _get_data_generation_output_filepath(self, env_name: str) -> str:
        """Gets the file where data generation outputs are tracked."""
        return os.path.join(self._get_data_generation_dir(env_name), "output.json")

    def _run_data_generation_for_dataset(
            self,
            dataset_type: DemonstrationDatasetType,
            config_filepath: str,
            env_name: str,
            reward_learning_train_size: int,
            reward_learning_val_size: int,
            reward_evaluation_size: int,
            config_overrides: Dict,
    ) -> StoredDemonstrationDataset:
        """Runs data generation according to an abitrary config file."""
        if self.debug:
            reward_learning_train_size = self.DEBUG_DATASET_SIZE
            reward_learning_val_size = self.DEBUG_DATASET_SIZE
            reward_evaluation_size = self.DEBUG_DATASET_SIZE

        manager = ExperimentManager(self.experiment_dir)
        output_dir = self._get_data_generation_dir(env_name)
        dataset_dir = os.path.join(output_dir, self.unique_id, dataset_type.name.lower())

        reward_learning_train_dir = os.path.join(dataset_dir, "reward_learning_train")
        manager.collect(
            config_filepath,
            size=reward_learning_train_size,
            dataset_dir=reward_learning_train_dir,
            **config_overrides,
        )

        reward_learning_val_dir = os.path.join(dataset_dir, "reward_learning_val")
        manager.collect(
            config_filepath,
            size=reward_learning_val_size,
            dataset_dir=reward_learning_val_dir,
            **config_overrides,
        )

        reward_evaluation_dir = os.path.join(dataset_dir, "reward_evaluation")
        manager.collect(
            config_filepath,
            size=reward_evaluation_size,
            dataset_dir=reward_evaluation_dir,
            **config_overrides,
        )

        return StoredDemonstrationDataset(
            dataset_type,
            reward_learning_train_dir,
            reward_learning_val_dir,
            reward_evaluation_dir,
        )

    def _get_num_steps_from_rllib_config(self, config_filepath: str) -> int:
        """Extracts the number of training steps from an rllib config."""
        config = OmegaConf.load(config_filepath)
        assert len(config) == 1, "rllib config should have one top-level key."
        key = next(iter(config))
        # pylint: disable=unsubscriptable-object
        num_steps = config[key].stop.timesteps_total
        assert isinstance(num_steps, int)
        return num_steps

    def _run_expert_data_generation(
            self,
            env_name: str,
            reward_learning_train_size: int,
            reward_learning_val_size: int,
            reward_evaluation_size: int,
    ) -> StoredDemonstrationDataset:
        """Runs data generation with an expert policy."""
        policy_learning_output_filepath = self._get_policy_learning_output_filepath(env_name)
        assert os.path.exists(
            policy_learning_output_filepath), "Expert policy must exist for collection of expert dataset"
        policy_learning_output = PolicyLearningOutput.load(policy_learning_output_filepath)

        config_overrides = {
            "restore": policy_learning_output.checkpoint_filepath,
            "config.lr": 0.0,
        }

        # Get the number of training timesteps to add to the sizes to collect.
        num_training_steps = self._get_num_steps_from_rllib_config(policy_learning_output.config_filepath)

        return self._run_data_generation_for_dataset(
            DemonstrationDatasetType.EXPERT,
            policy_learning_output.config_filepath,
            env_name,
            reward_learning_train_size + num_training_steps,
            reward_learning_val_size + num_training_steps,
            reward_evaluation_size + num_training_steps,
            config_overrides,
        )

    def _run_random_data_generation(
            self,
            env_name: str,
            config_filepath: str,
            reward_learning_train_size: int,
            reward_learning_val_size: int,
            reward_evaluation_size: int,
    ) -> StoredDemonstrationDataset:
        """Runs data generation with a random policy."""
        return self._run_data_generation_for_dataset(
            DemonstrationDatasetType.RANDOM,
            config_filepath,
            env_name,
            reward_learning_train_size,
            reward_learning_val_size,
            reward_evaluation_size,
            config_overrides={},
        )

    def run_data_generation(
            self,
            env_name: str,
            random_config_filepath: str,
            reward_learning_train_size: int,
            reward_learning_val_size: int,
            reward_evaluation_size: int,
    ) -> None:
        """Runs the data generation step.

        This step always runs data generation for both a random and expert policy.
        The reason for this is that generating both datasets is probably necessary, and only
        doing it based on whether one or the other is strictly necessary complicates matters
        to an extent that it's not worth handling.

        Args:
            env_name: Name of the env for which to generate data.
            random_config_filepath: Filepath of config for generating random data.
                The expert config filepath is already stored in a known location.
            reward_learning_train_size: Size of dataset to collect for reward learning training.
            reward_learning_val_size: Size of dataset to collect for reward learning validation.
            reward_evaluation_size: Size of dataset to collect for reward evaluation.
        """
        stored_expert_dataset = self._run_expert_data_generation(
            env_name,
            reward_learning_train_size,
            reward_learning_val_size,
            reward_evaluation_size,
        )
        stored_random_dataset = self._run_random_data_generation(
            env_name,
            random_config_filepath,
            reward_learning_train_size,
            reward_learning_val_size,
            reward_evaluation_size,
        )
        DataGenerationOutput.save(
            self._get_data_generation_output_filepath(env_name),
            stored_expert_dataset,
            stored_random_dataset,
        )

    def _overwite_reward_learning_datasets(
            self,
            env_name: str,
            conf: OmegaConf,
            dataset_type: DemonstrationDatasetType,
    ) -> OmegaConf:
        """Overwrites the dataset paths in the provided config for the reward lerning case."""
        data_generation_output = DataGenerationOutput.load(self._get_data_generation_output_filepath(env_name))
        stored_dataset = data_generation_output.get_dataset(dataset_type)
        conf.data.train_dataset_filepath = stored_dataset.reward_learning_train_pattern
        conf.data.val_dataset_filepath = stored_dataset.reward_learning_val_pattern
        return conf

    def _get_reward_learning_dir(self, env_name: str) -> str:
        """Gets the directory where reward learning outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "reward_learning")

    def _get_reward_learning_output_filepath(self, env_name: str) -> str:
        """Gets the file where reward learning outputs are tracked."""
        return os.path.join(self._get_reward_learning_dir(env_name), "output.json")

    def run_reward_learning(self, algo_config_filepath: str, env_name: str, dataset_type: str, algo_name: str) -> None:
        """Runs reward learning for single reward learning algorithm on a specified dataset.

        Args:
            algo_config_filepath: Filepath to reward learning algorithm to run.
            env_name: Name of environment on which to run.
            dataset_type: Type of the dataset to learn on.
            algo_name: What to call the resulting reward model (probably depends on both algo and dataset).
        """
        algo_config = self._load_config_with_defaults(algo_config_filepath)
        if self.debug:
            algo_config.training.trainer_args.max_epochs = self.DEBUG_REWARD_TRAINING_EPOCHS
            algo_config.data.debug_size = self.DEBUG_DATASET_SIZE
            algo_config.data.debug_size_mode = self.DEBUG_DATASET_LOADING_MODE

        dataset_type = DemonstrationDatasetType[dataset_type.upper()]
        algo_config = self._overwite_reward_learning_datasets(env_name, algo_config, dataset_type)

        # Only run on one gpu.
        algo_config.training.trainer_args.gpus = [self.gpu_index]

        algo_output_dir = os.path.join(self._get_reward_learning_dir(env_name), algo_name)
        algo_config.training.output_dir = algo_output_dir

        best_model_checkpoint_filepath = run_reward_learning(algo_config)

        RewardLearningOutput.save(
            self._get_reward_learning_output_filepath(env_name), {
                algo_name:
                StoredRewardModel(
                    dataset_type,
                    os.path.join(algo_output_dir, "config.yaml"),
                    best_model_checkpoint_filepath,
                )
            })

    def multirun_reward_learning(self, config_filepath: str, env_name: str) -> None:
        """Runs reward learning for all algorithms listed in a meta config file.

        Saves results in a predefined location (relative to self.experiment_dir).

        Args:
            config_filepath: The meta config filepath.
            env_name: Name of env to run on.
        """
        config = OmegaConf.load(config_filepath)
        env_config = self._get_config_section_for_env(config, env_name)
        for algo_meta_config in env_config.reward_learning.algorithms:
            self.run_reward_learning(
                self._expand_relative_config_path(
                    config_filepath,
                    env_name,
                    algo_meta_config.config,
                ),
                env_name,
                algo_meta_config.dataset_type,
                algo_meta_config.name,
            )

    def _overwite_reward_evaluation_dataset(
            self,
            env_name: str,
            conf: OmegaConf,
            dataset_type: DemonstrationDatasetType,
    ) -> OmegaConf:
        """Overwrites dataset paths in reward evaluation config."""
        data_generation_output = DataGenerationOutput.load(self._get_data_generation_output_filepath(env_name))
        stored_dataset = data_generation_output.get_dataset(dataset_type)
        conf.data.dataset_filepath = stored_dataset.reward_evaluation_pattern
        return conf

    def _overwrite_reward_evaluation_learned_reward_models(self, env_name: str, conf: OmegaConf) -> OmegaConf:
        """Overwrites the paths of the learned reward models in a reward evaluation config."""
        reward_learning_output = RewardLearningOutput.load(self._get_reward_learning_output_filepath(env_name))
        conf.rewards.learned = {model_name: model.data() for model_name, model in reward_learning_output.models.items()}
        return conf

    def _get_reward_evaluation_dir(self, env_name: str) -> str:
        """Gets the directory where reward evaluation outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "reward_evaluation")

    def _get_reward_evaluation_output_filepath(self, env_name: str) -> str:
        """Gets the file where reward evaluation outputs are tracked."""
        return os.path.join(self._get_reward_evaluation_dir(env_name), "output.json")

    def run_reward_evaluation(
            self,
            algo_config_filepath: str,
            env_name: str,
            dataset_type: str,
            algo_name: str,
    ) -> None:
        """Runs reward evaluation for a single evaluation algorithm.

        Args:
            algo_config_filepath: Filepath to reward evaluation algorithm to run.
            env_name: Name of env to run evaluation on.
            dataset_type: Type of the dataset to run evaluation on.
            algo_name: What to call this evaluation result (probably depends on algo and dataset).
        """
        algo_config = self._load_config_with_defaults(algo_config_filepath)

        if self.debug:
            algo_config.data.debug_size = self.DEBUG_REWARD_EVALUATION_SIZE
            algo_config.data.debug_size_mode = self.DEBUG_DATASET_LOADING_MODE

        dataset_type = DemonstrationDatasetType[dataset_type.upper()]
        algo_config = self._overwite_reward_evaluation_dataset(env_name, algo_config, dataset_type)
        algo_config = self._overwrite_reward_evaluation_learned_reward_models(env_name, algo_config)

        # Run evaluation only on this gpu index.
        algo_config.common.device = f"cuda:{self.gpu_index}"

        algo_output_dir = os.path.join(self._get_reward_evaluation_dir(env_name), algo_name)
        algo_config.visualization.output_dir = algo_output_dir
        run_reward_evaluation(algo_config)

    def multirun_reward_evaluation(self, config_filepath: str, env_name: str) -> None:
        """Runs reward evaluation of multiple algorithms based on a meta config file.

        Args:
            config_filepath: Config file specifying which algorithms to run on which datasets.
            env_name: Name of env to run evaluation on.
        """
        config = OmegaConf.load(config_filepath)
        env_config = self._get_config_section_for_env(config, env_name)
        for algo_meta_config in env_config.reward_evaluation.algorithms:
            self.run_reward_evaluation(
                self._expand_relative_config_path(
                    config_filepath,
                    env_name,
                    algo_meta_config.config,
                ),
                env_name,
                algo_meta_config.dataset_type,
                algo_meta_config.name,
            )

    def _overwrite_arbitrary_reward_policy_learning_learned_reward_models(
            self,
            env_name: str,
            conf: OmegaConf,
            reward_model_name: str,
    ) -> OmegaConf:
        """Overwrites the paths of the learned reward models in a policy learning config."""
        reward_learning_output = RewardLearningOutput.load(self._get_reward_learning_output_filepath(env_name))
        assert len(conf.keys()) == 1, "Policy learning config should have one top-level key"
        key = next(iter(conf.keys()))
        conf[key].config.reward_models.rewards.learned = {
            model_name: model.data()
            for model_name, model in reward_learning_output.models.items()
        }
        conf[key].config.reward_models.model_name = reward_model_name
        return conf

    def _get_arbitrary_reward_policy_learning_dir(self, env_name: str) -> str:
        """Gets the directory where reward evaluation outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "arbitrary_reward_policy_learning")

    def run_arbitrary_reward_policy_learning(self, config_filepath: str, env_name: str, reward_model_name: str) -> None:
        """Runs policy learning on an arbitrary reward model.

        Args:
            config_filepath: Config file for rllib-based policy learning.
            env_name: Name of env to run evaluation on.
            reward_model_name: The name of the reward model to use for policy learning.
        """
        config = OmegaConf.load(config_filepath)
        config = self._overwrite_arbitrary_reward_policy_learning_learned_reward_models(
            env_name,
            config,
            reward_model_name,
        )

        basedir = self._get_arbitrary_reward_policy_learning_dir(env_name)
        output_dir = os.path.join(basedir, reward_model_name)
        os.makedirs(output_dir, exist_ok=True)
        # Save the config to a filepath to pass to the manager.
        policy_learning_config_filepath = os.path.join(output_dir, "policy_learning.yaml")
        OmegaConf.save(config, policy_learning_config_filepath)
        self._set_gpu_environ()
        manager = ExperimentManager(self.experiment_dir)
        config_overrides = {}
        if self.debug:
            config_overrides["stop.timesteps_total"] = self.DEBUG_POLICY_TRAINING_TIMESTEPS
            config_overrides["config.num_workers"] = 0
        manager.train(policy_learning_config_filepath, output_dir=output_dir, **config_overrides)
        checkpoint_filepath = get_latest_rllib_checkpoint(output_dir)

        assert checkpoint_filepath is not None, "No policy checkpoint output from rllib policy training."
        output_filepath = os.path.join(output_dir, "policy_learning_result.json")
        PolicyLearningOutput.save(output_filepath, policy_learning_config_filepath, checkpoint_filepath)

    def multirun_arbitrary_reward_policy_learning(self, config_filepath: str, env_name: str) -> None:
        """Runs policy learning on arbitrary reward functions.

        Args:
            config_filepath: Pipeline config file.
            env_name: Name of env to run evaluation on.
        """
        config = OmegaConf.load(config_filepath)
        env_config = self._get_config_section_for_env(config, env_name)

        # Get the expert config filepath to use as the base for policy learning.
        expert_config_filepath = self._expand_relative_config_path(
            config_filepath,
            env_name,
            env_config.expert_policy_learning.config,
        )

        # Run policy learning on each reward model.
        for reward_model_name in env_config.arbitrary_reward_policy_learning.rewards:
            self.run_arbitrary_reward_policy_learning(expert_config_filepath, env_name, reward_model_name)

    def _get_policy_evaluation_dir(self, env_name: str) -> str:
        """Gets the directory where reward evaluation outputs are stored."""
        return os.path.join(self.experiment_dir, env_name, "policy_evaluation")

    def run_policy_evaluation(self, config_filepath: str, env_name: str, policy_name: str) -> None:
        """Runs policy evaluation on a policy.

        Args:
            config_filepath: Config file for policy evaluation.
            env_name: Name of env to run evaluation on.
            policy_name: The name of the policy to evaluate.
        """
        # Get the output for the requested policy.
        basedir = self._get_arbitrary_reward_policy_learning_dir(env_name)
        policy_learning_result_filepath = pathlib.Path(basedir) / policy_name / "policy_learning_result.json"
        if not policy_learning_result_filepath.exists():
            print(f"Skipping policy evaluation of policy {policy_name} because output is missing.")
            return
        policy_learning_output = PolicyLearningOutput.load(str(policy_learning_result_filepath))
        # Note how many timesteps this policy was trained for b/c this is needed by policy evaluation.
        num_training_steps = self._get_num_steps_from_rllib_config(policy_learning_output.config_filepath)

        self._set_gpu_environ()

        # Format the config file and run the evaluation.
        policy_evaluation_config = OmegaConf.load(config_filepath)
        policy_evaluation_config.policy.checkpoint_filepath = policy_learning_output.checkpoint_filepath
        policy_evaluation_config.policy.config_filepath = policy_learning_output.config_filepath
        policy_evaluation_config.policy_evaluation.num_training_steps = num_training_steps
        if self.debug:
            policy_evaluation_config.policy_evaluation.num_training_steps = self.DEBUG_POLICY_EVALUATION_STEPS
            policy_evaluation_config.policy_evaluation.num_evaluation_steps = self.DEBUG_POLICY_EVALUATION_STEPS
        output_dir = os.path.join(self._get_policy_evaluation_dir(env_name), policy_name)
        policy_evaluation_config.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        policy_evaluation_config_filepath = os.path.join(output_dir, "config.yaml")
        OmegaConf.save(policy_evaluation_config, policy_evaluation_config_filepath)
        run_policy_evaluation(policy_evaluation_config)

    def multirun_policy_evaluation(self, config_filepath: str, env_name: str) -> None:
        """Runs policy evaluation on learned policies.

        Args:
            config_filepath: Config file specifying which algorithms to run on which datasets.
            env_name: Name of env to run evaluation on.
        """
        config = OmegaConf.load(config_filepath)
        env_config = self._get_config_section_for_env(config, env_name)
        policy_evaluation_config_filepath = self._expand_relative_config_path(
            config_filepath,
            env_name,
            env_config.policy_evaluation.config,
        )
        for policy_name in env_config.policy_evaluation.policies:
            self.run_policy_evaluation(policy_evaluation_config_filepath, env_name, policy_name)

    def run_visualization(self) -> None:
        pass

    @staticmethod
    def _expand_relative_config_path(base_config_filepath: str, env_name: str, relative_config_filepath: str) -> str:
        """Expands a relative config path based on a base config filepath."""
        basedir = os.path.dirname(base_config_filepath)
        return os.path.join(basedir, env_name, relative_config_filepath)

    @staticmethod
    def _get_config_section_for_env(config: OmegaConf, env_name: str) -> OmegaConf:
        """Gets the section of the provided config associated with the provided env_name."""
        filtered_envs = [env for env in config.envs if env.name == env_name]
        assert len(filtered_envs) == 1, f"Expected one matching env with name {env_name}, but got {len(filtered_envs)}"
        env_config = filtered_envs[0]
        return env_config

    @staticmethod
    def _load_config_with_defaults(config_filepath: str) -> OmegaConf:
        """Loads a config file with the convention that a default file exists next to it.

        This assumes a file structure like this:

        > foo
            > config_file_to_load.yaml
            > defaults.yaml

        It loads defaults.yaml, than overwrites shared parameters with those from config_file_to_load.yaml.

        Args:
            config_filepath: Filepath of the config file to load.

        Returns:
            A configuration dictionary.
        """
        default_config_filepath = os.path.join(os.path.dirname(config_filepath), "defaults.yaml")
        default_config = OmegaConf.create()
        if os.path.exists(default_config_filepath):
            default_config = OmegaConf.load(default_config_filepath)

        config = OmegaConf.load(config_filepath)
        return OmegaConf.merge(default_config, config)

    def _step_is_between_inclusive(self, step: str, start: str, stop: str) -> bool:
        assert step in self.ORDERED_STEPS
        assert start in self.ORDERED_STEPS
        assert stop in self.ORDERED_STEPS
        step_index = self.ORDERED_STEPS.index(step)
        start_index = self.ORDERED_STEPS.index(start)
        stop_index = self.ORDERED_STEPS.index(stop)
        assert start_index <= stop_index
        return step_index >= start_index and step_index <= stop_index

    def _set_seed(self, seed: int) -> None:
        """Sets the random seed on this class.

        This also sets various global random seeds.

        Args:
            seed: The random seed to set.
        """
        self.seed = seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        random.seed(self.seed)


def run_pipeline(
        experiment_dir: str,
        config_filepath: str,
        env_name: str,
        seed: int = 0,
        start: str = "expert_policy_learning",
        stop: str = "visualization",
        debug: bool = False,
) -> None:
    """Runs the pipeline for a specific environment.

    Args:
        experiment_dir: The directory in which all experiments outputs should be stored.
        config_filepath: Filepath of pipeline config.
        env_name: Name of env to run pipeline for.
        seed: The random seed to use for the full pipeline of experiments.
        start: Step at which to begin running the pipeline.
        stop: Step at which to stop running the pipeline (inclusive, i.e., it does run this step).
        debug: If True, runs all steps in a debug mode such that it runs quickly.
    """
    pipeline = Pipeline(experiment_dir, debug, seed)
    config = OmegaConf.load(config_filepath)
    env_config = pipeline._get_config_section_for_env(config, env_name)

    # Policy learning.
    if pipeline._step_is_between_inclusive("expert_policy_learning", start, stop):
        policy_learning_config_filepath = pipeline._expand_relative_config_path(
            config_filepath,
            env_name,
            env_config.expert_policy_learning.config,
        )
        pipeline.run_policy_learning(policy_learning_config_filepath, env_name)

    # Data generation.
    if pipeline._step_is_between_inclusive("data_generation", start, stop):
        random_data_generation_config_filepath = pipeline._expand_relative_config_path(
            config_filepath,
            env_name,
            env_config.data_generation.random_config,
        )
        pipeline.run_data_generation(
            env_name,
            random_data_generation_config_filepath,
            env_config.data_generation.reward_learning_train_size,
            env_config.data_generation.reward_learning_val_size,
            env_config.data_generation.reward_evaluation_size,
        )

    # Reward model learning.
    if pipeline._step_is_between_inclusive("reward_learning", start, stop):
        pipeline.multirun_reward_learning(config_filepath, env_name)

    # Reward model evaluation.
    if pipeline._step_is_between_inclusive("reward_evaluation", start, stop):
        pipeline.multirun_reward_evaluation(config_filepath, env_name)

    # Learning policies on the learned and/or other rewards.
    if pipeline._step_is_between_inclusive("arbitrary_reward_policy_learning", start, stop):
        pipeline.multirun_arbitrary_reward_policy_learning(config_filepath, env_name)

    # Policy evaluation.
    if pipeline._step_is_between_inclusive("policy_evaluation", start, stop):
        pipeline.multirun_policy_evaluation(config_filepath, env_name)

    # Visualization.
    if pipeline._step_is_between_inclusive("visualization", start, stop):
        pipeline.run_visualization()


def multirun_pipeline(
        experiment_dir: str,
        config_filepath: str,
        env_name: Optional[str] = None,
        start: str = "expert_policy_learning",
        stop: str = "visualization",
        seeds: Union[int, List[int]] = 1,
        num_parallel: int = 1,
        debug: bool = False,
) -> None:
    """Runs the pipeline multiple times (across envs and across random seeds).

    Args:
        experiment_dir: The directory in which all experiments outputs should be stored.
        config_filepath: Pipeline config file to run.
        env_name: If provided, only run for this environment.
        start: Step at which to begin running the pipeline.
        stop: Step at which to stop running the pipeline (inclusive, i.e., it does run this step).
        seeds: If an int, then this is the number of random seeds with which to run the full pipeline per environment.
            If this is a list of ints, then the ints in the list are the seeds to run.
        num_parallel: Number of pipeline runs to do in parallel. Should be about num cores // 8.
        debug: If True, runs all steps in a debug mode such that it runs quickly.
    """
    if isinstance(seeds, int):
        seeds = list(range(seeds))
    assert isinstance(seeds, list), "Seeds should be a list of seeds at this point."

    # Separate code paths to make debugging easier.
    if num_parallel == 1:
        config = OmegaConf.load(config_filepath)
        for env in config.envs:
            if env_name is not None and env_name != env.name:
                continue
            for seed in seeds:
                run_pipeline(
                    experiment_dir=experiment_dir,
                    config_filepath=config_filepath,
                    env_name=env.name,
                    seed=seed,
                    start=start,
                    stop=stop,
                    debug=debug,
                )
    elif num_parallel > 1:
        pool = ProcessPoolExecutor(num_parallel)
        config = OmegaConf.load(config_filepath)
        futures = []
        for env in config.envs:
            if env_name is not None and env_name != env.name:
                continue
            for seed in seeds:
                future = pool.submit(
                    run_pipeline,
                    experiment_dir=experiment_dir,
                    config_filepath=config_filepath,
                    env_name=env.name,
                    seed=seed,
                    start=start,
                    stop=stop,
                    debug=debug,
                )
                futures.append(future)

        concurrent.futures.wait(futures)
        pool.shutdown()


if __name__ == "__main__":
    fire.Fire(multirun_pipeline)
