from typing import Any, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union

import io
import numpy as np
import scipy.stats
import pathlib
import torch as th
from torch.nn import functional as F

from stable_baselines3 import TD3
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.save_util import recursive_getattr, save_to_zip_file
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy

SelfCustomTD3 = TypeVar("SelfCustomTD3", bound="CustomTD3")


class CustomTD3(TD3):
    def __init__(self, *args, **kwargs):
        self._n_calls = 0
        self.std_history = []
        self.std_alt_history = []
        self.q_val_history = []
        self.q_val_norm_history = []
        self.q_fn_history = []
        self.act_samp_history = []
        self.act_pdf_history = []
        self.obs_samp_history = []

        super().__init__(*args, **kwargs)

    def _moving_average(self, a, n=1000):
        return np.convolve(a, np.ones(n) / n, mode="valid")[-1]

    def _compute_mean_q_values(self, observations):
        (
            actions_sampled,
            actions_alt_sampled,
            q_values_sampled,
            q_values_norm_sampled,
        ) = ([], [], [], [])
        observations_numpy = observations.cpu().detach().numpy()
        with th.no_grad():
            for i, observation in enumerate(observations_numpy[:20]):
                actions_sampled_obs, q_values_sampled_obs = [], []
                for _ in range(100):
                    unscaled_action = self.policy._predict(
                        observations[i].unsqueeze(0).to(th.float32),
                    )
                    action = self.policy.scale_action(
                        unscaled_action.cpu().detach().numpy()
                    )

                    # Add noise to the action (improve exploration)
                    if self.action_noise is not None:
                        action = np.clip(action + self.action_noise(), -1, 1)

                    actions_sampled_obs.append(action)
                    action = th.tensor(action).to(self.device)
                    q_values = self.critic.q1_forward(
                        observations[i].unsqueeze(0).to(th.float32),
                        action.to(th.float32),
                    )
                    q_values_sampled_obs.append(q_values.item())

                q_values_norm_sampled_obs = (
                    q_values_sampled_obs - np.min(q_values_sampled_obs)
                ) / (np.max(q_values_sampled_obs) - np.min(q_values_sampled_obs))

                actions_alt_sampled.append(
                    np.mean(
                        (
                            np.array(actions_sampled_obs)
                            - np.mean(actions_sampled_obs, axis=0)
                        )
                        ** 2
                    )
                    ** 0.5
                )
                actions_sampled.append(np.mean(np.std(actions_sampled_obs, axis=0)))
                q_values_sampled.append(np.mean(q_values_sampled_obs))
                q_values_norm_sampled.append(np.mean(q_values_norm_sampled_obs))

        print(
            np.mean(actions_sampled),
            np.mean(actions_alt_sampled),
            np.mean(q_values_sampled),
            np.mean(q_values_norm_sampled),
        )
        return (
            np.mean(actions_sampled),
            np.mean(actions_alt_sampled),
            np.mean(q_values_sampled),
            np.mean(q_values_norm_sampled),
        )

    def _plot_q_values(self, observations):
        actions_plot = np.linspace(-1.0, 1.0, 100)
        actions_sampled = []
        q_values_sampled = []
        observations_sampled = []
        with th.no_grad():
            for observation in observations[:1]:
                q_values_sampled_obs = []
                for action in actions_plot:
                    action = th.tensor(action).to(self.device)
                    q_values = self.critic.q1_forward(
                        observation.unsqueeze(0).to(th.float32),
                        action.unsqueeze(0).unsqueeze(0).to(th.float32),
                    )
                    q_values_sampled_obs.append(q_values.item())
                q_values_sampled.append(q_values_sampled_obs)
                observations_sampled.append(observation.cpu().detach().numpy())
                for _ in range(100):
                    unscaled_action = self.policy._predict(
                        observation.unsqueeze(0).to(th.float32)
                    )
                    action = self.policy.scale_action(
                        unscaled_action.cpu().detach().numpy()
                    )
                    # Add noise to the action (improve exploration)
                    if self.action_noise is not None:
                        action = np.clip(action + self.action_noise(), -1, 1)

                    actions_sampled.append(action)
                actions_pdf = scipy.stats.norm(
                    np.mean(actions_sampled, axis=0), 0.1
                ).pdf(actions_sampled)

        return q_values_sampled, actions_sampled, actions_pdf, observations_sampled

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)

        # Update learning rate according to lr schedule
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])

        replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
        std, std_alt, mean_q, mean_q_norm = self._compute_mean_q_values(
            replay_data.observations
        )
        q_fn, act_samp, act_pdf, obs_samp = self._plot_q_values(
            replay_data.observations
        )

        self.std_history.append(std)
        self.std_alt_history.append(std_alt)
        self.q_val_history.append(mean_q)
        self.q_val_norm_history.append(mean_q_norm)
        self.q_fn_history.append(q_fn)
        self.act_samp_history.append(act_samp)
        self.act_pdf_history.append(act_pdf)
        self.obs_samp_history.append(obs_samp)

        actor_losses, critic_losses = [], []

        for _ in range(gradient_steps):

            self._n_updates += 1
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(
                batch_size, env=self._vec_normalize_env
            )

            with th.no_grad():
                # Select action according to policy and add clipped noise
                noise = replay_data.actions.clone().data.normal_(
                    0, self.target_policy_noise
                )
                noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (
                    self.actor_target(replay_data.next_observations) + noise
                ).clamp(-1, 1)

                # Compute the next Q-values: min over all critics targets
                next_q_values = th.cat(
                    self.critic_target(replay_data.next_observations, next_actions),
                    dim=1,
                )
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                target_q_values = (
                    replay_data.rewards
                    + (1 - replay_data.dones) * self.gamma * next_q_values
                )

            # Get current Q-values estimates for each critic network
            current_q_values = self.critic(
                replay_data.observations, replay_data.actions
            )

            # Compute critic loss
            critic_loss = sum(
                F.mse_loss(current_q, target_q_values) for current_q in current_q_values
            )
            critic_losses.append(critic_loss.item())

            # Optimize the critics
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Delayed policy updates
            if self._n_updates % self.policy_delay == 0:
                # Compute actor loss
                actor_loss = -self.critic.q1_forward(
                    replay_data.observations, self.actor(replay_data.observations)
                ).mean()
                actor_losses.append(actor_loss.item())

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()

                polyak_update(
                    self.critic.parameters(), self.critic_target.parameters(), self.tau
                )
                polyak_update(
                    self.actor.parameters(), self.actor_target.parameters(), self.tau
                )
                # Copy running stats, see GH issue #996
                polyak_update(
                    self.critic_batch_norm_stats,
                    self.critic_batch_norm_stats_target,
                    1.0,
                )
                polyak_update(
                    self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0
                )

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        if len(actor_losses) > 0:
            self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))

    def save(
        self,
        path: Union[str, pathlib.Path, io.BufferedIOBase],
        exclude: Optional[Iterable[str]] = None,
        include: Optional[Iterable[str]] = None,
    ) -> None:
        """
        Save all the attributes of the object and the model parameters in a zip-file.

        :param path: path to the file where the rl agent should be saved
        :param exclude: name of parameters that should be excluded in addition to the default ones
        :param include: name of parameters that might be excluded but should be included anyway
        """
        # Copy parameter list so we don't mutate the original dict
        data = self.__dict__.copy()

        # Exclude is union of specified parameters (if any) and standard exclusions
        if exclude is None:
            exclude = []
        exclude = set(exclude).union(self._excluded_save_params())

        # Do not exclude params if they are specifically included
        if include is not None:
            exclude = exclude.difference(include)

        state_dicts_names, torch_variable_names = self._get_torch_save_params()
        all_pytorch_variables = state_dicts_names + torch_variable_names
        for torch_var in all_pytorch_variables:
            # We need to get only the name of the top most module as we'll remove that
            var_name = torch_var.split(".")[0]
            # Any params that are in the save vars must not be saved by data
            exclude.add(var_name)

        # Remove parameter entries of parameters which are to be excluded
        for param_name in exclude:
            data.pop(param_name, None)

        # Build dict of torch variables
        pytorch_variables = None
        if torch_variable_names is not None:
            pytorch_variables = {}
            for name in torch_variable_names:
                attr = recursive_getattr(self, name)
                pytorch_variables[name] = attr

        # Build dict of state_dicts
        params_to_save = self.get_parameters()

        # Save custom data
        np.save(f"{path}_std", self.std_history)
        np.save(f"{path}_q", self.q_val_history)
        np.save(f"{path}_q_norm", self.q_val_norm_history)
        np.save(f"{path}_q_fn", self.q_fn_history)
        np.save(f"{path}_act_samp", self.act_samp_history)
        np.save(f"{path}_act_pdf", self.act_pdf_history)
        np.save(f"{path}_obs_samp", self.obs_samp_history)

        save_to_zip_file(
            path, data=data, params=params_to_save, pytorch_variables=pytorch_variables
        )
