from typing import Any, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union

import io
import os
import gym
import numpy as np
import pathlib
import torch as th
from torch.nn import functional as F

from stable_baselines3 import SAC
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.save_util import recursive_getattr, save_to_zip_file
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.sac.policies import SACPolicy

from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.type_aliases import (
    GymEnv,
    MaybeCallback,
    RolloutReturn,
    Schedule,
    TrainFreq,
    TrainFrequencyUnit,
)
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps

import matplotlib.pyplot as plt


SelfCustomSAC = TypeVar("SelfCustomSAC", bound="CustomSAC")


class CustomSAC(SAC):
    def __init__(self, save_path, *args, **kwargs):
        self._n_calls = 0
        self.std_history = []
        self.q_val_history = []
        self.q_val_norm_history = []
        self.occupancy = None
        self.save_path = save_path

        super().__init__(*args, **kwargs)

    def _moving_average(self, a, n=1000):
        return np.convolve(a, np.ones(n) / n, mode="valid")[-1]

    def _get_sac_std(self):
        with th.no_grad():
            observation, _ = self.policy.obs_to_tensor(self._last_obs)
            _, log_std, _ = self.policy.actor.get_action_dist_params(observation)
            std = th.exp(log_std).to(self.device)
        return std.mean()

    def _compute_mean_q_values(self, observations):
        with th.no_grad():
            _, log_std, _ = self.policy.actor.get_action_dist_params(observations)
            std = th.mean(th.exp(log_std))
            actions = self.policy.actor(
                observations,
                deterministic=False,
            )
            q_values = th.cat(self.critic_target(observations, actions), dim=1)
            q_values, _ = th.min(q_values, dim=1)
            q_values_norm = (q_values - q_values.min()) / (
                q_values.max() - q_values.min()
            )
        return std.item(), q_values.mean().item(), q_values_norm.mean().item()

    def _on_step(self) -> None:
        """
        Update the exploration rate and target network if needed.
        This method is called in ``collect_rollouts()`` after each step in the environment.
        """
        self._n_calls += 1

    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.

        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        if self.num_timesteps < learning_starts and not (
            self.use_sde and self.use_sde_at_warmup
        ):
            # Warmup phase
            unscaled_action = np.array(
                [self.action_space.sample() for _ in range(n_envs)]
            )
        else:
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, gym.spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
            # ########################
            # ## Action Transformation
            # ########################
            # # action = action + 0.5 * np.sin(2 * np.pi * action)
            # # action = action - 0.5 - np.floor(2 * action)
            # if self.num_timesteps == 50000:
            #     self.replay_buffer = self.replay_buffer_class(
            #         self.buffer_size,
            #         self.observation_space,
            #         self.action_space,
            #         device=self.device,
            #         n_envs=self.n_envs,
            #         optimize_memory_usage=self.optimize_memory_usage,
            #         **self.replay_buffer_kwargs,
            #     )
            # if self.num_timesteps >= 50000:
            #     action = -action
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action
        return action, buffer_action

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizers learning rate
        optimizers = [self.actor.optimizer, self.critic.optimizer]
        if self.ent_coef_optimizer is not None:
            optimizers += [self.ent_coef_optimizer]

        # Update learning rate according to lr schedule
        self._update_learning_rate(optimizers)

        # if self._n_calls % 1000 == 0:
        #     replay_data = self.replay_buffer.sample(
        #         batch_size, env=self._vec_normalize_env
        #     )
        #     mean_std, mean_q, mean_q_norm = self._compute_mean_q_values(
        #         replay_data.observations
        #     )
        #     print(mean_std, mean_q, mean_q_norm)
        #     self.std_history.append(mean_std)
        #     self.q_val_history.append(mean_q)
        #     self.q_val_norm_history.append(mean_q_norm)

        ent_coef_losses, ent_coefs = [], []
        actor_losses, critic_losses = [], []

        for gradient_step in range(gradient_steps):
            if self.replay_buffer.size() >= batch_size:
                # Sample replay buffer
                replay_data = self.replay_buffer.sample(
                    batch_size, env=self._vec_normalize_env
                )

                # We need to sample because `log_std` may have changed between two gradient steps
                if self.use_sde:
                    self.actor.reset_noise()

                # Action by the current actor for the sampled state
                actions_pi, log_prob = self.actor.action_log_prob(
                    replay_data.observations
                )
                log_prob = log_prob.reshape(-1, 1)

                ent_coef_loss = None
                if self.ent_coef_optimizer is not None:
                    # Important: detach the variable from the graph
                    # so we don't change it with other losses
                    # see https://github.com/rail-berkeley/softlearning/issues/60
                    ent_coef = th.exp(self.log_ent_coef.detach())
                    ent_coef_loss = -(
                        self.log_ent_coef * (log_prob + self.target_entropy).detach()
                    ).mean()
                    ent_coef_losses.append(ent_coef_loss.item())
                else:
                    ent_coef = self.ent_coef_tensor

                ent_coefs.append(ent_coef.item())

                # Optimize entropy coefficient, also called
                # entropy temperature or alpha in the paper
                if ent_coef_loss is not None:
                    self.ent_coef_optimizer.zero_grad()
                    ent_coef_loss.backward()
                    self.ent_coef_optimizer.step()

                with th.no_grad():
                    # Select action according to policy
                    next_actions, next_log_prob = self.actor.action_log_prob(
                        replay_data.next_observations
                    )
                    # Compute the next Q values: min over all critics targets
                    next_q_values = th.cat(
                        self.critic_target(replay_data.next_observations, next_actions),
                        dim=1,
                    )
                    next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                    # add entropy term
                    next_q_values = next_q_values - ent_coef * next_log_prob.reshape(
                        -1, 1
                    )
                    # td error + entropy term
                    target_q_values = (
                        replay_data.rewards
                        + (1 - replay_data.dones) * self.gamma * next_q_values
                    )

                # Get current Q-values estimates for each critic network
                # using action from the replay buffer
                current_q_values = self.critic(
                    replay_data.observations, replay_data.actions
                )

                # Compute critic loss
                critic_loss = 0.5 * sum(
                    F.mse_loss(current_q, target_q_values)
                    for current_q in current_q_values
                )
                critic_losses.append(critic_loss.item())

                # Optimize the critic
                self.critic.optimizer.zero_grad()
                critic_loss.backward()
                self.critic.optimizer.step()

                # Compute actor loss
                # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
                # Min over all critic networks
                q_values_pi = th.cat(
                    self.critic(replay_data.observations, actions_pi), dim=1
                )
                min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
                actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
                actor_losses.append(actor_loss.item())

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()

                # Update target networks
                if gradient_step % self.target_update_interval == 0:
                    polyak_update(
                        self.critic.parameters(),
                        self.critic_target.parameters(),
                        self.tau,
                    )
                    # Copy running stats, see GH issue #996
                    polyak_update(
                        self.batch_norm_stats, self.batch_norm_stats_target, 1.0
                    )

        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/ent_coef", np.mean(ent_coefs))
        self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        if len(ent_coef_losses) > 0:
            self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

    def save(
        self,
        path: Union[str, pathlib.Path, io.BufferedIOBase],
        exclude: Optional[Iterable[str]] = None,
        include: Optional[Iterable[str]] = None,
    ) -> None:
        """
        Save all the attributes of the object and the model parameters in a zip-file.

        :param path: path to the file where the rl agent should be saved
        :param exclude: name of parameters that should be excluded in addition to the default ones
        :param include: name of parameters that might be excluded but should be included anyway
        """
        # Copy parameter list so we don't mutate the original dict
        data = self.__dict__.copy()

        # Exclude is union of specified parameters (if any) and standard exclusions
        if exclude is None:
            exclude = []
        exclude = set(exclude).union(self._excluded_save_params())

        # Do not exclude params if they are specifically included
        if include is not None:
            exclude = exclude.difference(include)

        state_dicts_names, torch_variable_names = self._get_torch_save_params()
        all_pytorch_variables = state_dicts_names + torch_variable_names
        for torch_var in all_pytorch_variables:
            # We need to get only the name of the top most module as we'll remove that
            var_name = torch_var.split(".")[0]
            # Any params that are in the save vars must not be saved by data
            exclude.add(var_name)

        # Remove parameter entries of parameters which are to be excluded
        for param_name in exclude:
            data.pop(param_name, None)

        # Build dict of torch variables
        pytorch_variables = None
        if torch_variable_names is not None:
            pytorch_variables = {}
            for name in torch_variable_names:
                attr = recursive_getattr(self, name)
                pytorch_variables[name] = attr

        # Build dict of state_dicts
        params_to_save = self.get_parameters()

        # Save custom data
        np.save(f"{path}_std", self.std_history)
        np.save(f"{path}_q", self.q_val_history)
        np.save(f"{path}_q_norm", self.q_val_norm_history)
        np.save(f"{path}_occupancy", self.occupancy)

        save_to_zip_file(
            path, data=data, params=params_to_save, pytorch_variables=pytorch_variables
        )

    def save_occupancy(self, episode):
        occupancy_folder = os.path.join(self.save_path, 'occupancy')
        os.makedirs(occupancy_folder, exist_ok=True)
        # Save the file in the occupancy folder
        np.save(os.path.join(occupancy_folder, f"{episode}.npy"), self.occupancy)

    def collect_rollouts(
        self,
        env,
        callback,
        train_freq,
        replay_buffer,
        action_noise=None,
        learning_starts=0,
        log_interval=None,
    ):
        """
        Collect experiences and store them into a ``ReplayBuffer``.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param train_freq: How much experience to collect
            by doing rollouts of current policy.
            Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
            or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
            with ``<n>`` being an integer greater than 0.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param replay_buffer:
        :param log_interval: Log data every ``log_interval`` episodes
        :return:
        """
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        num_collected_steps, num_collected_episodes = 0, 0

        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."

        if env.num_envs > 1:
            assert (
                train_freq.unit == TrainFrequencyUnit.STEP
            ), "You must use only one env when doing episodic training."

        # Vectorize action noise if needed
        if (
            action_noise is not None
            and env.num_envs > 1
            and not isinstance(action_noise, VectorizedActionNoise)
        ):
            action_noise = VectorizedActionNoise(action_noise, env.num_envs)

        if self.use_sde:
            self.actor.reset_noise(env.num_envs)

        callback.on_rollout_start()
        continue_training = True

        while should_collect_more_steps(
            train_freq, num_collected_steps, num_collected_episodes
        ):
            if (
                self.use_sde
                and self.sde_sample_freq > 0
                and num_collected_steps % self.sde_sample_freq == 0
            ):
                # Sample a new noise matrix
                self.actor.reset_noise(env.num_envs)

            # Select action randomly or according to policy
            actions, buffer_actions = self._sample_action(
                learning_starts, action_noise, env.num_envs
            )

            # Rescale and perform action
            new_obs, rewards, dones, infos = env.step(actions)
            # print("step done")
            # img = env.render("rgb_array")

            # plt.imshow(img)
            # plt.show()
            cur_obs, _ = self.policy.obs_to_tensor(self._last_obs)
            if isinstance(self.occupancy, np.ndarray):
                self.occupancy += counter(cur_obs, self.device)
            else:
                self.occupancy = counter(cur_obs, self.device)

            self.num_timesteps += env.num_envs
            num_collected_steps += 1

            # Give access to local variables
            callback.update_locals(locals())
            # Only stop training if return value is False, not when it is None.
            if callback.on_step() is False:
                return RolloutReturn(
                    num_collected_steps * env.num_envs,
                    num_collected_episodes,
                    continue_training=False,
                )

            # Retrieve reward and episode length if using Monitor wrapper
            self._update_info_buffer(infos, dones)

            # Store data in replay buffer (normalized action and unnormalized observation)
            self._store_transition(
                replay_buffer, buffer_actions, new_obs, rewards, dones, infos
            )

            self._update_current_progress_remaining(
                self.num_timesteps, self._total_timesteps
            )

            # For DQN, check if the target network should be updated
            # and update the exploration schedule
            # For SAC/TD3, the update is dones as the same time as the gradient update
            # see https://github.com/hill-a/stable-baselines/issues/900
            self._on_step()

            for idx, done in enumerate(dones):
                if done:
                    # Update stats
                    num_collected_episodes += 1
                    self._episode_num += 1

                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)

                    # Log training infos
                    if (
                        log_interval is not None
                        and self._episode_num % log_interval == 0
                    ):
                        self._dump_logs()

                    if self._episode_num % 10 == 0:
                        # Save occupancy
                        self.save_occupancy(self._episode_num)

        callback.on_rollout_end()

        return RolloutReturn(
            num_collected_steps * env.num_envs,
            num_collected_episodes,
            continue_training,
        )


def build_anchor(size, device):
    y = th.arange(-size, size, device=device) + 0.5
    x = th.arange(size, -size, -1.0, device=device) - 0.5  # x.clone()
    x, y = th.meshgrid(x, y, indexing="ij")
    return th.stack([y, x], dim=-1).cuda()


def counter(obs, device):
    anchor = build_anchor(5, device)
    obs = obs.reshape(-1, obs.shape[-1])
    # print(obs.shape, anchor.shape)
    reached = th.abs(obs[None, None, :, :] - anchor[:, :, None, :])
    reached = th.logical_and(reached[..., 0] < 0.5, reached[..., 1] < 0.5)
    return reached.sum(axis=-1).float().detach().cpu().numpy()
