from typing import Union

import gym
import numpy as np
import pytest
import torch as th
import torch.nn as nn

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

MODEL_LIST = [
    PPO,
    A2C,
    TD3,
    SAC,
    DQN,
]


class FlattenBatchNormDropoutExtractor(BaseFeaturesExtractor):
    """
    Feature extract that flatten the input and applies batch normalization and dropout.
    Used as a placeholder when feature extraction is not needed.

    :param observation_space:
    """

    def __init__(self, observation_space: gym.Space):
        super(FlattenBatchNormDropoutExtractor, self).__init__(
            observation_space,
            get_flattened_obs_dim(observation_space),
        )
        self.flatten = nn.Flatten()
        self.batch_norm = nn.BatchNorm1d(self._features_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, observations: th.Tensor) -> th.Tensor:
        result = self.flatten(observations)
        result = self.batch_norm(result)
        result = self.dropout(result)
        return result


def clone_batch_norm_stats(batch_norm: nn.BatchNorm1d) -> (th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the given batch norm layer.

    :param batch_norm:
    :return: the bias and running mean
    """
    return batch_norm.bias.clone(), batch_norm.running_mean.clone()


def clone_dqn_batch_norm_stats(model: DQN) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the Q-network and target network.

    :param model:
    :return: the bias and running mean from the Q-network and target network
    """
    q_net_batch_norm = model.policy.q_net.features_extractor.batch_norm
    q_net_bias, q_net_running_mean = clone_batch_norm_stats(q_net_batch_norm)

    q_net_target_batch_norm = model.policy.q_net_target.features_extractor.batch_norm
    q_net_target_bias, q_net_target_running_mean = clone_batch_norm_stats(q_net_target_batch_norm)

    return q_net_bias, q_net_running_mean, q_net_target_bias, q_net_target_running_mean


def clone_td3_batch_norm_stats(
    model: TD3,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the actor and critic networks and actor-target and critic-target networks.

    :param model:
    :return: the bias and running mean from the actor and critic networks and actor-target and critic-target networks
    """
    actor_batch_norm = model.actor.features_extractor.batch_norm
    actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)

    critic_batch_norm = model.critic.features_extractor.batch_norm
    critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)

    actor_target_batch_norm = model.actor_target.features_extractor.batch_norm
    actor_target_bias, actor_target_running_mean = clone_batch_norm_stats(actor_target_batch_norm)

    critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
    critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)

    return (
        actor_bias,
        actor_running_mean,
        critic_bias,
        critic_running_mean,
        actor_target_bias,
        actor_target_running_mean,
        critic_target_bias,
        critic_target_running_mean,
    )


def clone_sac_batch_norm_stats(
    model: SAC,
) -> (th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor):
    """
    Clone the bias and running mean from the actor and critic networks and critic-target networks.

    :param model:
    :return: the bias and running mean from the actor and critic networks and critic-target networks
    """
    actor_batch_norm = model.actor.features_extractor.batch_norm
    actor_bias, actor_running_mean = clone_batch_norm_stats(actor_batch_norm)

    critic_batch_norm = model.critic.features_extractor.batch_norm
    critic_bias, critic_running_mean = clone_batch_norm_stats(critic_batch_norm)

    critic_target_batch_norm = model.critic_target.features_extractor.batch_norm
    critic_target_bias, critic_target_running_mean = clone_batch_norm_stats(critic_target_batch_norm)

    return (actor_bias, actor_running_mean, critic_bias, critic_running_mean, critic_target_bias, critic_target_running_mean)


def clone_on_policy_batch_norm(model: Union[A2C, PPO]) -> (th.Tensor, th.Tensor):
    return clone_batch_norm_stats(model.policy.features_extractor.batch_norm)


CLONE_HELPERS = {
    A2C: clone_on_policy_batch_norm,
    DQN: clone_dqn_batch_norm_stats,
    SAC: clone_sac_batch_norm_stats,
    TD3: clone_td3_batch_norm_stats,
    PPO: clone_on_policy_batch_norm,
}


def test_dqn_train_with_batch_norm():
    model = DQN(
        "MlpPolicy",
        "CartPole-v1",
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=0,
        seed=1,
        tau=0,  # do not clone the target
    )

    (
        q_net_bias_before,
        q_net_running_mean_before,
        q_net_target_bias_before,
        q_net_target_running_mean_before,
    ) = clone_dqn_batch_norm_stats(model)

    model.learn(total_timesteps=200)

    (
        q_net_bias_after,
        q_net_running_mean_after,
        q_net_target_bias_after,
        q_net_target_running_mean_after,
    ) = clone_dqn_batch_norm_stats(model)

    assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
    assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()

    assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
    assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()


def test_td3_train_with_batch_norm():
    model = TD3(
        "MlpPolicy",
        "Pendulum-v1",
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=0,
        tau=0,  # do not copy the target
        seed=1,
    )

    (
        actor_bias_before,
        actor_running_mean_before,
        critic_bias_before,
        critic_running_mean_before,
        actor_target_bias_before,
        actor_target_running_mean_before,
        critic_target_bias_before,
        critic_target_running_mean_before,
    ) = clone_td3_batch_norm_stats(model)

    model.learn(total_timesteps=200)

    (
        actor_bias_after,
        actor_running_mean_after,
        critic_bias_after,
        critic_running_mean_after,
        actor_target_bias_after,
        actor_target_running_mean_after,
        critic_target_bias_after,
        critic_target_running_mean_after,
    ) = clone_td3_batch_norm_stats(model)

    assert ~th.isclose(actor_bias_before, actor_bias_after).all()
    assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

    assert ~th.isclose(critic_bias_before, critic_bias_after).all()
    assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

    assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
    assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()

    assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
    assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()


def test_sac_train_with_batch_norm():
    model = SAC(
        "MlpPolicy",
        "Pendulum-v1",
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=0,
        tau=0,  # do not copy the target
        seed=1,
    )

    (
        actor_bias_before,
        actor_running_mean_before,
        critic_bias_before,
        critic_running_mean_before,
        critic_target_bias_before,
        critic_target_running_mean_before,
    ) = clone_sac_batch_norm_stats(model)

    model.learn(total_timesteps=200)

    (
        actor_bias_after,
        actor_running_mean_after,
        critic_bias_after,
        critic_running_mean_after,
        critic_target_bias_after,
        critic_target_running_mean_after,
    ) = clone_sac_batch_norm_stats(model)

    assert ~th.isclose(actor_bias_before, actor_bias_after).all()
    assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

    assert ~th.isclose(critic_bias_before, critic_bias_after).all()
    assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

    assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
    assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_train_with_batch_norm(model_class, env_id):
    model = model_class(
        "MlpPolicy",
        env_id,
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        seed=1,
    )

    bias_before, running_mean_before = clone_on_policy_batch_norm(model)

    model.learn(total_timesteps=200)

    bias_after, running_mean_after = clone_on_policy_batch_norm(model)

    assert ~th.isclose(bias_before, bias_after).all()
    assert ~th.isclose(running_mean_before, running_mean_after).all()


@pytest.mark.parametrize("model_class", [DQN, TD3, SAC])
def test_offpolicy_collect_rollout_batch_norm(model_class):
    if model_class in [DQN]:
        env_id = "CartPole-v1"
    else:
        env_id = "Pendulum-v1"

    clone_helper = CLONE_HELPERS[model_class]

    learning_starts = 10
    model = model_class(
        "MlpPolicy",
        env_id,
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        learning_starts=learning_starts,
        seed=1,
        gradient_steps=0,
        train_freq=1,
    )

    batch_norm_stats_before = clone_helper(model)

    model.learn(total_timesteps=100)

    batch_norm_stats_after = clone_helper(model)

    # No change in batch norm params
    for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
        assert th.isclose(param_before, param_after).all()


@pytest.mark.parametrize("model_class", [A2C, PPO])
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_a2c_ppo_collect_rollouts_with_batch_norm(model_class, env_id):
    model = model_class(
        "MlpPolicy",
        env_id,
        policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
        seed=1,
        n_steps=64,
    )

    bias_before, running_mean_before = clone_on_policy_batch_norm(model)

    total_timesteps, callback = model._setup_learn(total_timesteps=2 * 64, eval_env=model.get_env())

    for _ in range(2):
        model.collect_rollouts(model.get_env(), callback, model.rollout_buffer, n_rollout_steps=model.n_steps)

    bias_after, running_mean_after = clone_on_policy_batch_norm(model)

    assert th.isclose(bias_before, bias_after).all()
    assert th.isclose(running_mean_before, running_mean_after).all()


@pytest.mark.parametrize("model_class", MODEL_LIST)
@pytest.mark.parametrize("env_id", ["Pendulum-v1", "CartPole-v1"])
def test_predict_with_dropout_batch_norm(model_class, env_id):
    if env_id == "CartPole-v1":
        if model_class in [SAC, TD3]:
            return
    elif model_class in [DQN]:
        return

    model_kwargs = dict(seed=1)
    clone_helper = CLONE_HELPERS[model_class]

    if model_class in [DQN, TD3, SAC]:
        model_kwargs["learning_starts"] = 0
    else:
        model_kwargs["n_steps"] = 64

    policy_kwargs = dict(
        features_extractor_class=FlattenBatchNormDropoutExtractor,
        net_arch=[16, 16],
    )
    model = model_class("MlpPolicy", env_id, policy_kwargs=policy_kwargs, verbose=1, **model_kwargs)

    batch_norm_stats_before = clone_helper(model)

    env = model.get_env()
    observation = env.reset()
    first_prediction, _ = model.predict(observation, deterministic=True)
    for _ in range(5):
        prediction, _ = model.predict(observation, deterministic=True)
        np.testing.assert_allclose(first_prediction, prediction)

    batch_norm_stats_after = clone_helper(model)

    # No change in batch norm params
    for param_before, param_after in zip(batch_norm_stats_before, batch_norm_stats_after):
        assert th.isclose(param_before, param_after).all()
