from typing import Union

import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from sb3_contrib import QRDQN, TQC, MaskablePPO
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
from sb3_contrib.common.maskable.utils import get_action_masks


class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
    """
    Feature extract that flatten the input and applies batch normalization and dropout.
    Used as a placeholder when feature extraction is not needed.
    :param observation_space:
    """

    def __init__(self, observation_space: gym.Space):
        super().__init__(
            observation_space,
            get_flattened_obs_dim(observation_space),
        )
        self.flatten = nn.Flatten()
        self.batch_norm = nn.BatchNorm1d(self._features_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, observations: th.Tensor) -> th.Tensor:
        result = self.flatten(observations)
        result = self.batch_norm(result)
        result = self.dropout(result)
        return result


def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the given batch norm layer.
    :param batch_norm:
    :return: the bias and running mean
    """
    return batch_norm.bias.clone(), batch_norm.running_mean.clone()


def clone_qrdqn_batch_norm_stats(model: QRDQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the quantile network and quantile-target network.
    :param model:
    :return: the bias and running mean from the quantile network and quantile-target network
    """
    quantile_net_batch_norm = model.policy.quantile_net.features_extractor.batch_norm
    quantile_net_bias, quantile_net_running_mean = clone_batch_norm_stats(quantile_net_batch_norm)

    quantile_net_target_batch_norm = model.policy.quantile_net_target.features_extractor.batch_norm
    quantile_net_target_bias, quantile_net_target_running_mean = clone_batch_norm_stats(quantile_net_target_batch_norm)

    return quantile_net_bias, quantile_net_running_mean, quantile_net_target_bias, quantile_net_target_running_mean


def clone_tqc_batch_norm_stats(
    model: TQC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the actor and critic networks and critic-target networks.
    :param model:
    :return: the bias and running mean from the actor and critic networks and critic-target networks
    """
    actor_batch_norm = model.actor.features_extractor.batch_norm
    actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)

    critic_batch_norm = model.critic.features_extractor.batch_norm
    critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)

    critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
    critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)

    return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)


def clone_on_policy_batch_norm(model: Union[MaskablePPO]) -> (th.Tensor, th.Tensor):
    return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)


CLONE_HELPERS = {
    QRDQN: clone_qrdqn_batch_norm_stats,
    TQC: clone_tqc_batch_norm_stats,
    MaskablePPO: clone_on_policy_batch_norm,
}


def test_ppo_mask_train_eval_mode():
    env = InvalidActionEnvDiscrete(dim=20, n_invalid_actions=10)
    model = MaskablePPO(
        "MlpPolicy",
        env,
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        seed=1,
    )

    bias_before, running_mean_before = clone_on_policy_batch_norm(model)

    model.learn(total_timesteps=200)

    bias_after, running_mean_after = clone_on_policy_batch_norm(model)

    assert ~th.isclose(bias_before, bias_after).all()
    assert ~th.isclose(running_mean_before, running_mean_after).all()

    batch_norm_stats_before = clone_on_policy_batch_norm(model)

    observation = env.reset()
    action_masks = get_action_masks(env)
    first_prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
    for _ in range(5):
        prediction, _ = model.predict(observation, action_masks=action_masks, deterministic=True)
        np.testing.assert_allclose(first_prediction, prediction)

    batch_norm_stats_after = clone_on_policy_batch_norm(model)

    # No change in batch norm params
    for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
        assert th.isclose(param_before, param_after).all()


def test_qrdqn_train_with_batch_norm():
    model = QRDQN(
        "MlpPolicy",
        "CartPole-v1",
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=0,
        seed=1,
        tau=0,  # do not clone the target
    )

    (
        quantile_net_bias_before,
        quantile_net_running_mean_before,
        quantile_net_target_bias_before,
        quantile_net_target_running_mean_before,
    ) = clone_qrdqn_batch_norm_stats(model)

    model.learn(total_timesteps=200)

    (
        quantile_net_bias_after,
        quantile_net_running_mean_after,
        quantile_net_target_bias_after,
        quantile_net_target_running_mean_after,
    ) = clone_qrdqn_batch_norm_stats(model)

    assert ~th.isclose(quantile_net_bias_before, quantile_net_bias_after).all()
    assert ~th.isclose(quantile_net_running_mean_before, quantile_net_running_mean_after).all()

    assert th.isclose(quantile_net_target_bias_before, quantile_net_target_bias_after).all()
    assert th.isclose(quantile_net_target_running_mean_before, quantile_net_target_running_mean_after).all()


def test_tqc_train_with_batch_norm():
    model = TQC(
        "MlpPolicy",
        "Pendulum-v1",
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=0,
        tau=0,  # do not copy the target
        seed=1,
    )

    (
        actor_bias_before,
        actor_running_mean_before,
        critic_bias_before,
        critic_running_mean_before,
        critic_target_bias_before,
        critic_target_running_mean_before,
    ) = clone_tqc_batch_norm_stats(model)

    model.learn(total_timesteps=200)

    (
        actor_bias_after,
        actor_running_mean_after,
        critic_bias_after,
        critic_running_mean_after,
        critic_target_bias_after,
        critic_target_running_mean_after,
    ) = clone_tqc_batch_norm_stats(model)

    assert ~th.isclose(actor_bias_before, actor_bias_after).all()
    assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

    assert ~th.isclose(critic_bias_before, critic_bias_after).all()
    assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

    assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
    assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
    if model_class in [QRDQN]:
        env_id = "CartPole-v1"
    else:
        env_id = "Pendulum-v1"

    clone_helper = CLONE_HELPERS[model_class]

    learning_starts = 10
    model = model_class(
        "MlpPolicy",
        env_id,
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=learning_starts,
        seed=1,
        gradient_steps=0,
        train_freq=1,
    )

    batch_norm_stats_before = clone_helper(model)

    model.learn(total_timesteps=100)

    batch_norm_stats_after = clone_helper(model)

    # No change in batch norm params
    for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
        assert th.isclose(param_before, param_after).all()


@pytest.mark.parametrize("model_class", [QRDQN, TQC])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
    if env_id == "CartPole-v1":
        if model_class in [TQC]:
            return
    elif model_class in [QRDQN]:
        return

    model_kwargs = dict(seed=1)
    clone_helper = CLONE_HELPERS[model_class]

    if model_class in [QRDQN, TQC]:
        model_kwargs["learning_starts"] = 0
    else:
        model_kwargs["n_steps"] = 64

    policy_kwargs = dict(
        features_extractor_class=FlattenBatchNormDropoutExtractor,
        net_arch=[16, 16],
    )
    model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)

    batch_norm_stats_before = clone_helper(model)

    env = model.get_env()
    observation = env.reset()
    first_prediction, _ = model.predict(observation, deterministic=True)
    for _ in range(5):
        prediction, _ = model.predict(observation, deterministic=True)
        np.testing.assert_allclose(first_prediction, prediction)

    batch_norm_stats_after = clone_helper(model)

    # No change in batch norm params
    for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
        assert th.isclose(param_before, param_after).all()
