import os
import yaml

import dotmap
import fire
from ray.rllib.train import create_parser as train_create_parser, run as train_run
from ray.rllib.rollout import create_parser as rollout_create_parser, run as rollout_run

from offline_rl.agents.load_custom_agents import load_custom_agents
from offline_rl.agents.registry import CUSTOM_AGENTS
from offline_rl.envs.load_custom_envs import register_ray_envs


class ExperimentConfig:
    CONFIG_DIR = "./configs"

    def __init__(self, config):
        self.experiment_name = list(config.keys())[0]
        # Store config without name to make access easier.
        self.config = dotmap.DotMap(config[self.experiment_name])

    @classmethod
    def from_yaml(cls, filepath):
        with open(filepath, "r", encoding="utf-8") as infile:
            config = yaml.safe_load(infile)
        return cls(config)

    @staticmethod
    def get_config_filepath(env_name, algo_name):
        return os.path.join(os.path.abspath(ExperimentConfig.CONFIG_DIR), f"{env_name.lower()}",
                            f"{algo_name.lower()}.yaml")

    @classmethod
    def from_env_algo(cls, env_name, algo_name):
        filepath = ExperimentConfig.get_config_filepath(env_name, algo_name)
        return cls.from_yaml(filepath)

    def write(self, filepath):
        data = dict()
        data[self.experiment_name] = self.config.toDict()
        with open(filepath, "w", encoding="utf-8") as outfile:
            yaml.dump(data, outfile, default_flow_style=False)

    def update(self, d):
        for key, value in d.items():
            cfg = self.config
            if "." in key:
                keys = key.split(".")
                for subkey in keys[:-1]:
                    cfg = cfg[subkey]
                key = keys[-1]
            cfg[key] = value

    def __repr__(self):
        return f"{self.config}"


class ExperimentManager:
    def __init__(self, experiment_dir):
        os.makedirs(experiment_dir, exist_ok=True)
        self.experiment_dir = experiment_dir

    def _run_train(self, config_filepath):
        parser = train_create_parser()
        args = parser.parse_args(args=[])
        args.config_file = config_filepath
        train_run(args, parser)

    def _register_envs_agents(self):
        register_ray_envs()
        load_custom_agents(CUSTOM_AGENTS)

    def collect(self, config_filepath, size=None, dataset_dir=None, dataset_name=None, **config_overrides):
        """Collect and save a dataset of demonstrations.

        Notes:
        - If you want to restore from a checkpoint, add `--restore=<path to checkpoint file>`.
        """
        assert (dataset_name is None and dataset_dir is not None) or (
            dataset_name is not None
            and dataset_dir is None), "Only one of dataset_name or dataset_dir should be provided."
        self._register_envs_agents()

        config = ExperimentConfig.from_yaml(config_filepath)
        config.update(config_overrides)

        env_name = config.config.env
        local_dir = os.path.join(self.experiment_dir, "runs", "collection", f"{env_name.lower()}")
        os.makedirs(local_dir, exist_ok=True)
        config.config.local_dir = local_dir

        if dataset_dir is None:
            dataset_dir = os.path.join(self.experiment_dir, "datasets", f"{env_name.lower()}", dataset_name)
        os.makedirs(dataset_dir, exist_ok=True)
        config.config.config.output = dataset_dir

        if size is not None:
            assert size > 0
            config.config.stop.timesteps_total = size

        experiment_filepath = os.path.join(local_dir, config.experiment_name, "experiment_config.yaml")
        os.makedirs(os.path.dirname(experiment_filepath), exist_ok=True)
        config.write(experiment_filepath)

        self._run_train(experiment_filepath)

    def train(self,
              config_filepath,
              run_name=None,
              output_dir=None,
              dataset_filepath_or_pattern=None,
              **config_overrides):
        """
        Args:
            config_filepath: Filepath of base config to use for training.
            run_name: If provided, name to use for directory containing outputs from this run.
                Either this argument or `output_dir` must not be None.
            output_dir: If provided, save outputs to this directory.
            dataset_filepath_or_pattern: Filepath or pattern for dataset(s) to use for training.
                If `None`, then performs online dataset collection and training.
            config_overrides: Config values to override.
                Examples: `--config.cql_alpha=10`
                          `--config.model.fcnet_activation="tanh"`
                          `--stop.timesteps_total=50000`
        """
        assert (run_name is None and output_dir is not None) or (
            run_name is not None and output_dir is None), "Only one of run_name or output_dir should be provided."
        self._register_envs_agents()

        config = ExperimentConfig.from_yaml(config_filepath)
        config.update(config_overrides)

        if output_dir is None:
            env_name = config.config.env
            output_dir = os.path.join(self.experiment_dir, "runs", "train", f"{env_name.lower()}", f"{run_name}")
        os.makedirs(output_dir, exist_ok=True)
        config.config.local_dir = output_dir

        if dataset_filepath_or_pattern is not None:
            config.config.config.input = dataset_filepath_or_pattern

        experiment_filepath = os.path.join(output_dir, config.experiment_name, "experiment_config.yaml")
        os.makedirs(os.path.dirname(experiment_filepath), exist_ok=True)
        config.write(experiment_filepath)

        self._run_train(experiment_filepath)

    def _run_rollout(self, checkpoint_filepath, algo_name, episodes, no_render, video_dir=None):
        parser = rollout_create_parser()
        args = parser.parse_args(args=["--run", algo_name])
        args.checkpoint = checkpoint_filepath
        args.episodes = episodes
        args.local_mode = True
        args.no_render = no_render
        # No limit on the number of steps.
        args.steps = 0
        if video_dir is not None:
            args.video_dir = video_dir
        rollout_run(args, parser)

    def evaluate(self, algo_name, checkpoint_filepath, num_episodes=1):
        self._run_rollout(
            checkpoint_filepath=checkpoint_filepath,
            algo_name=algo_name,
            episodes=num_episodes,
            no_render=True,
        )

    def visualize(self, algo_name, checkpoint_filepath, num_episodes=1):
        video_dir = os.path.join(os.path.dirname(checkpoint_filepath), "videos")
        os.makedirs(video_dir, exist_ok=True)
        self._run_rollout(
            checkpoint_filepath=checkpoint_filepath,
            algo_name=algo_name,
            episodes=num_episodes,
            video_dir=video_dir,
            no_render=False,
        )


if __name__ == "__main__":
    register_ray_envs()
    load_custom_agents(CUSTOM_AGENTS)
    fire.Fire(ExperimentManager)
