import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from stable_baselines3 import TD3
from custom_td3 import CustomTD3
from stable_baselines3.td3.policies import TD3Policy, Actor
from typing import Any, ClassVar, Optional, TypeVar, Union, Tuple
import numpy as np
import torch as th
from torch import nn
from gymnasium import spaces
from torch.nn import functional as F
import math
from copy import deepcopy
import warnings
from typing import Any, ClassVar, Optional, TypeVar, Union, NamedTuple
from stable_baselines3.common.buffers import ReplayBuffer, BaseBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, PyTorchObs
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    CombinedExtractor,
    FlattenExtractor,
    NatureCNN,
    create_mlp,
    get_actor_critic_arch,
)
from stable_baselines3.common.vec_env import VecNormalize
try:
    import psutil
except ImportError:
    psutil = None
class GradBuffer_Samples(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    prev_observations: th.Tensor
    next_observations: th.Tensor
    dones: th.Tensor
    rewards: th.Tensor
class GradBuffer(BaseBuffer):
    observations: np.ndarray
    prev_observations: np.ndarray
    next_observations: np.ndarray
    actions: np.ndarray
    rewards: np.ndarray
    dones: np.ndarray
    timeouts: np.ndarray
    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        handle_timeout_termination: bool = True,
    ):
        super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
        self.buffer_size = max(buffer_size // n_envs, 1)
        if psutil is not None:
            mem_available = psutil.virtual_memory().available
        if optimize_memory_usage and handle_timeout_termination:
            raise ValueError(
                "ReplayBuffer does not support optimize_memory_usage = True "
                "and handle_timeout_termination = True simultaneously."
            )
        self.optimize_memory_usage = optimize_memory_usage
        self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
        if not optimize_memory_usage:
            self.next_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
            self.prev_observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=observation_space.dtype)
        self.actions = np.zeros(
            (self.buffer_size, self.n_envs, self.action_dim), dtype=self._maybe_cast_dtype(action_space.dtype)
        )
        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.handle_timeout_termination = handle_timeout_termination
        self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        if psutil is not None:
            total_memory_usage: float = (
                self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
            )
            if not optimize_memory_usage:
                total_memory_usage += self.next_observations.nbytes
                total_memory_usage += self.prev_observations.nbytes
            if total_memory_usage > mem_available:
                total_memory_usage /= 1e9
                mem_available /= 1e9
                warnings.warn(
                    "This system does not have apparently enough memory to store the complete "
                    f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
                )
    def add(
        self,
        obs: np.ndarray,
        prev_obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: list[dict[str, Any]],
    ) -> None:
        if isinstance(self.observation_space, spaces.Discrete):
            obs = obs.reshape((self.n_envs, *self.obs_shape))
            next_obs = next_obs.reshape((self.n_envs, *self.obs_shape))
            prev_obs = prev_obs.reshape((self.n_envs, *self.obs_shape))
        action = action.reshape((self.n_envs, self.action_dim))
        self.observations[self.pos] = np.array(obs)
        if self.optimize_memory_usage:
            self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs)
            self.observations[(self.pos - 1) % self.buffer_size] = np.array(prev_obs)
        else:
            self.next_observations[self.pos] = np.array(next_obs)
            self.prev_observations[self.pos] = np.array(prev_obs)
        self.actions[self.pos] = np.array(action)
        self.rewards[self.pos] = np.array(reward)
        self.dones[self.pos] = np.array(done)
        if self.handle_timeout_termination:
            self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])
        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0
    def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> GradBuffer_Samples:
        if not self.optimize_memory_usage:
            return super().sample(batch_size=batch_size, env=env)
        if self.full:
            batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
        else:
            batch_inds = np.random.randint(0, self.pos, size=batch_size)
        return self._get_samples(batch_inds, env=env)
    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> GradBuffer_Samples:
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))
        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
            prev_obs = self._normalize_obs(self.observations[(batch_inds - 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)
            prev_obs = self._normalize_obs(self.prev_observations[batch_inds, env_indices, :], env)
        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            prev_obs,
            next_obs,
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
        )
        return GradBuffer_Samples(*tuple(map(self.to_torch, data)))
    @staticmethod
    def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike:
        if dtype == np.float64:
            return np.float32
        return dtype
class ASAPActor(Actor):
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Box,
        net_arch: list[int],
        features_extractor: nn.Module,
        features_dim: int,
        activation_fn: type[nn.Module] = nn.ReLU,
        normalize_images: bool = True,
    ):
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            net_arch=net_arch,
            features_extractor=features_extractor,
            features_dim=features_dim,
            activation_fn=activation_fn,
            normalize_images=normalize_images,
        )
        self.net_arch = net_arch
        self.features_dim = features_dim
        self.activation_fn = activation_fn
        action_dim = get_action_dim(self.action_space)
        latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
        self.latent_pi = nn.Sequential(*latent_pi_net)
        last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
        self.mu = nn.Sequential(
            nn.Linear(last_layer_dim, action_dim),
            nn.Tanh()
        )
        self.mu_next = nn.Sequential(
            nn.Linear(last_layer_dim, action_dim),
            nn.Tanh()
        )
    def forward(self, obs: th.Tensor) -> th.Tensor:
        features = self.extract_features(obs, self.features_extractor)
        latent_pi = self.latent_pi(features)
        return self.mu(latent_pi)
    def predict_next(self, obs: PyTorchObs, deterministic: bool = True) -> th.Tensor:
        features = self.extract_features(obs, self.features_extractor)
        latent_pi = self.latent_pi(features)
        return self.mu_next(latent_pi)
