import warnings
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union

import numpy as np
import torch as th
# from th import Tensor
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
from stable_baselines3.ppo.ppo import PPO

SelfPPO = TypeVar("SelfPPO", bound="PPO")

import pdb
import wandb
import time

### Maybe gru part should be here, not in policy_dummy.py
class GRUNetwork(th.nn.Module): # TBE
    def __init__(self, input_dim, gru_hidden_size, output_dim):
        super(GRUNetwork, self).__init__()
        self.gru = th.nn.GRU(input_dim, gru_hidden_size, batch_first=True)
        self.fc = th.nn.Linear(gru_hidden_size, output_dim)
    
    def forward(self, x):
        # Assume x has shape (batch_size, seq_len, input_dim)
        gru_out, _ = self.gru(x)
        # Use only the last output of the GRU
        output = self.fc(gru_out[:, -1, :])
        return output



class GameWeight(th.nn.Module):
    def __init__(self,
                 r_dim: int=2, # r_dim = K in the paper
                 initialize: Union[str, List[float]]='uniform',
                 # history_type: str = 'histogram' # try Dict e.g. {'type':'histogram', 'coeff': 0.1}
                 device: str = 'auto',
                 history_type: Dict[str,Union[str,float]] = None,
        ):
        super(GameWeight, self).__init__()

        self.device = device
        self.cnt: int = 1
        self.history_dict = history_type    # dictionary

        # Declare a trainable parameter
        if initialize == 'uniform':
            self.weight = th.nn.Parameter(th.full((r_dim,1), 1/r_dim, dtype=th.float32, device=self.device))  # Example: a 1D tensor with 1/r_dim values
            print("Weight", self.weight)
        elif initialize == 'dirichlet':
            # random sampling
            dirichlet_distribution = th.distributions.dirichlet.Dirichlet(th.ones(r_dim, dtype=th.float32)) # flat
            samples = dirichlet_distribution.sample().to(self.device) #[r_dim]
            self.weight = th.nn.Parameter(samples.unsqueeze(dim=-1))  # [r_dim,1]
            print("Weight", self.weight)
        elif isinstance(initialize, list):
            self.weight = th.nn.Parameter(th.tensor(initialize, dtype=th.float32, device=self.device).unsqueeze(dim=-1))  # [r_dim,1]
            print("Weight", self.weight)
        else:
            raise NotImplementedError
        
        # initialize history_type
        self.history_type: str = history_type['type']
        if history_type == 'histogram':
            pass    # nothing to do
        elif history_type == 'moving_avg':
            self.coeff = history_type['coeff']
        elif history_type == 'rnn':                     # TBE
            self.input_dim = history_type['input_dim']
            self.hidden_dim = history_type['hidden_dim']
            self.output_dim = history_type['output_dim']
            self.gru = GRUNetwork(input_dim=self.input_dim, gru_hidden_size=self.hidden_dim, output_dim=self.output_dim)
        

    # def step(self, beta, value_vector, history_type: Dict[str, Union[str,float]]):
    #     # Update the weights using softmax to obtain new_weight
    #     # Update history by combining self.weight and new_weight
    #     with th.no_grad():
    #         # print(f'vector value: {value_vector}') -> must be [r_dim,1]
    #         new_weight = th.softmax(- value_vector / (beta + 1e-8), dim=0) # [r_dim,1]
    #         assert new_weight.device == self.weight.data.device

    #         if history_type == 'histogram':
    #             self.weight.data = (1-1/(self.cnt+1)) * self.weight.data + 1/(self.cnt+1) * new_weight # [r_dim,1]
    #         elif history_type == 'moving_avg':
    #             self.weight.data = (1-self.coeff) * self.weight.data + self.coeff * new_weight # [r_dim,1]
    #         elif history_type == 'rnn':
    #             pass
    def step(self, beta, value_vector):
        # Update the weights using softmax to obtain new_weight
        # Update history by combining self.weight and new_weight
        with th.no_grad():
            # print(f'vector value: {value_vector}') -> must be [r_dim,1]
            new_weight = th.softmax(- value_vector / (beta + 1e-8), dim=0) # [r_dim,1]
            assert new_weight.device == self.weight.data.device, "match devices"

            if self.history_type == 'histogram':
                self.weight.data = (1-1/(self.cnt+1)) * self.weight.data + 1/(self.cnt+1) * new_weight # [r_dim,1]
            elif self.history_type == 'moving_avg':
                self.weight.data = (1-self.coeff) * self.weight.data + self.coeff * new_weight # [r_dim,1]
            elif self.history_type == 'rnn':    # TBE
                self.weight.data = self.gru(new_weight)
    
    # NOTE. gru is learned by what?? -> learning with bellman error is not good, since SFP is independent process
    # w must contain information for only min of objectives, not about policy/value approximation
    # weight history compressing without learning? -> purpose: to make it close to SFP, 
    # i.e. sbr(w_history) = ppo-update(w_current)


