from stable_baselines3.sac import SAC
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gym import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy

import io
import pathlib
import sys
import time
import warnings
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gym import spaces

from stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer

from mnsesac.reshape_action_policies import ReshapeActionSACPolicy
from easydict import EasyDict
from mnsesac.train_theta_with_varing_policy.buffers import ActionProbReplayBuffer
from mnsesac.train_theta_with_varing_policy.train_inverse import EmbeddingInvNetwork
SelfSAC = TypeVar("MnseSAC", bound="SAC")


class MnseSAC(SAC):
    policy_aliases: Dict[str, Type[BasePolicy]] = {
        "MlpPolicy": MlpPolicy,
        "CnnPolicy": CnnPolicy,
        "MultiInputPolicy": MultiInputPolicy,
        "ReshapeActionSACPolicy": ReshapeActionSACPolicy,
    }

    def __init__(
            self,
            policy: Union[str, Type[SACPolicy]],
            env: Union[GymEnv, str],
            learning_rate: Union[float, Schedule] = 3e-4,
            buffer_size: int = 1_000_000,  # 1e6
            learning_starts: int = 100,
            batch_size: int = 256,
            tau: float = 0.005,
            gamma: float = 0.99,
            train_freq: Union[int, Tuple[int, str]] = 1,
            gradient_steps: int = 1,
            action_noise: Optional[ActionNoise] = None,
            replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
            replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
            optimize_memory_usage: bool = False,
            ent_coef: Union[str, float] = "auto",
            target_update_interval: int = 1,
            target_entropy: Union[str, float] = "auto",
            use_sde: bool = False,
            sde_sample_freq: int = -1,
            use_sde_at_warmup: bool = False,
            tensorboard_log: Optional[str] = None,
            policy_kwargs: Optional[Dict[str, Any]] = None,
            verbose: int = 0,
            seed: Optional[int] = None,
            device: Union[th.device, str] = "auto",
            _init_setup_model: bool = True,
            additional_config=None,
    ):
        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,  # 1e6
            learning_starts=learning_starts,
            batch_size=batch_size,
            tau=tau,
            gamma=gamma,
            train_freq=train_freq,
            gradient_steps=gradient_steps,
            action_noise=action_noise,
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            optimize_memory_usage=optimize_memory_usage,
            ent_coef=ent_coef,
            target_update_interval=target_update_interval,
            target_entropy=target_entropy,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            use_sde_at_warmup=use_sde_at_warmup,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            seed=seed,
            device=device,
            _init_setup_model=_init_setup_model,
        )

        if additional_config:
            self.additional_config = EasyDict(additional_config)

        self.train_aml_interval = 100000
        self.collect_data_for_inv = True
        self.aml_trained_flag = False

        if self.collect_data_for_inv:
            self.action_prob_replay_buffer = ActionProbReplayBuffer(
                self.train_aml_interval,
                env.observation_space,
                env.action_space,
                handle_timeout_termination=False, )

        if self.collect_data_for_inv:
            action_bins = 20 #20
            device = self.device
            self.bins = th.linspace(-1, 1, steps=action_bins + 1).to(device)

            obs_shape = self.action_prob_replay_buffer.obs_shape
            action_dim = self.action_prob_replay_buffer.action_dim

            self.inverse_net = EmbeddingInvNetwork(obs_shape=obs_shape[0], action_shape=action_dim,
                                              action_bins=action_bins).to(device)
            self.inverse_optimizer = th.optim.Adam(self.inverse_net.parameters(), lr=1e-3)
            self.inv_batch_size = 64
            self.inv_gradient_steps = 100000

            from torch.utils.tensorboard import SummaryWriter
            import time
            time_flag = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

            self.exp_path = tensorboard_log + '/ActionMapping/' + time_flag
            self.inv_writer = SummaryWriter(self.exp_path)
            self.log_interval = 100

            self.fig_log_interval = 10000

            self.save_model_interval = 50000

        if self.collect_data_for_inv:
            self.theta = th.nn.Parameter(th.zeros(action_dim * action_bins, device=device), requires_grad=True)
            self.theta_opt = th.optim.Adam([self.theta], lr=1e-2)


    def _sample_action(
        self,
        learning_starts: int,
        action_noise: Optional[ActionNoise] = None,
        n_envs: int = 1,
        return_action_logprob: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Sample an action according to the exploration policy.
        This is either done by sampling the probability distribution of the policy,
        or sampling a random action (from a uniform distribution over the action space)
        or by adding noise to the deterministic output.

        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param n_envs:
        :return: action to take in the environment
            and scaled action that will be stored in the replay buffer.
            The two differs when the action space is not normalized (bounds are not [-1, 1]).
        """
        # Select action randomly or according to policy
        if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
            # Warmup phase
            unscaled_action = np.array([self.action_space.sample() for _ in range(n_envs)])
            theta = None

            action_len = self.action_space.high - self.action_space.low
            action_prob = 1./action_len
            action_logprob = np.log(action_prob).sum()

            action_logprob = np.array([action_logprob for _ in range(n_envs)])

        else:
            # Note: when using continuous actions,
            # we assume that the policy uses tanh to scale the action
            # We use non-deterministic action in the case of SAC, for TD3, it does not matter
            unscaled_action, theta, action_logprob = self.predict(self._last_obs, deterministic=False, return_action_logprob=True)
            action_logprob = action_logprob.detach().cpu().numpy()

        # Rescale the action from [low, high] to [-1, 1]
        if isinstance(self.action_space, spaces.Box):
            scaled_action = self.policy.scale_action(unscaled_action,theta=theta)

            # Add noise to the action (improve exploration)
            if action_noise is not None:
                scaled_action = np.clip(scaled_action + action_noise(), -1, 1)

            # We store the scaled action in the buffer
            buffer_action = scaled_action
            action = self.policy.unscale_action(scaled_action, theta=theta)
        else:
            # Discrete case, no need to normalize or clip
            buffer_action = unscaled_action
            action = buffer_action

        if return_action_logprob:
            return action, buffer_action, action_logprob
        else:
            return action, buffer_action

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        train_freq: TrainFreq,
        replay_buffer: ReplayBuffer,
        action_noise: Optional[ActionNoise] = None,
        learning_starts: int = 0,
        log_interval: Optional[int] = None,
    ) -> RolloutReturn:
        """
        Collect experiences and store them into a ``ReplayBuffer``.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param train_freq: How much experience to collect
            by doing rollouts of current policy.
            Either ``TrainFreq(<n>, TrainFrequencyUnit.STEP)``
            or ``TrainFreq(<n>, TrainFrequencyUnit.EPISODE)``
            with ``<n>`` being an integer greater than 0.
        :param action_noise: Action noise that will be used for exploration
            Required for deterministic policy (e.g. TD3). This can also be used
            in addition to the stochastic policy for SAC.
        :param learning_starts: Number of steps before learning for the warm-up phase.
        :param replay_buffer:
        :param log_interval: Log data every ``log_interval`` episodes
        :return:
        """
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        num_collected_steps, num_collected_episodes = 0, 0

        assert isinstance(env, VecEnv), "You must pass a VecEnv"
        assert train_freq.frequency > 0, "Should at least collect one step or episode."

        if env.num_envs > 1:
            assert train_freq.unit == TrainFrequencyUnit.STEP, "You must use only one env when doing episodic training."

        # Vectorize action noise if needed
        if action_noise is not None and env.num_envs > 1 and not isinstance(action_noise, VectorizedActionNoise):
            action_noise = VectorizedActionNoise(action_noise, env.num_envs)

        if self.use_sde:
            self.actor.reset_noise(env.num_envs)

        callback.on_rollout_start()
        continue_training = True

        while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes):
            if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.actor.reset_noise(env.num_envs)

            # Select action randomly or according to policy
            actions, buffer_actions, action_logprob = self._sample_action(learning_starts, action_noise, env.num_envs, return_action_logprob=True)

            # Rescale and perform action
            new_obs, rewards, dones, infos = env.step(actions)

            self.num_timesteps += env.num_envs
            num_collected_steps += 1

            # Give access to local variables
            callback.update_locals(locals())
            # Only stop training if return value is False, not when it is None.
            if callback.on_step() is False:
                return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)

            # Retrieve reward and episode length if using Monitor wrapper
            self._update_info_buffer(infos, dones)

            # Store data in replay buffer (normalized action and unnormalized observation)
            self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos,
                                   action_logprob=action_logprob, real_action=actions)

            # for test: save the buffer
            # print("Inv Buffer pos: " + str(self.action_prob_replay_buffer.pos))
            # print("FULL: " + str(self.action_prob_replay_buffer.full))


            self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)

            # For DQN, check if the target network should be updated
            # and update the exploration schedule
            # For SAC/TD3, the update is dones as the same time as the gradient update
            # see https://github.com/hill-a/stable-baselines/issues/900
            self._on_step()

            for idx, done in enumerate(dones):
                if done:
                    # Update stats
                    num_collected_episodes += 1
                    self._episode_num += 1

                    if action_noise is not None:
                        kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
                        action_noise.reset(**kwargs)

                    # Log training infos
                    if log_interval is not None and self._episode_num % log_interval == 0:
                        self._dump_logs()
        callback.on_rollout_end()

        return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)

    def _store_transition(
        self,
        replay_buffer: ReplayBuffer,
        buffer_action: np.ndarray,
        new_obs: Union[np.ndarray, Dict[str, np.ndarray]],
        reward: np.ndarray,
        dones: np.ndarray,
        infos: List[Dict[str, Any]],
        action_logprob: np.ndarray,
        real_action: np.ndarray,
    ) -> None:
        """
        Store transition in the replay buffer.
        We store the normalized action and the unnormalized observation.
        It also handles terminal observations (because VecEnv resets automatically).

        :param replay_buffer: Replay buffer object where to store the transition.
        :param buffer_action: normalized action
        :param new_obs: next observation in the current episode
            or first observation of the episode (when dones is True)
        :param reward: reward for the current transition
        :param dones: Termination signal
        :param infos: List of additional information about the transition.
            It may contain the terminal observations and information about timeout.
        """
        # Store only the unnormalized version
        if self._vec_normalize_env is not None:
            new_obs_ = self._vec_normalize_env.get_original_obs()
            reward_ = self._vec_normalize_env.get_original_reward()
        else:
            # Avoid changing the original ones
            self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward

        # Avoid modification by reference
        next_obs = deepcopy(new_obs_)
        # As the VecEnv resets automatically, new_obs is already the
        # first observation of the next episode
        for i, done in enumerate(dones):
            if done and infos[i].get("terminal_observation") is not None:
                if isinstance(next_obs, dict):
                    next_obs_ = infos[i]["terminal_observation"]
                    # VecNormalize normalizes the terminal observation
                    if self._vec_normalize_env is not None:
                        next_obs_ = self._vec_normalize_env.unnormalize_obs(next_obs_)
                    # Replace next obs for the correct envs
                    for key in next_obs.keys():
                        next_obs[key][i] = next_obs_[key]
                else:
                    next_obs[i] = infos[i]["terminal_observation"]
                    # VecNormalize normalizes the terminal observation
                    if self._vec_normalize_env is not None:
                        next_obs[i] = self._vec_normalize_env.unnormalize_obs(next_obs[i, :])

        replay_buffer.add(
            self._last_original_obs,
            next_obs,
            buffer_action,
            reward_,
            dones,
            infos,
        )

        if self.collect_data_for_inv:
            self.action_prob_replay_buffer.add(
                self._last_original_obs,
                next_obs,
                real_action,
                reward_,
                dones,
                infos,
                log_prob=action_logprob,
        )

        self._last_obs = new_obs
        # Save the unnormalized observation
        if self._vec_normalize_env is not None:
            self._last_original_obs = new_obs_

    def predict(
            self,
            observation: Union[np.ndarray, Dict[str, np.ndarray]],
            state: Optional[Tuple[np.ndarray, ...]] = None,
            episode_start: Optional[np.ndarray] = None,
            deterministic: bool = False,
            return_action_logprob: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Get the policy action from an observation (and optional hidden state).
        Includes sugar-coating to handle different observations (e.g. normalizing images).

        :param observation: the input observation
        :param state: The last hidden states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
            this correspond to beginning of episodes,
            where the hidden states of the RNN must be reset.
        :param deterministic: Whether or not to return deterministic actions.
        :return: the model's action and the next hidden state
            (used in recurrent policies)
        """
        return self.policy.predict(observation, state, episode_start, deterministic,return_action_logprob=return_action_logprob)

    def train(self, gradient_steps: int, batch_size: int = 64) -> None:
        if self.aml_trained_flag:
            # Switch to train mode (this affects batch norm / dropout)
            self.policy.set_training_mode(True)
            # Update optimizers learning rate
            optimizers = [self.actor.optimizer, self.critic.optimizer]
            if self.ent_coef_optimizer is not None:
                optimizers += [self.ent_coef_optimizer]

            # Update learning rate according to lr schedule
            self._update_learning_rate(optimizers)

            ent_coef_losses, ent_coefs = [], []
            actor_losses, critic_losses = [], []

            for gradient_step in range(gradient_steps):
                # Sample replay buffer
                replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

                # We need to sample because `log_std` may have changed between two gradient steps
                if self.use_sde:
                    self.actor.reset_noise()

                # Action by the current actor for the sampled state
                actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
                log_prob = log_prob.reshape(-1, 1)

                ent_coef_loss = None
                if self.ent_coef_optimizer is not None:
                    # Important: detach the variable from the graph
                    # so we don't change it with other losses
                    # see https://github.com/rail-berkeley/softlearning/issues/60
                    ent_coef = th.exp(self.log_ent_coef.detach())
                    ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                    ent_coef_losses.append(ent_coef_loss.item())
                else:
                    ent_coef = self.ent_coef_tensor

                ent_coefs.append(ent_coef.item())

                # Optimize entropy coefficient, also called
                # entropy temperature or alpha in the paper
                if ent_coef_loss is not None:
                    self.ent_coef_optimizer.zero_grad()
                    ent_coef_loss.backward()
                    self.ent_coef_optimizer.step()

                with th.no_grad():
                    # Select action according to policy
                    next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
                    # Compute the next Q values: min over all critics targets
                    next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
                    next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
                    # add entropy term
                    next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                    # td error + entropy term
                    target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

                # Get current Q-values estimates for each critic network
                # using action from the replay buffer
                current_q_values = self.critic(replay_data.observations, replay_data.actions)

                # Compute critic loss
                critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
                critic_losses.append(critic_loss.item())

                # Optimize the critic
                self.critic.optimizer.zero_grad()
                critic_loss.backward()
                self.critic.optimizer.step()

                # Compute actor loss
                # Alternative: actor_loss = th.mean(log_prob - qf1_pi)
                # Min over all critic networks
                q_values_pi = th.cat(self.critic(replay_data.observations, actions_pi), dim=1)
                min_qf_pi, _ = th.min(q_values_pi, dim=1, keepdim=True)
                actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
                actor_losses.append(actor_loss.item())

                # Optimize the actor
                self.actor.optimizer.zero_grad()
                actor_loss.backward()
                self.actor.optimizer.step()

                # Update target networks
                if gradient_step % self.target_update_interval == 0:
                    polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
                    # Copy running stats, see GH issue #996
                    polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

            self._n_updates += gradient_steps

            self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
            self.logger.record("train/ent_coef", np.mean(ent_coefs))
            self.logger.record("train/actor_loss", np.mean(actor_losses))
            self.logger.record("train/critic_loss", np.mean(critic_losses))
            if len(ent_coef_losses) > 0:
                self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))

        debug = False

        if self.aml_trained_flag or not self.collect_data_for_inv:
            train_AML_flag = False
        else:
            if self.num_timesteps >= self.train_aml_interval:
                train_AML_flag = True
                self.aml_trained_flag = True
            else:
                train_AML_flag = False

        if train_AML_flag or debug:
            # Step 2.2: Train inv network
            replay_buffer = self.action_prob_replay_buffer
            if debug:
                inv_gradient_steps = 10
            else:
                inv_gradient_steps = self.inv_gradient_steps
            inv_batch_size = self.inv_batch_size
            device = self.device
            inverse_net = self.inverse_net
            inverse_optimizer = self.inverse_optimizer
            writer = self.inv_writer
            bins = self.bins
            log_interval = self.log_interval
            fig_log_interval = self.fig_log_interval
            save_model_interval = self.save_model_interval
            exp_path =self.exp_path
            action_bins = 20 #20

            count_every_bin = th.zeros(action_bins, device=device)
            test_action_dist = True
            # sample_action_dim = 0
            action_dim = replay_buffer.action_dim
            inv_losses = []

            for i in range(inv_gradient_steps):
                inv_replay_data = replay_buffer.sample(inv_batch_size)
                state = inv_replay_data.observations.to(device)
                action = inv_replay_data.actions.to(device)
                s_prime = inv_replay_data.next_observations.to(device)

                total_inv_loss = 0
                for sample_action_dim in range(action_dim):
                    predict = inverse_net.forward_ith(state, s_prime, action, action_dim=sample_action_dim)
                    target = action[:, sample_action_dim]
                    target = th.clip(target, max=1 - 1e-5, min=-1 + 1e-5)
                    discrete_target = th.bucketize(target, bins) - 1

                    predict = predict.view(-1, action_bins)
                    discrete_target = discrete_target.view(-1)
                    # IS_weight = IS_weight.view(-1)

                    if test_action_dist:
                        ith_counts = th.bincount(discrete_target, minlength=action_bins)
                        count_every_bin = count_every_bin + ith_counts
                        print(count_every_bin)

                    criterion = th.nn.CrossEntropyLoss()
                    ith_inv_loss = criterion(predict, discrete_target)
                    total_inv_loss += ith_inv_loss

                inv_loss = total_inv_loss
                inv_losses.append(inv_loss.item())

                inverse_optimizer.zero_grad()
                inv_loss.backward()
                inverse_optimizer.step()

                if i % log_interval == 0:
                    writer.add_scalar('Training Loss/inv_loss', inv_loss.item(), i)
                    print("Iter: " + str(i) + " Inv Loss: " + str(inv_loss.item()))
                if i % fig_log_interval == 0:
                    from mnsesac.train_theta_with_varing_policy.train_inverse import test_inverse_dist
                    for actuator_id in range(action_dim):
                        test_inverse_dist(replay_buffer, inverse_net, writer, epoch=i, actuator_id=actuator_id,
                                          device=device, bins=bins)
                if i % save_model_interval == 0:
                    th.save(inverse_net, exp_path + '/model.pth')

            # Step 2.3: Train Action Mapping layer according to the buffer & the inverse_net
            action_dim = replay_buffer.action_dim
            from mnsesac.train_theta_with_varing_policy.reshape_action_layer import reshape_action, reshape_action_for_tensor
            theta = self.theta.view(action_dim, action_bins)
            theta_opt = self.theta_opt

            # training process:
            batch_size = 256
            if debug:
                steps = 10
            else:
                steps = 10000
            # # firstly train only one actuator
            # # actuator_id = 0
            # actuator_id = sample_action_dim
            correct_flag = True
            correct_const = 0.3

            log_interval = 100
            fig_log_interval = max(steps // 50, 1)
            save_model_interval = max(steps // 5, 1)

            # Train theta
            losses = []
            for i in range(steps):
                replay_data = replay_buffer.sample(batch_size)
                state = replay_data.observations.to(device)
                action = replay_data.actions.to(device)
                s_prime = replay_data.next_observations.to(device)

                predict = inverse_net(state, s_prime, action)
                target = action
                target = th.clip(target, max=1 - 1e-5, min=-1 + 1e-5)
                discrete_target = th.bucketize(target, bins) - 1

                total_loss = 0
                for actuator_id in range(action_dim):
                    # Focus on autuator_id th action
                    test_predict = predict[:, actuator_id, :]
                    test_target = target[:, actuator_id]
                    test_discrete_target = discrete_target[:, actuator_id]
                    theta_ith = theta[actuator_id, :]

                    # log_pa means log p(a|s,s') , which is a tensor without grad
                    with th.no_grad():
                        log_pa = - F.cross_entropy(test_predict, test_discrete_target, reduction='none')

                    # Step 2: calculate P(e_i|s,s') = P(a_i|s,s') / det(e_i/a_i) for every transition
                    # caculate e = h(a) under current theta
                    e, slopes = reshape_action_for_tensor(x=test_target, theta=theta_ith, low=-1, high=1)

                    # slopes means |de/da|
                    # log p(e|s,s') = log [ p(a|s,s') / |de/da| ] = log p(a|s,s') - log |de/da|
                    epsilon = 1e-8
                    log_slopes = th.log(th.abs(slopes) + epsilon)

                    # corrected p(e|s,s')
                    if correct_flag:
                        log_pe = log_pa - correct_const * log_slopes
                    else:
                        log_pe = log_pa

                    # modified bias because of importance sampling
                    IS_weight = th.abs(slopes) + epsilon
                    IS_log_pe = log_pe * IS_weight

                    IS_mean = IS_log_pe.sum() / IS_weight.sum()

                    ith_loss = -IS_mean

                    total_loss += ith_loss

                loss = total_loss
                theta_opt.zero_grad()
                loss.backward()
                theta_opt.step()

                # print("Iter: " + str(i) + " Action_Mapping_loss: " + str(loss.item()))
                if i % log_interval == 0:
                    writer.add_scalar('Training Loss/Action_Mapping_loss', loss.item(), i)
                    print("Iter: " + str(i) + " Action_Mapping_loss: " + str(loss.item()))
                    if correct_flag:
                        writer.add_scalar('Training Loss/Correct const', correct_const, i)
                    # theta_list = theta_ith.tolist()
                    # theta_list = [round(item, 2) for item in theta_list]
                    # print("theta:" + str(theta_list))
                if i % fig_log_interval == 0:
                    from mnsesac.train_theta_with_varing_policy.train_aml_with_perfect_inverse import save_theta
                    # with th.no_grad():
                    #     save_theta(replay_buffer=replay_buffer, inverse_net=inverse_net, writer=writer, epoch=i,
                    #                theta_ith=theta_ith, actuator_id=actuator_id, device=device, bins=bins)
                    with th.no_grad():
                        for actuator_id in range(action_dim):
                            theta_ith = theta[actuator_id, :]
                            save_theta(replay_buffer=replay_buffer, inverse_net=inverse_net, writer=writer, epoch=i,
                                       theta_ith=theta_ith, actuator_id=actuator_id, device=device, bins=bins)
                if i % save_model_interval == 0:
                    pass


            # Step 2.4: Update the Action Mapping Layer according to the trained AML
            # self.actor.theta = self.theta
            # repeated_theta = theta_ith.repeat(action_dim)
            # repeated_theta = th.nn.Parameter(repeated_theta, requires_grad=True)
            # self.actor.theta = repeated_theta
            trained_theta = th.nn.Parameter(theta, requires_grad=True)
            self.actor.theta = trained_theta
            self.replay_buffer.reset()


    def _excluded_save_params(self) -> List[str]:
        return super()._excluded_save_params() + ["inverse_net"]

    def save(
            self,
            path=None,
            exclude=None,
            include=None,
    ) -> None:
        pass