class ASAPPolicy(TD3Policy):
    actor: ASAPActor
    def _build(self, lr_schedule: Schedule) -> None:
        super()._build(lr_schedule=lr_schedule)
        self.next_actor_target = deepcopy(self.actor)
        for p in self.next_actor_target.parameters():
            p.requires_grad = False
    def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ASAPActor:
        actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor)
        return ASAPActor(**actor_kwargs).to(self.device)
    def _predict_next(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
        return self.actor.predict_next(observation, deterministic)
    def _predict_next_target(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
        with th.no_grad():
            return self.next_actor_target.predict_next(observation, deterministic)
    def _polyak_update_targets(self, tau: float):
        with th.no_grad():
            for p, p_targ in zip(self.actor.parameters(),
                                  self.next_actor_target.parameters()):
                p_targ.data.mul_(1.0 - tau)
                p_targ.data.add_(tau * p.data)
class ASAPTD3(CustomTD3):
    policy: ASAPPolicy
    def __init__(
        self,
        policy: Union[str, type[TD3Policy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-3,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        policy_delay: int = 2,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        asap_lamT = 0.1,
        asap_lamS = 0.5,
        asap_lamP = 0.1,
        asap_tau = 0.01,
        asap_target_update_interval : int = 1
    ):
        self.asap_lamT = asap_lamT
        self.asap_lamS = asap_lamS
        self.asap_lamP = asap_lamP
        self.asap_tau = asap_tau
        self.asap_target_update_interval = asap_target_update_interval
        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,
            learning_starts=learning_starts,
            batch_size=batch_size,
            tau=tau,
            gamma=gamma,
            train_freq=train_freq,
            gradient_steps=gradient_steps,
            action_noise=action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            optimize_memory_usage=optimize_memory_usage,
            policy_delay=policy_delay,
            target_policy_noise=target_policy_noise,
            target_noise_clip=target_noise_clip,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model,
        )
    def _setup_model(self):
        self.replay_buffer_class = GradBuffer
        return super()._setup_model()
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        self.policy.set_training_mode(True)
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
        actor_losses, critic_losses = [], []
        asap_losses_target = []
        asap_losses = []
        for _ in range(gradient_steps):
            self._n_updates += 1
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
            with th.no_grad():
                noise = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
                noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
            current_q_values = self.critic(replay_data.observations, replay_data.actions)
            critic_loss = sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
            assert isinstance(critic_loss, th.Tensor)
            critic_losses.append(critic_loss.item())
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()
            if self._n_updates % self.policy_delay == 0:
                grad_prev_actions = self.policy._predict(replay_data.prev_observations, deterministic=True).type(th.float32)
                grad_now_actions = self.policy._predict(replay_data.observations, deterministic=True).type(th.float32)
                grad_next_actions = self.policy._predict(replay_data.next_observations, deterministic=True).type(th.float32)
                derv_t = 0.5 * ((2*grad_now_actions - grad_next_actions - grad_prev_actions)**2)
                delta = grad_next_actions - grad_prev_actions + 1e-4
                hdelta = F.tanh((1/delta)**2).detach()
                loss_t = th.mean(derv_t*hdelta)
                d_observations = replay_data.observations
                d_next_observations = replay_data.next_observations
                d_predict_next_actions = self.policy._predict_next_target(d_observations, True).detach()
                d_next_actions = self.policy._predict(d_next_observations, True)
                spatial_loss = 0.5 * F.mse_loss(d_next_actions, d_predict_next_actions)
                asap_losses.append(spatial_loss.item())
                d_predict_next_actions_train = self.policy._predict_next(d_observations, True)
                d_next_actions_target = self.policy._predict(d_next_observations, True).detach()
                predictor_loss = 0.5 * F.mse_loss(d_predict_next_actions_train, d_next_actions_target)
                actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() + loss_t * self.asap_lamT + spatial_loss * self.asap_lamS + predictor_loss * self.asap_lamP
                actor_losses.append(actor_loss.item())
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
                polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
                polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
            if self._n_updates % self.asap_target_update_interval == 0:
                self.policy._polyak_update_targets(self.asap_tau)
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        if len(actor_losses) > 0:
            self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        self.logger.record("train/asap_loss", np.mean(asap_losses))
    def _store_transition(
        self,
        replay_buffer: GradBuffer,
        buffer_action: np.ndarray,
        new_obs: Union[np.ndarray, dict[str, np.ndarray]],
        reward: np.ndarray,
        dones: np.ndarray,
        infos: list[dict[str, Any]],
    ) -> None:
        if self._vec_normalize_env is not None:
            new_obs_ = self._vec_normalize_env.get_original_obs()
            reward_ = self._vec_normalize_env.get_original_reward()
        else:
            self._last2_original_obs = self._last2_obs
            self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward
        next_obs = deepcopy(new_obs_)
        for i, done in enumerate(dones):
            if done and infos[i].get("terminal_observation") is not None:
                if isinstance(next_obs, dict):
                    next_obs_ = infos[i]["terminal_observation"]
                    if self._vec_normalize_env is not None:
                        next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
                    for key in next_obs.keys():
                        next_obs[key][i] = next_obs_[key]
                else:
                    next_obs[i] = infos[i]["terminal_observation"]
                    if self._vec_normalize_env is not None:
                        next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])
        replay_buffer.add(
            self._last_original_obs,
            self._last2_original_obs,
            next_obs,
            buffer_action,
            reward_,
            dones,
            infos,
        )
        self._last2_obs = self._last_obs
        self._last_obs = new_obs
        if self._vec_normalize_env is not None:
            self._last2_original_obs = self._last_original_obs
            self._last_original_obs = new_obs_
    def _setup_learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        reset_num_timesteps: bool = True,
        tb_log_name: str = "run",
        progress_bar: bool = False,
    ) -> tuple[int, BaseCallback]:
        total_timesteps, callback = super()._setup_learn(
            total_timesteps,
            callback,
            reset_num_timesteps,
            tb_log_name,
            progress_bar,
        )
        if reset_num_timesteps or self._last2_obs is None:
            assert self.env is not None
            if self._last_obs is not None:
                self._last2_obs = deepcopy(self._last_obs)
            if self._last_original_obs is not None:
                self._last2_original_obs = deepcopy(self._last_original_obs)
        return total_timesteps, callback