# the only change from MaxminPPO is to use history-based weight instead of current weight
class GamePPO(PPO):
    """ 
    SFP-based Maxmin Proximal Policy Optimization algorithm (GamePPO) (clip version)

    TBE below

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    :param ent_coef_weight: Entropy coefficient for weight(i.e. adversary player)
    :param n_init_states: after train, when computing \sum_s v(s)\mu(s), replace with \sum_{i=1}^n_init_states v(s)/n_init_states
    """
    def __init__(
        ## Add new variables for maxmin ppo (TBA) -> maybe vars about w?
        self,
        policy: Union[str, Type[ActorCriticPolicy]], #
        env: Union[GymEnv, str], #
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.1, # default: 0.0
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        stats_window_size: int = 100, ### neglect
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        verbose: int = 0, #
        seed: Optional[int] = None, #
        device: Union[th.device, str] = "auto", ## default
        _init_setup_model: bool = True, ## default
        r_dim: int = 1, ### Newly added
        r_dim_wise_normalize: bool = False, ### Newly added
        env_name: Optional[str] = None,
        weight_initialize: Union[str, List[float]]='uniform',
        ent_coef_weight: float = 0.1,
        ### Newly added
        history_type: Dict[str,Union[str,float]] = None,    # 'histogram', 'rnn', 'moving_avg', something else?
        # e.g. history_type = {'type': rnn, 'coeff': (params for gru)}
        # n_init_states: int = 10,  # seems not needed, since already sampling n_env init states 
    ) -> None:
        super().__init__(  # variables for class PPO 
            policy = policy, #
            env = env, #
            learning_rate = learning_rate,
            n_steps = n_steps, #
            batch_size = batch_size, #
            n_epochs = n_epochs, #
            gamma = gamma, ## default
            gae_lambda = gae_lambda, ## default
            clip_range = clip_range, ## default
            clip_range_vf = clip_range_vf, ## default
            normalize_advantage = normalize_advantage, ## default
            ent_coef = ent_coef,
            vf_coef = vf_coef, ## default
            max_grad_norm = max_grad_norm, ## default
            use_sde = use_sde, ## default
            sde_sample_freq = sde_sample_freq, ## default
            target_kl = target_kl, ## default
            stats_window_size = stats_window_size, ### neglect
            tensorboard_log = tensorboard_log, ## default
            policy_kwargs = policy_kwargs, ## default
            verbose = verbose, #
            seed = seed, #
            device = device, ## default
            _init_setup_model = _init_setup_model, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = r_dim_wise_normalize, ### Newly added
            env_name = env_name
        )

        ### Add new parameters: weight
        self.weight = GameWeight(r_dim=r_dim, initialize=weight_initialize, device = device, history_type = history_type)
        # self.alpha = ent_coef
        self.beta = ent_coef_weight
        # self.n_init_states = n_init_states

        self.time = 0.0

        # try:
        #     env_name = env.spec.id
        # except:
        #     env_name = 'SUMO'

        if env_name == 'DST':
            env_name = 'deep-sea-treasure-sparse-v0'
        elif env_name == 'Four-room':
            env_name = 'four-room-truncated-v0'
        elif env_name == 'reacher':
            env_name = 'mo-reacher-v4'
        elif env_name == 'traffic' or env_name == 'traffic-big' or env_name == 'traffic-asym':
            env_name = 'SUMO'
        else:
            raise Exception("Invalid Env Name")
        self.env_name = env_name

        ### Environments

        # if env_name == 'mo-reacher-v4':
        #     init_angle = np.array([0, 3.1415 / 2])
        #     init_state = np.concatenate([
        #         np.cos(init_angle),
        #         np.sin(init_angle),
        #         np.zeros(2)])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        # elif env_name == 'mo-mountaincar-v0': # for revised version
        #     init_state = np.array([-0.5, 0])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        if env_name == 'four-room-truncated-v0':
            self.init_state_tensor = th.tensor([6] + [0 for _ in range(5)]).unsqueeze(0)
            # self.init_state_tensor = th.tensor([12] + [0 for _ in range(13)]).unsqueeze(0)
        elif env_name == 'deep-sea-treasure-sparse-v0':
            init_state = np.array([0, 0], dtype=np.int32)
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        elif env_name == 'SUMO':
            # self.init_state_tensor = th.tensor([1.] + [0. for _ in range(20)]).unsqueeze(0) ### for 2-way-intersection
            self.init_state_tensor = th.tensor([1.] + [0. for _ in range(36)]).unsqueeze(0)  ### for big-intersection
        else:
            raise NotImplementedError

    
    def train(self, batch_size: int = 100) -> None:  # one step of training for maxmin_PPO
        """
        ppo_gradient_steps: int = 1: how many times of ppo update, default: 1. 
        if this value is sufficiently large, the training can be seen as exact SBR with one timescale 
        """
        # Start ppo directly? or do random exporation for total_steps * init_fraction, and then start ppo?
        # -> TBD
        start_train_time = time.time()

        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # if save_value_for_weight_update = True, save K-dim value for current policy, and use this for weight update
        value_vector_for_weight_update: Optional[th.Tensor] = None

        ## Simultaneous updates for policy(agent player) and weight(adversary player)
        # w_t, theta_t -> theta_{t+1} & theta_t -> w_{t+1}    
        # train for n_epochs epochs -> TBE suit for MO 
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            # rollout_buffer is being keep updated in learn(on_policy_algorithm.py) 
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                # Befor each (theta_t, w_t) update, prepare v^{pi_{theta_t}}, which will be used for w_t update
                with th.no_grad():
                    initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
                    # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
                    initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
                    averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
                    if self.r_dim > 1:
                        assert averaged_state_value.dim() == 1, "MO dim error"
                        value_vector_for_weight_update = th.unsqueeze(averaged_state_value, dim = 1)  # (r_dim,1)
                    else:
                        assert averaged_state_value.dim() == 0, "SO dim error"
                        value_vector_for_weight_update = th.tensor([[averaged_state_value]])  # (1,1)

                ## PPO-update
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) ## Current parameterized ones

                if self.r_dim == 1:
                    values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages # [batch_size, r_dim]

                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                # We follow the Fair-RL Code using normalization over batch 'and' r_dim
                # Note that normalization is conducted for every sampled_batch of size 'self.batch_size'
                if self.normalize_advantage and len(advantages) > 1:
                    if self.r_dim_wise_normalize:
                        advantages = (advantages - advantages.mean(axis=0)) / (advantages.std(axis=0) + 1e-8)
                    else: # default
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                if self.r_dim > 1:
                    weight_temp = self.weight.weight.data   # this weight is already history-based by update rule of GameWeight class
                    # pdb.set_trace()
                    # print(f'grad?: {weight_temp.requires_grad}')    # False
                    assert weight_temp.dim()==2, "error in w dot adv"
                    # pdb.set_trace()
                    # print(f'weight_temp device: {weight_temp.device}')  # cpu
                    # weight_temp = weight_temp.to(advantages.device) # seems not needed
                    # print(f'advantages device: {advantages.device}')
                    # print(f'weight_temp device: {weight_temp.device}')  # cuda:0
                    advantages = th.matmul(advantages, weight_temp)
                    advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                    # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"


                # if self.r_dim > 1:
                #     weight_temp = self.weight.weight.detach()   # cpu, while advantages is on cuda:0
                #     assert weight_temp.dim()==2, "error in w dot adv"
                #     # pdb.set_trace()
                #     weight_temp = weight_temp.to(advantages.device)
                #     # print(f'advantages device: {advantages.device}')
                #     # print(f'weight_temp device: {weight_temp.device}')
                #     advantages = th.matmul(advantages, weight_temp)
                #     advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                #     # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob) # policy prob value of which the policy is used for generating current data

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None: ##### Here!
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred) # rollout_data.returns = {A(s,a) by gae} + {V(s) from value network} 
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else: #### Here!
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                ## Policy(agent player) update; w_t, theta_t -> theta_{t+1}
                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

                ## weight(adversary player); theta_t -> w_{t+1}
                with th.no_grad():
                    self.weight.step(self.beta, value_vector_for_weight_update) # (r_dim,1)
                    # pdb.set_trace()
                    # print(f'weight: {self.weight.weight}')

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        end_train_time = time.time()
        self.time += end_train_time - start_train_time
        
        # Compute value after one step of training
        with th.no_grad():
            self.policy.set_training_mode(False)    # eval mode
            
            # four-room env has no initial_states attribute
            # instead, env.initial
            initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
            # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
            initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
            averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
            if self.r_dim > 1:
                # assert averaged_state_value.dim() == 1, "MO dim error"
                current_state_value_vector = th.unsqueeze(averaged_state_value, dim = 1)  # (r_dim,1)
            else:
                # assert averaged_state_value.dim() == 0, "SO dim error"
                current_state_value_vector = th.tensor([[averaged_state_value]])  # (1,1)

            wegh_data = self.weight.weight.data
            assert wegh_data.dim()==2, "error in w shape"
            # pdb.set_trace()
            # print(f'weight_temp device: {wegh_data.device}')  # cuda
            # print(f'current_state_value_vector device: {current_state_value_vector.device}')    # cuda
            # wegh_data = wegh_data.to(current_state_value_vector.device)     # no need
            weighted_value = th.matmul(current_state_value_vector.view(1,-1), wegh_data)
            weighted_value = weighted_value[0][0]
            # print(f'weighted_value: {weighted_value[0][0]}')



     
        
        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

        # Wandb log
        wegh = self.weight.weight
        # pdb.set_trace()
        # print(f'weight: {wegh}')


        if self.env_name == 'SUMO':
            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Value 2': current_state_value_vector[2].item(),
                'Value 3': current_state_value_vector[3].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Weight 2': wegh[2].item(),
                'Weight 3': wegh[3].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
            }
            )
        elif self.env_name == 'four-room-truncated-v0':
            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                # 'Weight 2': wegh[2].item(),
            }
            )
        elif self.env_name == 'deep-sea-treasure-sparse-v0':
            # with th.no_grad():
            #     init_q_value = self.q_net.forward(self.init_state_tensor)[0]  # [ac_dim]
            #     ## Masking part
            #     init_q_value[0] = float('-inf')
            #     init_q_value[2] = float('-inf')
            #     init_prob = th.softmax(init_q_value / self.ent_alpha, dim=-1)

            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                # 'Init Up': init_prob[0].item(),
                # 'Init Down': init_prob[1].item(), ###
                # 'Init Left': init_prob[2].item(),
                # 'Init Right': init_prob[3].item(), ###
            }
            )
        else:
            raise NotImplementedError
