import os

import hydra
import numpy as np
from omegaconf import dictconfig, OmegaConf
from ray.rllib.policy.sample_batch import SampleBatch
import torch

from offline_rl.data.rllib_data_utils import load_sample_batches
from offline_rl.rewards.evaluation.model_collection import ModelCollection
from offline_rl.scripts.rewards.evaluation.run_gym_reward_evaluation import get_reward_models
from offline_rl.scripts.rewards.learning.common import get_env
from offline_rl.scripts.rl.run import ExperimentManager
from offline_rl.utils.file_utils import save_json, get_datetime_string


def collect_data_from_policy(config):
    # Make sure the config is valid for collecting a dataset.
    assert os.path.exists(config.policy.config_filepath)
    assert os.path.exists(config.policy.checkpoint_filepath)
    assert config.policy_evaluation.num_training_steps > 0
    assert config.policy_evaluation.num_evaluation_steps > 0

    experiment_dir = os.path.join(config.output_dir, "experiment")
    manager = ExperimentManager(experiment_dir)

    # The dataset directory should almost certainly be unique on a per-run basis
    # so use datetime string (and assume things run serially).
    datetime_string = get_datetime_string()
    dataset_dir = os.path.join(config.output_dir, "dataset", datetime_string)

    # The overrides necessary to rollout the policy without updating it.
    config_overrides = {
        # Restore from the provided checkpoint.
        "restore": config.policy.checkpoint_filepath,
        # Make sure the policy isn't updated.
        "config.lr": 0.0,
        "config.do_not_update_model": True,
        # Collect full episodes instead of fragments.
        "config.batch_mode": "complete_episodes",
    }
    # Note: since we restore the training state, we have to add the evaluation steps to the training steps
    # so far. This is quite hacky and makes me think I should just switch to using evaluate.py from rllib.
    # The problem with that is that it's not parallelized (last I checked), and it saves all the rollouts
    # as a pickle.
    size = config.policy_evaluation.num_training_steps + config.policy_evaluation.num_evaluation_steps
    manager.collect(
        config_filepath=config.policy.config_filepath,
        size=size,
        dataset_dir=dataset_dir,
        **config_overrides,
    )

    return dataset_dir


def get_reward_models_for_env(env_name, reward_labels):
    env = get_env(env_name)
    config = OmegaConf.create({"rewards": {
        "env_name": env_name,
    }})
    reward_models = get_reward_models(env.observation_space, env.action_space, config)

    # Subselect to the requested ones.
    reduced_reward_models = dict()
    for reward_label in reward_labels:
        reduced_reward_models[reward_label] = reward_models.get_model(reward_label)
    reward_models = ModelCollection(reduced_reward_models)

    return reward_models


def compute_episode_return(episode):
    # TOOD(redacted): Add option for discounted return?
    return episode[SampleBatch.REWARDS].sum()


def compute_mean_returns(batches, rewards):
    label_returns = dict()
    for label, label_rewards in rewards.label_to_rewards.items():
        batches[SampleBatch.REWARDS] = label_rewards.to("cpu").numpy()
        returns = []
        for episode in batches.split_by_episode():
            episode_return = compute_episode_return(episode)
            returns.append(episode_return)
        label_returns[label] = {
            "mean": np.mean(returns).astype(float),
            "std": np.std(returns).astype(float),
            "num_episodes": len(returns),
        }
    return label_returns


def convert_to_tensor(arr, device, float_dtype=torch.float32):
    return torch.tensor(arr, dtype=float_dtype, device=device)


def compute_rewards(config, batches, device="cuda"):
    reward_models = get_reward_models_for_env(config.env_name, config.rewards)
    rewards = reward_models.rewards(
        convert_to_tensor(batches[SampleBatch.OBS], device=device),
        convert_to_tensor(batches[SampleBatch.ACTIONS], device=device),
        convert_to_tensor(batches[SampleBatch.NEXT_OBS], device=device),
        convert_to_tensor(batches[SampleBatch.DONES], device=device),
    )
    return rewards


@hydra.main(config_path="configs")
def main(config: dictconfig.DictConfig) -> None:
    """Runs policy evaluation of a policy on an environment against a set of reward models."""
    os.makedirs(config.output_dir, exist_ok=True)
    dataset_dir = collect_data_from_policy(config)
    batches = load_sample_batches(
        dataset_dir,
        debug_size=config.policy_evaluation.num_evaluation_steps,
        debug_size_mode="ordered",
    )
    rewards = compute_rewards(config, batches)
    returns = compute_mean_returns(batches, rewards)
    output_filepath = os.path.join(config.output_dir, "returns.json")
    save_json(output_filepath, returns)


if __name__ == "__main__":
    # pylint: disable=no-value-for-parameter
    main()
