from stable_baselines3 import TD3
from typing import Any, ClassVar, Optional, TypeVar, Union, List, Tuple
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
import numpy as np
import torch as th
from torch.nn import functional as F
from gymnasium import spaces
from torch import nn
from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps, polyak_update, get_parameters_by_name
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.td3.policies import Actor, CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy
SelfTD3 = TypeVar("SelfTD3", bound="TD3")
class PaveTD3(TD3):
    def __init__(
        self,
        policy: Union[str, type[TD3Policy]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-3,
        buffer_size: int = 1_000_000,
        learning_starts: int = 100,
        batch_size: int = 256,
        tau: float = 0.005,
        gamma: float = 0.99,
        train_freq: Union[int, tuple[int, str]] = 1,
        gradient_steps: int = 1,
        action_noise: Optional[ActionNoise] = None,
        replay_buffer_class: Optional[type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        policy_delay: int = 2,
        target_policy_noise: float = 0.2,
        target_noise_clip: float = 0.5,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        grad_lamT: float = 0.1,
        grad_lamS: float = 0.1,
        grad_lamC: float = 0.01,
        grad_sigma: float = 0.01,
        grad_delta: float = 1.0,
    ):
        self.grad_lamT = grad_lamT
        self.grad_lamS = grad_lamS
        self.grad_lamC = grad_lamC
        self.grad_sigma = grad_sigma
        self.grad_delta = grad_delta
        if policy_kwargs is None:
            policy_kwargs = {}
        if "activation_fn" not in policy_kwargs:
            policy_kwargs["activation_fn"] = nn.SiLU
        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,
            learning_starts=learning_starts,
            batch_size=batch_size,
            tau=tau,
            gamma=gamma,
            train_freq=train_freq,
            gradient_steps=gradient_steps,
            action_noise=action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            optimize_memory_usage=optimize_memory_usage,
            policy_delay=policy_delay,
            target_policy_noise=target_policy_noise,
            target_noise_clip=target_noise_clip,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model,
        )
    def learn(
        self: SelfTD3,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "TD3",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfTD3:
        self.pure_actions = [ [] for _ in range(self.env.num_envs)]
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )
    def _sample_action_with_pure(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
    ) -> tuple[np.ndarray, np.ndarray]:
        if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
            unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
        else:
            assert self._last_obs is not None, "self._last_obs was not set"
            unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
        if isinstance(self.action_space, spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action)
            pure_action = unscaled_action
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action)
        else:
            buffer_action = unscaled_action
            action = buffer_action
            pure_action = unscaled_action
        return action, buffer_action, pure_action
    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        train_freq: TrainFreq,
        replay_buffer: ReplayBuffer,
        action_noise: Optional[ActionNoise] = None,
        learning_starts: int = 0,
        log_interval: Optional[int] = None,
    ) -> RolloutReturn:
        self.policy.set_training_mode(False)
        num_collected_steps, num_collected_episodes = 0, 0
        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."
        if env.num_envs > 1:
            assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."
        if self.use_sde:
            self.actor.reset_noise(env.num_envs)
        callback.on_rollout_start()
        continue_training = True
        pure_actions = self.pure_actions
        while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
            if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
                self.actor.reset_noise(env.num_envs)
            actions, buffer_actions, pure_action = self._sample_action_with_pure(learning_starts, action_noise, env.num_envs)
            for env_idx in range(env.num_envs):
                pure_actions[env_idx].append(pure_action[env_idx])
            new_obs, rewards, dones, infos = env.step(actions)
            self.num_timesteps += env.num_envs
            num_collected_steps += 1
            callback.update_locals(locals())
            if not callback.on_step():
                return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
            self._update_info_buffer(infos, dones)
            self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
            self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
            self._on_step()
            for idx, done in enumerate(dones):
                if done:
                    num_collected_episodes += 1
                    self._episode_num += 1
                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)
                    if log_interval is not None and self._episode_num % log_interval == 0:
                        self._dump_logs()
                    pure_actions_array = np.stack(pure_actions[idx], axis=0)
                    if pure_actions_array.shape[0] > 1:
                        action_diffs = pure_actions_array[1:, :] - pure_actions_array[:-1, :]
                        diff_norms = np.linalg.norm(action_diffs, axis=-1)
                        mean_oscillation = np.mean(diff_norms)
                        self.logger.record(f"train/oscillation", mean_oscillation)
                    pure_actions[idx] = []
            self._last_obs = new_obs
        callback.on_rollout_end()
        return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        self.policy.set_training_mode(True)
        self._update_learning_rate([self.actor.optimizer, self.critic.optimizer])
        actor_losses, critic_losses = [], []
        mpr_losses, vfc_losses, curv_losses = [], [], []
        for _ in range(gradient_steps):
            self._n_updates += 1
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
            a_input = replay_data.actions.clone().detach().requires_grad_(True)
            obs_input = replay_data.observations
            q1_pred, q2_pred = self.critic(obs_input, a_input)
            grad_q1 = th.autograd.grad(q1_pred.sum(), a_input, create_graph=True)[0]
            grad_q2 = th.autograd.grad(q2_pred.sum(), a_input, create_graph=True)[0]
            noise = th.randn_like(obs_input) * self.grad_sigma
            q1_noisy, q2_noisy = self.critic(obs_input + noise, a_input)
            grad_q1_noisy = th.autograd.grad(q1_noisy.sum(), a_input, create_graph=True)[0]
            grad_q2_noisy = th.autograd.grad(q2_noisy.sum(), a_input, create_graph=True)[0]
            mpr_loss = F.mse_loss(grad_q1, grad_q1_noisy) + F.mse_loss(grad_q2, grad_q2_noisy)
            obs_next = replay_data.next_observations
            q1_next, q2_next = self.critic(obs_next, a_input)
            grad_q1_next = th.autograd.grad(q1_next.sum(), a_input, create_graph=True)[0]
            grad_q2_next = th.autograd.grad(q2_next.sum(), a_input, create_graph=True)[0]
            vfc_loss = F.mse_loss(grad_q1, grad_q1_next) + F.mse_loss(grad_q2, grad_q2_next)
            v = (th.randint_like(a_input, high=2) * 2 - 1).to(dtype=a_input.dtype)
            grad_q1_v_product = (grad_q1 * v).sum()
            hessian_vec_prod1 = th.autograd.grad(grad_q1_v_product, a_input, create_graph=True)[0]
            trace_approx1 = (hessian_vec_prod1 * v).sum(dim=1)
            curv_loss1 = th.mean(th.relu(trace_approx1 + self.grad_delta))
            grad_q2_v_product = (grad_q2 * v).sum()
            hessian_vec_prod2 = th.autograd.grad(grad_q2_v_product, a_input, create_graph=True)[0]
            trace_approx2 = (hessian_vec_prod2 * v).sum(dim=1)
            curv_loss2 = th.mean(th.relu(trace_approx2 + self.grad_delta))
            curv_loss = curv_loss1 + curv_loss2
            q_flow_loss = (self.grad_lamS * mpr_loss) + \
                          (self.grad_lamT * vfc_loss) + \
                          (self.grad_lamC * curv_loss)
            with th.no_grad():
                noise_act = replay_data.actions.clone().data.normal_(0, self.target_policy_noise)
                noise_act = noise_act.clamp(-self.target_noise_clip, self.target_noise_clip)
                next_actions = (self.actor_target(replay_data.next_observations) + noise_act).clamp(-1, 1)
                next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
            critic_td_loss = F.mse_loss(q1_pred, target_q_values) + F.mse_loss(q2_pred, target_q_values)
            critic_loss = critic_td_loss + q_flow_loss
            critic_losses.append(critic_loss.item())
            mpr_losses.append(mpr_loss.item())
            vfc_losses.append(vfc_loss.item())
            curv_losses.append(curv_loss.item())
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()
            if self._n_updates % self.policy_delay == 0:
                actor_action = self.actor(replay_data.observations)
                actor_loss = -self.critic.q1_forward(replay_data.observations, actor_action).mean()
                actor_losses.append(actor_loss.item())
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()
                polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
                polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
                polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)
        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        if len(actor_losses) > 0:
            self.logger.record("train/actor_loss", np.mean(actor_losses))
        self.logger.record("train/critic_loss", np.mean(critic_losses))
        if len(mpr_losses) > 0:
            self.logger.record("train/qflow_mpr_loss", np.mean(mpr_losses))
        if len(vfc_losses) > 0:
            self.logger.record("train/qflow_vfc_loss", np.mean(vfc_losses))
        if len(curv_losses) > 0:
            self.logger.record("train/qflow_curv_loss", np.mean(curv_losses))
    def _log_hessian_stats(self, obs: th.Tensor) -> None:
        with th.set_grad_enabled(True):
            obs_temp = obs.detach()
            action = self.actor(obs_temp)
            q_val = self.critic.q1_forward(obs_temp, action)
            q_sum = q_val.sum()
            grads = th.autograd.grad(q_sum, action, create_graph=True)[0]
            hessian_traces = []
            max_eigenvalues = []
            for i in range(action.shape[0]):
                hessian_matrix = []
                action_dim = action.shape[1]
                for j in range(action_dim):
                    grad_2 = th.autograd.grad(grads[i, j], action, retain_graph=True)[0]
                    hessian_matrix.append(grad_2[i].detach().cpu().numpy())
                hessian_matrix = np.array(hessian_matrix)
                eigenvalues = np.linalg.eigvalsh(hessian_matrix)
                hessian_traces.append(np.sum(eigenvalues))
                max_eigenvalues.append(np.max(eigenvalues))
            self.logger.record("train/hessian_max_eigen", np.mean(max_eigenvalues))
            self.logger.record("train/hessian_trace", np.mean(hessian_traces))