# Code in this file is adapted from:
# https://github.com/openai/evolution-strategies-starter.

import gym
import numpy as np
import tree  # pip install dm_tree

import ray
from src.rllib.models import ModelCatalog
from src.rllib.policy.policy_template import build_policy_class
from src.rllib.policy.sample_batch import SampleBatch
from src.rllib.utils.filter import get_filter
from src.rllib.utils.framework import try_import_torch
from src.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
    unbatch
from src.rllib.utils.torch_ops import convert_to_torch_tensor

torch, _ = try_import_torch()


def before_init(policy, observation_space, action_space, config):
    policy.action_noise_std = config["action_noise_std"]
    policy.action_space_struct = get_base_struct_from_space(action_space)
    policy.preprocessor = ModelCatalog.get_preprocessor_for_space(
        observation_space)
    policy.observation_filter = get_filter(config["observation_filter"],
                                           policy.preprocessor.shape)
    policy.single_threaded = config.get("single_threaded", False)

    def _set_flat_weights(policy, theta):
        pos = 0
        theta_dict = policy.model.state_dict()
        new_theta_dict = {}

        for k in sorted(theta_dict.keys()):
            shape = policy.param_shapes[k]
            num_params = int(np.prod(shape))
            new_theta_dict[k] = torch.from_numpy(
                np.reshape(theta[pos:pos + num_params], shape))
            pos += num_params
        policy.model.load_state_dict(new_theta_dict)

    def _get_flat_weights(policy):
        # Get the parameter tensors.
        theta_dict = policy.model.state_dict()
        # Flatten it into a single np.ndarray.
        theta_list = []
        for k in sorted(theta_dict.keys()):
            theta_list.append(torch.reshape(theta_dict[k], (-1, )))
        cat = torch.cat(theta_list, dim=0)
        return cat.cpu().numpy()

    type(policy).set_flat_weights = _set_flat_weights
    type(policy).get_flat_weights = _get_flat_weights

    def _compute_actions(policy,
                         obs_batch,
                         add_noise=False,
                         update=True,
                         **kwargs):
        # Batch is given as list -> Try converting to numpy first.
        if isinstance(obs_batch, list) and len(obs_batch) == 1:
            obs_batch = obs_batch[0]
        observation = policy.preprocessor.transform(obs_batch)
        observation = policy.observation_filter(
            observation[None], update=update)

        observation = convert_to_torch_tensor(observation, policy.device)
        dist_inputs, _ = policy.model({
            SampleBatch.CUR_OBS: observation
        }, [], None)
        dist = policy.dist_class(dist_inputs, policy.model)
        action = dist.sample()

        def _add_noise(single_action, single_action_space):
            single_action = single_action.detach().cpu().numpy()
            if add_noise and isinstance(single_action_space, gym.spaces.Box) \
                    and single_action_space.dtype.name.startswith("float"):
                single_action += np.random.randn(*single_action.shape) * \
                                 policy.action_noise_std
            return single_action

        action = tree.map_structure(_add_noise, action,
                                    policy.action_space_struct)
        action = unbatch(action)
        return action, [], {}

    def _compute_single_action(policy,
                               observation,
                               add_noise=False,
                               update=True,
                               **kwargs):
        action, state_outs, extra_fetches = policy.compute_actions(
            [observation], add_noise=add_noise, update=update, **kwargs)
        return action[0], state_outs, extra_fetches

    type(policy).compute_actions = _compute_actions
    type(policy).compute_single_action = _compute_single_action


def after_init(policy, observation_space, action_space, config):
    state_dict = policy.model.state_dict()
    policy.param_shapes = {
        k: tuple(state_dict[k].size())
        for k in sorted(state_dict.keys())
    }
    policy.num_params = sum(np.prod(s) for s in policy.param_shapes.values())


def make_model_and_action_dist(policy, observation_space, action_space,
                               config):
    # Policy network.
    dist_class, dist_dim = ModelCatalog.get_action_dist(
        action_space,
        config["model"],  # model_options
        dist_type="deterministic",
        framework="torch")
    model = ModelCatalog.get_model_v2(
        policy.preprocessor.observation_space,
        action_space,
        num_outputs=dist_dim,
        model_config=config["model"],
        framework="torch")
    # Make all model params not require any gradients.
    for p in model.parameters():
        p.requires_grad = False
    return model, dist_class


ESTorchPolicy = build_policy_class(
    name="ESTorchPolicy",
    framework="torch",
    loss_fn=None,
    get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
    before_init=before_init,
    after_init=after_init,
    make_model_and_action_dist=make_model_and_action_dist)
