import warnings
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union, Tuple

import numpy as np
import torch as th
# from th import Tensor
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, update_learning_rate

SelfPPO = TypeVar("SelfPPO", bound="PPO")

import pdb
import wandb
import time
import math

class PPO(OnPolicyAlgorithm):
    """
    Proximal Policy Optimization algorithm (PPO) (clip version)

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "MlpPolicy": ActorCriticPolicy,         # modified to use MO critic
        "CnnPolicy": ActorCriticCnnPolicy,
        "MultiInputPolicy": MultiInputActorCriticPolicy,
    }

    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]], #
        env: Union[GymEnv, str], #
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.0,
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        stats_window_size: int = 100, ### neglect
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        verbose: int = 0, #
        seed: Optional[int] = None, #
        device: Union[th.device, str] = "cpu", ## default
        _init_setup_model: bool = True, ## default
        r_dim: int = 1, ### Newly added
        r_dim_wise_normalize: bool = False, ### Newly added
        env_name: Optional[str] = None
    ):
        super().__init__(
            policy,
            env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            gamma=gamma,
            gae_lambda=gae_lambda,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            device=device,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(
                spaces.Box,
                spaces.Discrete,
                spaces.MultiDiscrete,
                spaces.MultiBinary,
            ),
            r_dim=r_dim,
            env_name=env_name
        )

        # Sanity check, otherwise it will lead to noisy gradient and NaN
        # because of the advantage normalization
        if normalize_advantage:
            assert (
                batch_size > 1
            ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

        if self.env is not None:
            # Check that `n_steps * n_envs > 1` to avoid NaN
            # when doing advantage normalization
            buffer_size = self.env.num_envs * self.n_steps
            assert buffer_size > 1 or (
                not normalize_advantage
            ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
            # Check that the rollout buffer size is a multiple of the mini-batch size
            untruncated_batches = buffer_size // batch_size
            if buffer_size % batch_size > 0:
                warnings.warn(
                    f"You have specified a mini-batch size of {batch_size},"
                    f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
                    f" after every {untruncated_batches} untruncated mini-batches,"
                    f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
                    f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
                    f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
                )
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.normalize_advantage = normalize_advantage
        self.target_kl = target_kl

        ### newly added
        self.r_dim = r_dim
        self.r_dim_wise_normalize = r_dim_wise_normalize

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        super()._setup_model()

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        if self.r_dim > 1: ## find argmin index w.r.t. expected critic
            with th.no_grad():
                initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
                # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
                initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
                averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
                scalarized_index = th.argmin(averaged_state_value).item()

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) ## Current parameterized ones

                if self.r_dim == 1:
                    values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages # [batch_size, r_dim]

                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                # We follow the Fair-RL Code using normalization over batch 'and' r_dim
                # Note that normalization is conducted for every sampled_batch of size 'self.batch_size'
                if self.normalize_advantage and len(advantages) > 1:
                    if self.r_dim_wise_normalize:
                        advantages = (advantages - advantages.mean(axis=0)) / (advantages.std(axis=0) + 1e-8)
                    else: # default
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                if self.r_dim > 1:
                    advantages = advantages[:, scalarized_index] # [batch_size]

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob) # policy prob value of which the policy is used for generating current data

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None: ##### Here!
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred) # rollout_data.returns = A(s,a) by gae + V(s) from value network 
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else: #### Here!
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def learn(
        self: SelfPPO,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        tb_log_name: str = "PPO",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfPPO:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )


class Weight(th.nn.Module):
    def __init__(self,
                 r_dim: int=2, # r_dim = K in the paper
                 initialize: Union[str, List[float]]='uniform',
                 use_ci = True,
                 use_md = True,
                 corr = 'inner_product',
                 reg = 'kl2',
                 eta = 1.0,
                 lam = 1.0,
                 beta = 0.5,
                 lr = 1e-3,
                 device = th.device("cpu"),
        ):
        super(Weight, self).__init__()

        self.use_ci = use_ci
        self.use_md = use_md
        self.corr = corr
        self.reg = reg
        self.eta = eta
        self.lam = lam
        self.beta = beta
        self.r_dim = r_dim
        self.lr = lr
        self.device = device

        self.w_update_cnt = 0
        # Declare a trainable parameter
        if initialize == 'uniform':
            self.weight = th.nn.Parameter(th.full((r_dim,1), 1/r_dim, dtype=th.float32, device=self.device))  # Example: a 1D tensor with 1/r_dim values
            print("Weight", self.weight)
        elif initialize == 'dirichlet':
            # random sampling
            dirichlet_distribution = th.distributions.dirichlet.Dirichlet(th.ones(r_dim, dtype=th.float32)) # flat
            samples = dirichlet_distribution.sample().to(self.device) #[r_dim]
            self.weight = th.nn.Parameter(samples.unsqueeze(dim=-1))
            print("Weight", self.weight)
        elif isinstance(initialize, list):
            self.weight = th.nn.Parameter(th.tensor(initialize, dtype=th.float32, device=self.device).unsqueeze(dim=-1))
            print("Weight", self.weight)
        else:
            raise NotImplementedError
        self.w_avg = self.weight.data.clone()

        self.matrix = th.zeros(r_dim, r_dim, device=self.device)
        for i in range(r_dim):
            for j in range(i + 1):
                self.matrix[i, j] = 1.0 / (i + 1)
        self.intercept = self.matrix[:,0].unsqueeze(dim=-1)  # [r_dim, 1]

        self.zero_vector = th.zeros(r_dim, device=self.device)
    
    def proj_on_simplex(self, x):   # [r_dim,1] -> [r_dim,1]
        x = x.squeeze(dim=-1)
        sorted_x, _ = th.sort(x, descending=True) # [r_dim, ]
        sorted_x = sorted_x.unsqueeze(dim=-1) # [r_dim, 1]
        criterion = sorted_x - th.matmul(self.matrix, sorted_x) + self.intercept # [r_dim, 1]
        threshold_idx = th.sum(criterion > 0).item() # range 1 to r_dim
        lmbda = criterion[threshold_idx-1] - sorted_x[threshold_idx-1] # [1,]
        projected_x = th.max(x + lmbda, self.zero_vector).unsqueeze(dim=-1)
        # if not th.sum(projected_x).item() == 1.0:  # if numerical error exists
        #     projected_x /= th.sum(projected_x)
        return projected_x
        

    def step_linear_obj_MD(self, value_vector, corr_vector):    # without any regularization, thus, md uses only lam 
        # Update the weights using softmax, instead of PGD
        self.w_update_cnt += 1
        assert corr_vector.shape[0] == self.r_dim, "corr_vector must be a column vector"
        assert len(corr_vector.shape) == 2
        with th.no_grad():
            value_vector = value_vector.to(self.device)
            corr_vector = corr_vector.to(self.device)
            if self.use_ci and self.use_md:
                c1 = self.lam*self.eta/(self.lam+self.eta)
                c2 = self.eta/(self.lam+self.eta)
                c3 = self.lam/(self.lam+self.eta)                
                if self.reg == 'kl2':
                    # print(f'value_vector: {value_vector}')
                    # print(f'self.weight.data: {self.weight.data}')
                    # print(f'corr_vector: {corr_vector}')
                    self.weight.data = th.softmax(- c1 * value_vector + c2 * th.log(self.weight.data + 1e-8) + c3 * th.log(corr_vector + 1e-8), dim=0)
                    # print(f'self.weight.data: {self.weight.data}')
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- c1 * value_vector + (c2-1) * self.weight.data + c3 * corr_vector)/(c1 + 1e-8)  # befor projection
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif self.use_ci and not self.use_md:   # lam==infty case
                if self.reg == 'kl2':
                    self.weight.data = th.softmax(- self.eta * value_vector + th.log(corr_vector + 1e-8), dim=0)
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- value_vector + corr_vector/(self.eta+1e-8) - self.weight.data/(self.eta+1e-8))
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif not self.use_ci and self.use_md:   # eta==infty case
                if self.reg == 'kl2':
                    self.weight.data = th.softmax(- self.lam * value_vector + th.log(self.weight.data + 1e-8), dim=0)
                elif self.reg == 'quad2':
                    self.weight.data -= - self.lr * value_vector
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            else:   # with beta*H(w) only
                self.weight.data = th.softmax(- value_vector / (self.beta + 1e-8), dim=0) # [r_dim,1]
            
            self.w_avg = (self.weight.data + self.w_update_cnt * self.w_avg) / (self.w_update_cnt + 1)
    
    def step(self, value_vector, corr_vector):     
        # Update the weights using softmax, instead of PGD
        ## NOTE. ecw = tau_w = 1/nu = 1/eta(this eta is not npg lr, but previous eta, i.e. nu for now)
        ## NOTE. this "step" is different from "step_linear_obj_MD" when using only MD. since here, MD is conducted on regularized obj
        self.w_update_cnt += 1
        assert corr_vector.shape[0] == self.r_dim, "corr_vector must be a column vector"
        assert len(corr_vector.shape) == 2
        with th.no_grad():
            value_vector = value_vector.to(self.device)
            corr_vector = corr_vector.to(self.device)
            if self.use_ci and self.use_md:                 # self.eta = nu = 1/tau_w  # same as before
                c1 = self.lam*self.eta/(self.lam+self.eta)
                c2 = self.eta/(self.lam+self.eta)
                c3 = self.lam/(self.lam+self.eta)                
                if self.reg == 'kl2':
                    # print(f'value_vector: {value_vector}')
                    # print(f'self.weight.data: {self.weight.data}')
                    # print(f'corr_vector: {corr_vector}')
                    self.weight.data = th.softmax(- c1 * value_vector + c2 * th.log(self.weight.data + 1e-8) + c3 * th.log(corr_vector + 1e-8), dim=0)
                    # print(f'self.weight.data: {self.weight.data}')
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- c1 * value_vector + (c2-1) * self.weight.data + c3 * corr_vector)/(c1 + 1e-8)  # befor projection
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif self.use_ci and not self.use_md:   # lam==infty case  # same as before
                if self.reg == 'kl2':
                    self.weight.data = th.softmax(- self.eta * value_vector + th.log(corr_vector + 1e-8), dim=0)
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- value_vector + corr_vector/(self.eta+1e-8) - self.weight.data/(self.eta+1e-8))
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif not self.use_ci and self.use_md:   # NOTE. here is changed
                if self.reg == 'kl2':
                    beta_new = self.eta/(self.lam+self.eta)     ## newly added
                    self.weight.data = th.softmax(beta_new * np.log(self.weight.data + 1e-8) - (1-beta_new)*self.eta * value_vector, dim=0)    ## newly added
                elif self.reg == 'quad2':
                    self.weight.data -= - self.lr * value_vector
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            else:   # with beta*H(w) only  # same as before
                self.weight.data = th.softmax(- value_vector / (self.beta + 1e-8), dim=0) # [r_dim,1] 
            
            self.w_avg = (self.weight.data + self.w_update_cnt * self.w_avg) / (self.w_update_cnt + 1)

    def forward(self, input):
        # input: K-dim reward vector
        assert len(input.shape) == 2 # [1, r_dim]
        output = th.matmul(input, self.weight)
        return output 

class MaxminPPO(PPO):
    """ 
    Maxmin Proximal Policy Optimization algorithm (MaxminPPO) (clip version)

    TBE below

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    :param ent_coef_weight: Entropy coefficient for weight(i.e. adversary player)
    :param n_init_states: after train, when computing \sum_s v(s)\mu(s), replace with \sum_{i=1}^n_init_states v(s)/n_init_states
    """
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]], #
        env: Union[GymEnv, str], #
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.1, # default: 0.0
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        stats_window_size: int = 100, ### neglect
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        verbose: int = 0, #
        seed: Optional[int] = None, #
        device: Union[th.device, str] = "cpu", ## default
        _init_setup_model: bool = True, ## default
        r_dim: int = 1, ### Newly added
        r_dim_wise_normalize: bool = False, ### Newly added
        env_name: Optional[str] = None, ### Newly added
        ### Newly added
        weight_initialize: Union[str, List[float]]='uniform',
        ent_coef_weight: float = 0.1,
        # n_init_states: int = 10,  # seems not needed, since already sampling n_env init states 
        ### Newly added for CIMD
        use_ci = True,
        use_md = True,
        corr = 'inner_product',
        reg = 'kl2',
        ci_coef = 2.0,  
        md_coef = 1.0,
        ## newly added
        corr_period = 1,    # default: 1
        util_ppo = False,   # default: False
    ) -> None:
        super().__init__(  # variables for class PPO 
            policy = policy, #
            env = env, #
            learning_rate = learning_rate,
            n_steps = n_steps, #
            batch_size = batch_size, #
            n_epochs = n_epochs, #
            gamma = gamma, ## default
            gae_lambda = gae_lambda, ## default
            clip_range = clip_range, ## default
            clip_range_vf = clip_range_vf, ## default
            normalize_advantage = normalize_advantage, ## default
            ent_coef = ent_coef,
            vf_coef = vf_coef, ## default
            max_grad_norm = max_grad_norm, ## default
            use_sde = use_sde, ## default
            sde_sample_freq = sde_sample_freq, ## default
            target_kl = target_kl, ## default
            stats_window_size = stats_window_size, ### neglect
            tensorboard_log = tensorboard_log, ## default
            policy_kwargs = policy_kwargs, ## default
            verbose = verbose, #
            seed = seed, #
            device = device, ## default
            _init_setup_model = _init_setup_model, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = r_dim_wise_normalize, ### Newly added
            env_name = env_name,     ### Newly added
        )

        ### Add new parameters: weight
        self.weight = Weight(r_dim=r_dim, initialize=weight_initialize, use_ci=use_ci, use_md=use_md, corr=corr, reg=reg, eta=ci_coef, lam=md_coef, beta=ent_coef_weight, lr=learning_rate)
        # self.alpha = ent_coef
        self.beta = ent_coef_weight
        self.use_ci = use_ci
        if self.use_ci:
            self.corr = corr
        self.corr_period = corr_period  ## newly added
        # self.eta = ci_coef
        # self.md_lambda = md_coef
        # self.n_init_states = n_init_states
        self.device = device
        self.util_ppo = util_ppo

        ## wandb log
        init_w_dict = {}
        for i in range(self.r_dim):
            init_w_dict[f"Weight {i}"] = self.weight.weight.data[i].item()
            init_w_dict[f"Avg_Weight {i}"] = self.weight.w_avg[i].item()
        wandb.log(init_w_dict)

        self.time = 0.0

        # try:
        #     env_name = env.spec.id
        # except:
        #     env_name = 'SUMO'

        if env_name == 'DST':
            env_name = 'deep-sea-treasure-sparse-v0'
        elif env_name == 'Four-room':
            env_name = 'four-room-truncated-v0'
        elif env_name == 'reacher':
            env_name = 'mo-reacher-v4'
        elif env_name == 'traffic' or env_name == 'traffic-big' or env_name == 'traffic-asym' or env_name == 'traffic-asym-4000':
            env_name = 'SUMO'
        elif env_name == 'sc':
            env_name = 'sc'
        else:
            raise Exception("Invalid Env Name")
        self.env_name = env_name

        ### Environments

        if env_name == 'mo-reacher-v4':
            init_angle = np.array([0, 3.1415 / 2])
            init_state = np.concatenate([
                np.cos(init_angle),
                np.sin(init_angle),
                np.zeros(2)])
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        elif env_name == 'mo-mountaincar-v0': # for revised version
            init_state = np.array([-0.5, 0])
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        elif env_name == 'four-room-truncated-v0':
            self.init_state_tensor = th.tensor([3,3] + [0 for _ in range(7)]).unsqueeze(0).to(self.device)      # for asym 7x7 map, with 2+5 items
            # self.init_state_tensor = th.tensor([6] + [0 for _ in range(5)]).unsqueeze(0).to(self.device)    # 7x7 map, with four items
            # self.init_state_tensor = th.tensor([12] + [0 for _ in range(13)]).unsqueeze(0)        # 12x12 map
        elif env_name == 'deep-sea-treasure-sparse-v0':
            init_state = np.array([0, 0], dtype=np.int32)
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        elif env_name == 'SUMO':
            # self.init_state_tensor = th.tensor([1.] + [0. for _ in range(20)]).unsqueeze(0) ### for 2-way-intersection
            self.init_state_tensor = th.tensor([1.] + [0. for _ in range(36)]).unsqueeze(0)  ### for big-intersection
        elif env_name == 'sc':
            init_state = np.array([14194392.6530898,  11231918.30015633,  8887734.17771035,  7032798.55544704,
                          5565001.66373161,  4403544.79699775,  3484492.53553488, 2757253.25229394,
                          2181794.16614764,  8272049.37361304,        0.        ])
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        else:
            raise NotImplementedError
        
        print(f"use correlation: {self.use_ci}")
        print(f"use mirror descent: {use_md}")

    
    def train(self, batch_size: int = 100) -> None:  # one step of training for maxmin_PPO
        """
        ppo_gradient_steps: int = 1: how many times of ppo update, default: 1. 
        if this value is sufficiently large, the training can be seen as exact SBR with one timescale 
        """
        start_train_time = time.time()

        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # if save_value_for_weight_update = True, save K-dim value for current policy, and use this for weight update
        value_vector_for_weight_update: Optional[th.Tensor] = None

        ## Simultaneous updates for policy(agent player) and weight(adversary player)
        # w_t, theta_t -> theta_{t+1} & theta_t -> w_{t+1}    

        ## newly added
        corr_vector = np.zeros((self.r_dim,1))
        corr_vector = th.softmax(th.tensor(corr_vector, dtype=th.float32, device=self.device), dim=0)  # uniform at initial

        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            # rollout_buffer is being keep updated in learn(on_policy_algorithm.py) 
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                # Befor each (theta_t, w_t) update, prepare v^{pi_{theta_t}}, which will be used for w_t update
                with th.no_grad():
                    initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
                    # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
                    initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
                    averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
                    if self.r_dim > 1:
                        assert averaged_state_value.dim() == 1, "MO dim error"
                        value_vector_for_weight_update = th.unsqueeze(averaged_state_value, dim = 1)  # (r_dim,1)
                    else:
                        assert averaged_state_value.dim() == 0, "SO dim error"
                        value_vector_for_weight_update = th.tensor([[averaged_state_value]])  # (1,1)

                ## PPO-update
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) ## Current parameterized ones

                if self.r_dim == 1:
                    values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages # [batch_size, r_dim]

                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                # We follow the Fair-RL Code using normalization over batch 'and' r_dim
                # Note that normalization is conducted for every sampled_batch of size 'self.batch_size'
                if self.normalize_advantage and len(advantages) > 1:
                    if self.r_dim_wise_normalize:
                        advantages = (advantages - advantages.mean(axis=0)) / (advantages.std(axis=0) + 1e-8)
                    else: # default
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                if self.r_dim > 1:
                    weight_temp = self.weight.weight.data
                    # pdb.set_trace()
                    # print(f'grad?: {weight_temp.requires_grad}')    # False
                    assert weight_temp.dim()==2, "error in w dot adv"
                    # pdb.set_trace()
                    # print(f'weight_temp device: {weight_temp.device}')  # cpu
                    weight_temp = weight_temp.to(advantages.device)
                    # print(f'advantages device: {advantages.device}')
                    # print(f'weight_temp device: {weight_temp.device}')  # cuda:0
                    advantages = th.matmul(advantages, weight_temp)
                    advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                    # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"


                # if self.r_dim > 1:
                #     weight_temp = self.weight.weight.detach()   # cpu, while advantages is on cuda:0
                #     assert weight_temp.dim()==2, "error in w dot adv"
                #     # pdb.set_trace()
                #     weight_temp = weight_temp.to(advantages.device)
                #     # print(f'advantages device: {advantages.device}')
                #     # print(f'weight_temp device: {weight_temp.device}')
                #     advantages = th.matmul(advantages, weight_temp)
                #     advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                #     # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob) # policy prob value of which the policy is used for generating current data

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None: ##### Here!
                    # No clipping
                    values_pred = values    # state value by current value network
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred) # rollout_data.returns = {A(s,a) by gae} + {V(s) from value network} 
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else: #### Here!
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                ## Policy(agent player) update; w_t, theta_t -> theta_{t+1}
                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

                ## weight(adversary player); theta_t -> w_{t+1}
                with th.no_grad():
                    if (self.weight.w_update_cnt + 1) % self.corr_period == 0:
                        # define corr_vector
                        corr_vector = np.zeros((self.r_dim,1))                        
                        if self.use_ci:
                            reward_buffer = self.rollout_buffer.rewards[-self.corr_period:].copy()  # or [-self.batch_size:] or [-1:]
                            reward_buffer = reward_buffer.reshape(-1, self.r_dim)   # [corr_period*n_envs, r_dim]
                            min_index = th.argmin(value_vector_for_weight_update).item()
                            if self.corr == 'inner_product':
                                corr_vector = np.matmul(reward_buffer[:, min_index].reshape(1,-1), reward_buffer).reshape(-1,1)  # [r_dim, 1]
                                corr_vector = corr_vector / len(reward_buffer)
                            elif self.corr == 'cos_sim':
                                norms = np.linalg.norm(reward_buffer, axis=0)
                                normalize = norms * norms[min_index]
                                normalize = normalize.reshape(-1,1)
                                corr_vector = np.matmul(reward_buffer[:, min_index].reshape(1,-1), reward_buffer).reshape(-1,1)  # [r_dim, 1]
                                corr_vector = corr_vector / (normalize + 1e-8)
                            elif self.corr == 'pearson_corr':
                                # corr_vector = np.zeros((self.r_dim,1))
                                for i in range(self.r_dim):
                                #    print(f'min reward {i}: {np.min(reward_buffer[:, min_index])}, {np.min(reward_buffer[:, i])}')
                                #    print(f'max reward {i}: {np.max(reward_buffer[:, min_index])}, {np.max(reward_buffer[:, i])}')
                                #    corr_mx = np.corrcoef(reward_buffer[:, min_index], reward_buffer[:, i])   # -> this gives nan when std=0
                                #    corr_vector[i] = corr_mx[0,1]    
                                    cov = np.cov(reward_buffer[:, min_index], reward_buffer[:, i])[0,1]
                                    std_x = np.std(reward_buffer[:, min_index], ddof=1)
                                    std_y = np.std(reward_buffer[:, i], ddof=1)
                                    corr_vector[i] = np.clip(cov/(std_x*std_y + 1e-8),-1,1)
                                # print(f'corr_vector: {corr_vector}')
                            elif self.corr == 'pearson_dist':
                                # corr_vector = np.zeros((self.r_dim,1))
                                for i in range(self.r_dim):
                                    cov = np.cov(reward_buffer[:, min_index], reward_buffer[:, i])[0,1]
                                    std_x = np.std(reward_buffer[:, min_index], ddof=1)
                                    std_y = np.std(reward_buffer[:, i], ddof=1)
                                    cor = np.clip(cov/(std_x*std_y + 1e-8),-1,1)
                                    corr_vector[i] = np.sqrt((1-cor)/2)
                            else:
                                raise NotImplementedError
                        corr_vector = th.softmax(th.tensor(corr_vector, dtype=th.float32), dim=0)
                            # print(f'corr_vector: {corr_vector}')
                    if not self.util_ppo:
                        self.weight.step(value_vector_for_weight_update, corr_vector) # (r_dim,1) here, w_update_cnt increases
                        # pdb.set_trace()
                        # print(f'weight: {self.weight.weight}')

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        end_train_time = time.time()
        self.time += end_train_time - start_train_time
        
        # Compute value after one step of training  # not important
        with th.no_grad():
            self.policy.set_training_mode(False)    # eval mode
            
            # four-room env has no initial_states attribute
            # instead, env.initial
            initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
            # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
            initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
            averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
            if self.r_dim > 1:
                # assert averaged_state_value.dim() == 1, "MO dim error"
                current_state_value_vector = th.unsqueeze(averaged_state_value, dim = 1).to(self.device)  # (r_dim,1)
            else:
                # assert averaged_state_value.dim() == 0, "SO dim error"
                current_state_value_vector = th.tensor([[averaged_state_value]]).to(self.device)  # (1,1)

            wegh_data = self.weight.weight.data
            assert wegh_data.dim()==2, "error in w shape"
            # pdb.set_trace()
            # print(f'weight_temp device: {wegh_data.device}')  # cuda
            # print(f'current_state_value_vector device: {current_state_value_vector.device}')    # cuda
            # wegh_data = wegh_data.to(current_state_value_vector.device)     # no need
            weighted_value = th.matmul(current_state_value_vector.view(1,-1), wegh_data)
            weighted_value = weighted_value[0][0]
            # print(f'weighted_value: {weighted_value[0][0]}')



     
        
        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

        # Wandb log
        wegh = self.weight.weight
        # pdb.set_trace()
        # print(f'weight: {wegh}')


        if self.env_name == 'SUMO':
            wandb_dict = {}
            weight_log = self.weight.weight.data
            weight_avg_log = self.weight.w_avg
            for i in range(self.r_dim):
                wandb_dict[f"Weight {i}"] = weight_log[i].item()
                wandb_dict[f"Avg_Weight {i}"] = weight_avg_log[i].item()
                wandb_dict[f'Corr {i}'] = corr_vector[i].item()
                wandb_dict[f"Value {i}"] = current_state_value_vector[i].item()
            wandb_dict["loss"] = loss.item()
            wandb_dict["Total Mean V(mu)"] = th.mean(current_state_value_vector).item()
            wandb_dict["Total Mean Q"] = weighted_value
            wandb_dict["Total train time"] = self.time
            wandb.log(wandb_dict)

            # wandb.log({
            #     'loss': loss.item(),
            #     # 'Total Mean Q': th.mean(current_state_value_vector).item(),
            #     'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
            #     'Value 0': current_state_value_vector[0].item(),
            #     'Value 1': current_state_value_vector[1].item(),
            #     'Value 2': current_state_value_vector[2].item(),
            #     'Value 3': current_state_value_vector[3].item(),
            #     'Weight 0': wegh[0].item(),
            #     'Weight 1': wegh[1].item(),
            #     'Weight 2': wegh[2].item(),
            #     'Weight 3': wegh[3].item(),
            #     'Total Mean Q': weighted_value,
            #     'Total train time': self.time,
            # }
            # )
        elif self.env_name == 'four-room-truncated-v0' or self.env_name == 'sc':
            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                'Corr 0': corr_vector[0].item(),
                'Corr 1': corr_vector[1].item(),
                # 'Weight 2': wegh[2].item(),
            }
            )
        elif self.env_name == 'deep-sea-treasure-sparse-v0':
            # with th.no_grad():
            #     init_q_value = self.q_net.forward(self.init_state_tensor)[0]  # [ac_dim]
            #     ## Masking part
            #     init_q_value[0] = float('-inf')
            #     init_q_value[2] = float('-inf')
            #     init_prob = th.softmax(init_q_value / self.ent_alpha, dim=-1)

            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                # 'Init Up': init_prob[0].item(),
                # 'Init Down': init_prob[1].item(), ###
                # 'Init Left': init_prob[2].item(),
                # 'Init Right': init_prob[3].item(), ###
            }
            )
        elif self.env_name == 'mo-reacher-v4':
            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Value 2': current_state_value_vector[2].item(),
                'Value 3': current_state_value_vector[3].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Weight 2': wegh[2].item(),
                'Weight 3': wegh[3].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                'Corr 0': corr_vector[0].item(),
                'Corr 1': corr_vector[1].item(),
                'Corr 2': corr_vector[2].item(),
                'Corr 3': corr_vector[3].item(),
            }
            )
        else:
            raise NotImplementedError








