from typing import Dict, Optional

import gym
import numpy as np
from omegaconf import OmegaConf
from ray.rllib.agents.ppo.ppo_tf_policy import setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import (
    KLCoeffMixin,
    kl_and_loss_stats,
    ppo_surrogate_loss,
    setup_mixins,
    ValueNetworkMixin,
    vf_preds_fetches,
)
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, LearningRateSchedule
from ray.rllib.utils.torch_ops import apply_grad_clipping
from ray.rllib.utils.typing import AgentID, TrainerConfigDict
import torch

import offline_rl
from offline_rl.envs.bouncing_balls_env import BouncingBallsEnvRewardModel
from offline_rl.scripts.rewards.evaluation.run_gym_reward_evaluation import get_reward_models


def convert_to_tensor(arr: np.ndarray, device: torch.device) -> torch.Tensor:
    """Converts a numpy array to a torch tensor inferring the data type.

    This logic isn't generally applicable, which is why this function is defined in this file.
    """
    if arr.dtype in [np.float32, np.float64]:
        dtype = torch.float32
    elif arr.dtype in [int, np.int32, np.int64]:
        dtype = torch.int64
    elif arr.dtype in [bool]:
        dtype = torch.bool
    else:
        raise ValueError(f"Don't know how to convert dtype: {arr.dtype} to tensor dtype.")

    return torch.tensor(arr, dtype=dtype, device=device)


def reward_overwriting_postprocess_fn(policy: Policy,
                                      sample_batch: SampleBatch,
                                      other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
                                      episode: Optional[MultiAgentEpisode] = None) -> SampleBatch:
    """Overwrites rewards with the output from the RewardModel stored on the policy."""
    # Run on the cpu since this runs on individual rollouts.
    # TODO(redacted): Change this to an option to handle to image case.
    device = torch.device("cpu")
    rewards = policy.reward_model.reward(
        convert_to_tensor(sample_batch[SampleBatch.OBS], device),
        convert_to_tensor(sample_batch[SampleBatch.ACTIONS], device),
        convert_to_tensor(sample_batch[SampleBatch.NEXT_OBS], device),
        convert_to_tensor(sample_batch[SampleBatch.DONES], device),
    )
    rewards = rewards.cpu().numpy().reshape(-1)
    assert rewards.shape == sample_batch[SampleBatch.REWARDS].shape
    sample_batch[SampleBatch.REWARDS] = rewards
    return compute_gae_for_sample_batch(policy, sample_batch, other_agent_batches, episode)


def load_reward_model_before_init(
        policy: Policy,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        config: TrainerConfigDict,
) -> None:
    """Initializes the reward model on the policy."""
    setup_config(policy, obs_space, action_space, config)
    assert "reward_models" in config
    reward_models_config = OmegaConf.create(config["reward_models"])
    reward_models_config["common"] = {"device": "cpu"}
    reward_models_config["rewards"]["env_name"] = config["env"]
    reward_models = get_reward_models(obs_space, action_space, reward_models_config)
    policy.reward_model = reward_models.get_model(reward_models_config["model_name"])

def maybe_zero_grad_apply_grad_clipping(policy, optimizer, loss):
    """Conditionally zeros out the gradients based on a config value stored on the policy before applying grad clipping.

    This is a hack that allows for using training code for evaluation purposes.
    By setting this grad zeroing to true during eval, we ensure the weights don't change
    (apparently, setting the learning rate to zero is insufficient).

    See `apply_grad_clipping` for details on args.
    """
    if policy.config["do_not_update_model"]:
        optimizer.zero_grad()
    return apply_grad_clipping(policy, optimizer, loss)

RewardOverwritingPolicy = build_policy_class(
    name="RewardOverwritingPolicy",
    framework="torch",
    get_default_config=lambda: offline_rl.agents.reward_overwriting.reward_overwriting_trainer.DEFAULT_CONFIG,
    loss_fn=ppo_surrogate_loss,
    stats_fn=kl_and_loss_stats,
    extra_action_out_fn=vf_preds_fetches,
    postprocess_fn=reward_overwriting_postprocess_fn,
    extra_grad_process_fn=maybe_zero_grad_apply_grad_clipping,
    before_init=load_reward_model_before_init,
    before_loss_init=setup_mixins,
    mixins=[LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, ValueNetworkMixin],
)
