import warnings
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union, Tuple

import numpy as np
import torch as th
# from th import Tensor
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, update_learning_rate

SelfPPO = TypeVar("SelfPPO", bound="PPO")

import pdb
import wandb
import time
import math

class PPO(OnPolicyAlgorithm):
    """
    Proximal Policy Optimization algorithm (PPO) (clip version)

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "MlpPolicy": ActorCriticPolicy,         # modified to use MO critic
        "CnnPolicy": ActorCriticCnnPolicy,
        "MultiInputPolicy": MultiInputActorCriticPolicy,
    }

    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]], #
        env: Union[GymEnv, str], #
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.0,
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        stats_window_size: int = 100, ### neglect
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        verbose: int = 0, #
        seed: Optional[int] = None, #
        device: Union[th.device, str] = "cpu", ## default
        _init_setup_model: bool = True, ## default
        r_dim: int = 1, ### Newly added
        r_dim_wise_normalize: bool = False, ### Newly added
        env_name: Optional[str] = None
    ):
        super().__init__(
            policy,
            env,
            learning_rate=learning_rate,
            n_steps=n_steps,
            gamma=gamma,
            gae_lambda=gae_lambda,
            ent_coef=ent_coef,
            vf_coef=vf_coef,
            max_grad_norm=max_grad_norm,
            use_sde=use_sde,
            sde_sample_freq=sde_sample_freq,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            device=device,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(
                spaces.Box,
                spaces.Discrete,
                spaces.MultiDiscrete,
                spaces.MultiBinary,
            ),
            r_dim=r_dim,
            env_name=env_name
        )

        # Sanity check, otherwise it will lead to noisy gradient and NaN
        # because of the advantage normalization
        if normalize_advantage:
            assert (
                batch_size > 1
            ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440"

        if self.env is not None:
            # Check that `n_steps * n_envs > 1` to avoid NaN
            # when doing advantage normalization
            buffer_size = self.env.num_envs * self.n_steps
            assert buffer_size > 1 or (
                not normalize_advantage
            ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}"
            # Check that the rollout buffer size is a multiple of the mini-batch size
            untruncated_batches = buffer_size // batch_size
            if buffer_size % batch_size > 0:
                warnings.warn(
                    f"You have specified a mini-batch size of {batch_size},"
                    f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`,"
                    f" after every {untruncated_batches} untruncated mini-batches,"
                    f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n"
                    f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n"
                    f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})"
                )
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.clip_range = clip_range
        self.clip_range_vf = clip_range_vf
        self.normalize_advantage = normalize_advantage
        self.target_kl = target_kl

        ### newly added
        self.r_dim = r_dim
        self.r_dim_wise_normalize = r_dim_wise_normalize

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        super()._setup_model()

        # Initialize schedules for policy/value clipping
        self.clip_range = get_schedule_fn(self.clip_range)
        if self.clip_range_vf is not None:
            if isinstance(self.clip_range_vf, (float, int)):
                assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"

            self.clip_range_vf = get_schedule_fn(self.clip_range_vf)

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        if self.r_dim > 1: ## find argmin index w.r.t. expected critic
            with th.no_grad():
                initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
                # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
                initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
                averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
                scalarized_index = th.argmin(averaged_state_value).item()

        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) ## Current parameterized ones

                if self.r_dim == 1:
                    values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages # [batch_size, r_dim]

                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                # We follow the Fair-RL Code using normalization over batch 'and' r_dim
                # Note that normalization is conducted for every sampled_batch of size 'self.batch_size'
                if self.normalize_advantage and len(advantages) > 1:
                    if self.r_dim_wise_normalize:
                        advantages = (advantages - advantages.mean(axis=0)) / (advantages.std(axis=0) + 1e-8)
                    else: # default
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                if self.r_dim > 1:
                    advantages = advantages[:, scalarized_index] # [batch_size]

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob) # policy prob value of which the policy is used for generating current data

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None: ##### Here!
                    # No clipping
                    values_pred = values
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred) # rollout_data.returns = A(s,a) by gae + V(s) from value network 
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else: #### Here!
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

    def learn(
        self: SelfPPO,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        tb_log_name: str = "PPO",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfPPO:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )


class Weight(th.nn.Module):
    def __init__(self,
                 r_dim: int=2, # r_dim = K in the paper
                 initialize: Union[str, List[float]]='uniform',
                 use_ci = True,
                 use_md = True,
                 corr = 'pearson_corr',
                 reg = 'kl2',
                 eta = 1.0,
                 lam = 1.0,
                 beta = 0.5,
                 lr = 1e-3,
                 device = th.device("cpu"),
        ):
        super(Weight, self).__init__()

        self.use_ci = use_ci
        self.use_md = use_md
        self.corr = corr
        self.reg = reg
        self.eta = eta
        self.lam = lam
        self.beta = beta
        self.r_dim = r_dim
        self.lr = lr
        self.device = device

        self.w_update_cnt = 0
        # Declare a trainable parameter
        if initialize == 'uniform':
            self.weight = th.nn.Parameter(th.full((r_dim,1), 1/r_dim, dtype=th.float32, device=self.device))  # Example: a 1D tensor with 1/r_dim values
            print("Weight", self.weight)
        elif initialize == 'dirichlet':
            # random sampling
            dirichlet_distribution = th.distributions.dirichlet.Dirichlet(th.ones(r_dim, dtype=th.float32)) # flat
            samples = dirichlet_distribution.sample().to(self.device) #[r_dim]
            self.weight = th.nn.Parameter(samples.unsqueeze(dim=-1))
            print("Weight", self.weight)
        elif isinstance(initialize, list):
            self.weight = th.nn.Parameter(th.tensor(initialize, dtype=th.float32, device=self.device).unsqueeze(dim=-1))
            print("Weight", self.weight)
        else:
            raise NotImplementedError
        self.w_avg = self.weight.data.clone()

        self.matrix = th.zeros(r_dim, r_dim, device=self.device)
        for i in range(r_dim):
            for j in range(i + 1):
                self.matrix[i, j] = 1.0 / (i + 1)
        self.intercept = self.matrix[:,0].unsqueeze(dim=-1)  # [r_dim, 1]

        self.zero_vector = th.zeros(r_dim, device=self.device)
    
    def proj_on_simplex(self, x):   # [r_dim,1] -> [r_dim,1]
        x = x.squeeze(dim=-1)
        sorted_x, _ = th.sort(x, descending=True) # [r_dim, ]
        sorted_x = sorted_x.unsqueeze(dim=-1) # [r_dim, 1]
        criterion = sorted_x - th.matmul(self.matrix, sorted_x) + self.intercept # [r_dim, 1]
        threshold_idx = th.sum(criterion > 0).item() # range 1 to r_dim
        lmbda = criterion[threshold_idx-1] - sorted_x[threshold_idx-1] # [1,]
        projected_x = th.max(x + lmbda, self.zero_vector).unsqueeze(dim=-1)
        # if not th.sum(projected_x).item() == 1.0:  # if numerical error exists
        #     projected_x /= th.sum(projected_x)
        return projected_x
        

    def step(self, value_vector, corr_vector):     
        # Update the weights using softmax, instead of PGD
        self.w_update_cnt += 1
        assert corr_vector.shape[0] == self.r_dim, "corr_vector must be a column vector"
        assert len(corr_vector.shape) == 2
        with th.no_grad():
            value_vector = value_vector.to(self.device)
            corr_vector = corr_vector.to(self.device)
            if self.use_ci and self.use_md:
                c1 = self.lam*self.eta/(self.lam+self.eta)
                c2 = self.eta/(self.lam+self.eta)
                c3 = self.lam/(self.lam+self.eta)                
                if self.reg == 'kl2':
                    # print(f'value_vector: {value_vector}')
                    # print(f'self.weight.data: {self.weight.data}')
                    # print(f'corr_vector: {corr_vector}')
                    self.weight.data = th.softmax(- c1 * value_vector + c2 * th.log(self.weight.data + 1e-8) + c3 * th.log(corr_vector + 1e-8), dim=0)
                    # print(f'self.weight.data: {self.weight.data}')
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- c1 * value_vector + (c2-1) * self.weight.data + c3 * corr_vector)/(c1 + 1e-8)  # befor projection
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif self.use_ci and not self.use_md:   # lam==infty case
                if self.reg == 'kl2':
                    self.weight.data = th.softmax(- self.eta * value_vector + th.log(corr_vector + 1e-8), dim=0)
                elif self.reg == 'quad2':
                    self.weight.data -= self.lr * (- value_vector + corr_vector/(self.eta+1e-8) - self.weight.data/(self.eta+1e-8))
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            elif not self.use_ci and self.use_md:   # eta==infty case
                if self.reg == 'kl2':
                    self.weight.data = th.softmax(- self.lam * value_vector + th.log(self.weight.data + 1e-8), dim=0)
                elif self.reg == 'quad2':
                    self.weight.data -= - self.lr * value_vector
                    self.weight.data = self.proj_on_simplex(self.weight.data)
                else:
                    raise NotImplementedError
            else:   # with beta*H(w) only
                self.weight.data = th.softmax(- value_vector / (self.beta + 1e-8), dim=0) # [r_dim,1]
            
            self.w_avg = (self.weight.data + self.w_update_cnt * self.w_avg) / (self.w_update_cnt + 1)

    def forward(self, input):
        # input: K-dim reward vector
        assert len(input.shape) == 2 # [1, r_dim]
        output = th.matmul(input, self.weight)
        return output 

class MaxminPPO(PPO):
    """ 
    Maxmin Proximal Policy Optimization algorithm (MaxminPPO) (clip version)

    TBE below

    Paper: https://arxiv.org/abs/1707.06347
    Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/)
    https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail and
    Stable Baselines (PPO2 from https://github.com/hill-a/stable-baselines)

    Introduction to PPO: https://spinningup.openai.com/en/latest/algorithms/ppo.html

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param n_steps: The number of steps to run for each environment per update
        (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
        NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization)
        See https://github.com/pytorch/pytorch/issues/29372
    :param batch_size: Minibatch size
    :param n_epochs: Number of epoch when optimizing the surrogate loss
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
    :param clip_range: Clipping parameter, it can be a function of the current progress
        remaining (from 1 to 0).
    :param clip_range_vf: Clipping parameter for the value function,
        it can be a function of the current progress remaining (from 1 to 0).
        This is a parameter specific to the OpenAI implementation. If None is passed (default),
        no clipping will be done on the value function.
        IMPORTANT: this clipping depends on the reward scaling.
    :param normalize_advantage: Whether to normalize or not the advantage
    :param ent_coef: Entropy coefficient for the loss calculation
    :param vf_coef: Value function coefficient for the loss calculation
    :param max_grad_norm: The maximum value for the gradient clipping
    :param use_sde: Whether to use generalized State Dependent Exploration (gSDE)
        instead of action noise exploration (default: False)
    :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
        Default: -1 (only sample at the beginning of the rollout)
    :param target_kl: Limit the KL divergence between updates,
        because the clipping is not enough to prevent large update
        see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
        By default, there is no limit on the kl div.
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    :param ent_coef_weight: Entropy coefficient for weight(i.e. adversary player)
    :param n_init_states: after train, when computing \sum_s v(s)\mu(s), replace with \sum_{i=1}^n_init_states v(s)/n_init_states
    """
    def __init__(
        ## Add new variables for maxmin ppo (TBA) -> maybe vars about w?
        self,
        policy: Union[str, Type[ActorCriticPolicy]], #
        env: Union[GymEnv, str], #
        learning_rate: Union[float, Schedule] = 3e-4,
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.1, # default: 0.0
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        stats_window_size: int = 100, ### neglect
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        verbose: int = 0, #
        seed: Optional[int] = None, #
        device: Union[th.device, str] = "cpu", ## default
        _init_setup_model: bool = True, ## default
        r_dim: int = 1, ### Newly added
        r_dim_wise_normalize: bool = False, ### Newly added
        env_name: Optional[str] = None, ### Newly added
        ### Newly added
        weight_initialize: Union[str, List[float]]='uniform',
        ent_coef_weight: float = 0.1,
        # n_init_states: int = 10,  # seems not needed, since already sampling n_env init states 
        ### Newly added for CIMD
        use_ci = True,
        use_md = True,
        corr = 'pearson_corr',
        reg = 'kl2',
        ci_coef = 2.0,  # 2->beta=0.5, 10->beta=0.1
        md_coef = 1.0,
    ) -> None:
        super().__init__(  # variables for class PPO 
            policy = policy, #
            env = env, #
            learning_rate = learning_rate,
            n_steps = n_steps, #
            batch_size = batch_size, #
            n_epochs = n_epochs, #
            gamma = gamma, ## default
            gae_lambda = gae_lambda, ## default
            clip_range = clip_range, ## default
            clip_range_vf = clip_range_vf, ## default
            normalize_advantage = normalize_advantage, ## default
            ent_coef = ent_coef,
            vf_coef = vf_coef, ## default
            max_grad_norm = max_grad_norm, ## default
            use_sde = use_sde, ## default
            sde_sample_freq = sde_sample_freq, ## default
            target_kl = target_kl, ## default
            stats_window_size = stats_window_size, ### neglect
            tensorboard_log = tensorboard_log, ## default
            policy_kwargs = policy_kwargs, ## default
            verbose = verbose, #
            seed = seed, #
            device = device, ## default
            _init_setup_model = _init_setup_model, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = r_dim_wise_normalize, ### Newly added
            env_name = env_name     ### Newly added
        )

        ### Add new parameters: weight
        self.weight = Weight(r_dim=r_dim, initialize=weight_initialize, use_ci=use_ci, use_md=use_md, corr=corr, reg=reg, eta=ci_coef, lam=md_coef, beta=ent_coef_weight, lr=learning_rate)
        # self.alpha = ent_coef
        self.beta = ent_coef_weight
        self.use_ci = use_ci
        if self.use_ci:
            self.corr = corr
        # self.eta = ci_coef
        # self.md_lambda = md_coef
        # self.n_init_states = n_init_states

        ## wandb log
        init_w_dict = {}
        for i in range(self.r_dim):
            init_w_dict[f"Weight {i}"] = self.weight.weight.data[i].item()
            init_w_dict[f"Avg_Weight {i}"] = self.weight.w_avg[i].item()
        wandb.log(init_w_dict)

        self.time = 0.0

        # try:
        #     env_name = env.spec.id
        # except:
        #     env_name = 'SUMO'

        if env_name == 'DST':
            env_name = 'deep-sea-treasure-sparse-v0'
        elif env_name == 'Four-room':
            env_name = 'four-room-truncated-v0'
        elif env_name == 'reacher':
            env_name = 'mo-reacher-v4'
        elif env_name == 'traffic' or env_name == 'traffic-big' or env_name == 'traffic-asym':
            env_name = 'SUMO'
        else:
            raise Exception("Invalid Env Name")
        self.env_name = env_name

        ### Environments

        # if env_name == 'mo-reacher-v4':
        #     init_angle = np.array([0, 3.1415 / 2])
        #     init_state = np.concatenate([
        #         np.cos(init_angle),
        #         np.sin(init_angle),
        #         np.zeros(2)])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        # elif env_name == 'mo-mountaincar-v0': # for revised version
        #     init_state = np.array([-0.5, 0])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        if env_name == 'four-room-truncated-v0':
            self.init_state_tensor = th.tensor([6] + [0 for _ in range(5)]).unsqueeze(0)
            # self.init_state_tensor = th.tensor([12] + [0 for _ in range(13)]).unsqueeze(0)
        elif env_name == 'deep-sea-treasure-sparse-v0':
            init_state = np.array([0, 0], dtype=np.int32)
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        elif env_name == 'SUMO':
            # self.init_state_tensor = th.tensor([1.] + [0. for _ in range(20)]).unsqueeze(0) ### for 2-way-intersection
            self.init_state_tensor = th.tensor([1.] + [0. for _ in range(36)]).unsqueeze(0)  ### for big-intersection
        else:
            raise NotImplementedError
        
        print(f"use correlation: {self.use_ci}")
        print(f"use mirror descent: {use_md}")

    
    def train(self, batch_size: int = 100) -> None:  # one step of training for maxmin_PPO
        """
        ppo_gradient_steps: int = 1: how many times of ppo update, default: 1. 
        if this value is sufficiently large, the training can be seen as exact SBR with one timescale 
        """
        start_train_time = time.time()

        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update optimizer learning rate
        self._update_learning_rate(self.policy.optimizer)
        # Compute current clip range
        clip_range = self.clip_range(self._current_progress_remaining)  # type: ignore[operator]
        # Optional: clip range for the value function
        if self.clip_range_vf is not None:
            clip_range_vf = self.clip_range_vf(self._current_progress_remaining)  # type: ignore[operator]

        entropy_losses = []
        pg_losses, value_losses = [], []
        clip_fractions = []

        continue_training = True

        # if save_value_for_weight_update = True, save K-dim value for current policy, and use this for weight update
        value_vector_for_weight_update: Optional[th.Tensor] = None

        ## Simultaneous updates for policy(agent player) and weight(adversary player)
        # w_t, theta_t -> theta_{t+1} & theta_t -> w_{t+1}    
        # train for n_epochs epochs -> TBE suit for MO 
        for epoch in range(self.n_epochs):
            approx_kl_divs = []
            # Do a complete pass on the rollout buffer
            # rollout_buffer is being keep updated in learn(on_policy_algorithm.py) 
            for rollout_data in self.rollout_buffer.get(self.batch_size):
                # Befor each (theta_t, w_t) update, prepare v^{pi_{theta_t}}, which will be used for w_t update
                with th.no_grad():
                    initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
                    # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
                    initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
                    averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
                    if self.r_dim > 1:
                        assert averaged_state_value.dim() == 1, "MO dim error"
                        value_vector_for_weight_update = th.unsqueeze(averaged_state_value, dim = 1)  # (r_dim,1)
                    else:
                        assert averaged_state_value.dim() == 0, "SO dim error"
                        value_vector_for_weight_update = th.tensor([[averaged_state_value]])  # (1,1)

                ## PPO-update
                actions = rollout_data.actions
                if isinstance(self.action_space, spaces.Discrete):
                    # Convert discrete action from float to long
                    actions = rollout_data.actions.long().flatten()

                # Re-sample the noise matrix because the log_std has changed
                if self.use_sde:
                    self.policy.reset_noise(self.batch_size)

                values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions) ## Current parameterized ones

                if self.r_dim == 1:
                    values = values.flatten()

                # Normalize advantage
                advantages = rollout_data.advantages # [batch_size, r_dim]

                # Normalization does not make sense if mini batchsize == 1, see GH issue #325
                # We follow the Fair-RL Code using normalization over batch 'and' r_dim
                # Note that normalization is conducted for every sampled_batch of size 'self.batch_size'
                if self.normalize_advantage and len(advantages) > 1:
                    if self.r_dim_wise_normalize:
                        advantages = (advantages - advantages.mean(axis=0)) / (advantages.std(axis=0) + 1e-8)
                    else: # default
                        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

                if self.r_dim > 1:
                    weight_temp = self.weight.weight.data
                    # pdb.set_trace()
                    # print(f'grad?: {weight_temp.requires_grad}')    # False
                    assert weight_temp.dim()==2, "error in w dot adv"
                    # pdb.set_trace()
                    # print(f'weight_temp device: {weight_temp.device}')  # cpu
                    weight_temp = weight_temp.to(advantages.device)
                    # print(f'advantages device: {advantages.device}')
                    # print(f'weight_temp device: {weight_temp.device}')  # cuda:0
                    advantages = th.matmul(advantages, weight_temp)
                    advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                    # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"


                # if self.r_dim > 1:
                #     weight_temp = self.weight.weight.detach()   # cpu, while advantages is on cuda:0
                #     assert weight_temp.dim()==2, "error in w dot adv"
                #     # pdb.set_trace()
                #     weight_temp = weight_temp.to(advantages.device)
                #     # print(f'advantages device: {advantages.device}')
                #     # print(f'weight_temp device: {weight_temp.device}')
                #     advantages = th.matmul(advantages, weight_temp)
                #     advantages = th.squeeze(advantages,dim=-1)  # [batch_size] # scalarized advantage
                #     # assert rollout_data.advantages.dim() == 2, "rollout_data.advantages dimenstion has changed"

                # ratio between old and new policy, should be one at the first iteration
                ratio = th.exp(log_prob - rollout_data.old_log_prob) # policy prob value of which the policy is used for generating current data

                # clipped surrogate loss
                policy_loss_1 = advantages * ratio
                policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
                policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()

                # Logging
                pg_losses.append(policy_loss.item())
                clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
                clip_fractions.append(clip_fraction)

                if self.clip_range_vf is None: ##### Here!
                    # No clipping
                    values_pred = values    # state value by current value network
                else:
                    # Clip the difference between old and new value
                    # NOTE: this depends on the reward scaling
                    values_pred = rollout_data.old_values + th.clamp(
                        values - rollout_data.old_values, -clip_range_vf, clip_range_vf
                    )
                # Value loss using the TD(gae_lambda) target
                value_loss = F.mse_loss(rollout_data.returns, values_pred) # rollout_data.returns = {A(s,a) by gae} + {V(s) from value network} 
                value_losses.append(value_loss.item())

                # Entropy loss favor exploration
                if entropy is None:
                    # Approximate entropy when no analytical form
                    entropy_loss = -th.mean(-log_prob)
                else: #### Here!
                    entropy_loss = -th.mean(entropy)

                entropy_losses.append(entropy_loss.item())

                loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

                # Calculate approximate form of reverse KL Divergence for early stopping
                # see issue #417: https://github.com/DLR-RM/stable-baselines3/issues/417
                # and discussion in PR #419: https://github.com/DLR-RM/stable-baselines3/pull/419
                # and Schulman blog: http://joschu.net/blog/kl-approx.html
                with th.no_grad():
                    log_ratio = log_prob - rollout_data.old_log_prob
                    approx_kl_div = th.mean((th.exp(log_ratio) - 1) - log_ratio).cpu().numpy()
                    approx_kl_divs.append(approx_kl_div)

                if self.target_kl is not None and approx_kl_div > 1.5 * self.target_kl:
                    continue_training = False
                    if self.verbose >= 1:
                        print(f"Early stopping at step {epoch} due to reaching max kl: {approx_kl_div:.2f}")
                    break

                ## Policy(agent player) update; w_t, theta_t -> theta_{t+1}
                # Optimization step
                self.policy.optimizer.zero_grad()
                loss.backward()
                # Clip grad norm
                th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.policy.optimizer.step()

                ## weight(adversary player); theta_t -> w_{t+1}
                with th.no_grad():
                    # define corr_vector
                    reward_buffer = self.rollout_buffer.rewards[-self.batch_size:]
                    reward_buffer = reward_buffer.reshape(-1, self.r_dim)   # [batch_size*n_envs, r_dim]
                    min_index = th.argmin(value_vector_for_weight_update).item()
                    corr_vector = np.zeros((self.r_dim,1))
                    if self.use_ci:
                        if self.corr == 'inner_product':
                            corr_vector = np.matmul(reward_buffer[:, min_index].reshape(1,-1), reward_buffer).reshape(-1,1)  # [r_dim, 1]
                            corr_vector = corr_vector / len(reward_buffer)
                        elif self.corr == 'cos_sim':
                            norms = np.linalg.norm(reward_buffer, axis=0)
                            normalize = norms * norms[min_index]
                            normalize = normalize.reshape(-1,1)
                            corr_vector = np.matmul(reward_buffer[:, min_index].reshape(1,-1), reward_buffer).reshape(-1,1)  # [r_dim, 1]
                            corr_vector = corr_vector / (normalize + 1e-8)
                        elif self.corr == 'pearson_corr':
                            # corr_vector = np.zeros((self.r_dim,1))
                            for i in range(self.r_dim):
                            #    print(f'min reward {i}: {np.min(reward_buffer[:, min_index])}, {np.min(reward_buffer[:, i])}')
                            #    print(f'max reward {i}: {np.max(reward_buffer[:, min_index])}, {np.max(reward_buffer[:, i])}')
                            #    corr_mx = np.corrcoef(reward_buffer[:, min_index], reward_buffer[:, i])   # -> this gives nan when std=0
                            #    corr_vector[i] = corr_mx[0,1]    
                                cov = np.cov(reward_buffer[:, min_index], reward_buffer[:, i])[0,1]
                                std_x = np.std(reward_buffer[:, min_index], ddof=1)
                                std_y = np.std(reward_buffer[:, i], ddof=1)
                                corr_vector[i] = np.clip(cov/(std_x*std_y + 1e-8),-1,1)
                            # print(f'corr_vector: {corr_vector}')
                        elif self.corr == 'pearson_dist':
                            # corr_vector = np.zeros((self.r_dim,1))
                            for i in range(self.r_dim):
                                cov = np.cov(reward_buffer[:, min_index], reward_buffer[:, i])[0,1]
                                std_x = np.std(reward_buffer[:, min_index], ddof=1)
                                std_y = np.std(reward_buffer[:, i], ddof=1)
                                cor = np.clip(cov/(std_x*std_y + 1e-8),-1,1)
                                corr_vector[i] = np.sqrt((1-cor)/2)
                        else:
                            raise NotImplementedError
                    corr_vector = th.softmax(th.tensor(corr_vector, dtype=th.float32), dim=0)
                        # print(f'corr_vector: {corr_vector}')
                    self.weight.step(value_vector_for_weight_update, corr_vector) # (r_dim,1)
                    # pdb.set_trace()
                    # print(f'weight: {self.weight.weight}')

            self._n_updates += 1
            if not continue_training:
                break

        explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten())

        end_train_time = time.time()
        self.time += end_train_time - start_train_time
        
        # Compute value after one step of training  # not important
        with th.no_grad():
            self.policy.set_training_mode(False)    # eval mode
            
            # four-room env has no initial_states attribute
            # instead, env.initial
            initial_state = self.env.initial_states()  ### parallel inital state due to env = SubprocVecEnv # (1, n_env, ob_dim)
            # assert initial_state.shape[0] == 1 ### only for traffic intersection. For MO-Mujoco, revise subproc_vec_env.py
            initial_state_value = self.policy.predict_values(obs_as_tensor(np.squeeze(initial_state, axis=0), self.device))  # (n_env,r_dim)
            averaged_state_value = th.mean(initial_state_value, axis=0)  # average value of initial state # (r_dim,)
            if self.r_dim > 1:
                # assert averaged_state_value.dim() == 1, "MO dim error"
                current_state_value_vector = th.unsqueeze(averaged_state_value, dim = 1).to(self.device)  # (r_dim,1)
            else:
                # assert averaged_state_value.dim() == 0, "SO dim error"
                current_state_value_vector = th.tensor([[averaged_state_value]]).to(self.device)  # (1,1)

            wegh_data = self.weight.weight.data
            assert wegh_data.dim()==2, "error in w shape"
            # pdb.set_trace()
            # print(f'weight_temp device: {wegh_data.device}')  # cuda
            # print(f'current_state_value_vector device: {current_state_value_vector.device}')    # cuda
            # wegh_data = wegh_data.to(current_state_value_vector.device)     # no need
            weighted_value = th.matmul(current_state_value_vector.view(1,-1), wegh_data)
            weighted_value = weighted_value[0][0]
            # print(f'weighted_value: {weighted_value[0][0]}')



     
        
        # Logs
        self.logger.record("train/entropy_loss", np.mean(entropy_losses))
        self.logger.record("train/policy_gradient_loss", np.mean(pg_losses))
        self.logger.record("train/value_loss", np.mean(value_losses))
        self.logger.record("train/approx_kl", np.mean(approx_kl_divs))
        self.logger.record("train/clip_fraction", np.mean(clip_fractions))
        self.logger.record("train/loss", loss.item())
        self.logger.record("train/explained_variance", explained_var)
        if hasattr(self.policy, "log_std"):
            self.logger.record("train/std", th.exp(self.policy.log_std).mean().item())

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/clip_range", clip_range)
        if self.clip_range_vf is not None:
            self.logger.record("train/clip_range_vf", clip_range_vf)

        # Wandb log
        wegh = self.weight.weight
        # pdb.set_trace()
        # print(f'weight: {wegh}')


        if self.env_name == 'SUMO':
            wandb_dict = {}
            weight_log = self.weight.weight.data
            weight_avg_log = self.weight.w_avg
            for i in range(self.r_dim):
                wandb_dict[f"Weight {i}"] = weight_log[i].item()
                wandb_dict[f"Avg_Weight {i}"] = weight_avg_log[i].item()
                wandb_dict[f'Corr {i}'] = corr_vector[i].item()
                wandb_dict[f"Value {i}"] = current_state_value_vector[i].item()
            wandb_dict["loss"] = loss.item()
            wandb_dict["Total Mean V(mu)"] = th.mean(current_state_value_vector).item()
            wandb_dict["Total Mean Q"] = weighted_value
            wandb_dict["Total train time"] = self.time
            wandb.log(wandb_dict)

            # wandb.log({
            #     'loss': loss.item(),
            #     # 'Total Mean Q': th.mean(current_state_value_vector).item(),
            #     'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
            #     'Value 0': current_state_value_vector[0].item(),
            #     'Value 1': current_state_value_vector[1].item(),
            #     'Value 2': current_state_value_vector[2].item(),
            #     'Value 3': current_state_value_vector[3].item(),
            #     'Weight 0': wegh[0].item(),
            #     'Weight 1': wegh[1].item(),
            #     'Weight 2': wegh[2].item(),
            #     'Weight 3': wegh[3].item(),
            #     'Total Mean Q': weighted_value,
            #     'Total train time': self.time,
            # }
            # )
        elif self.env_name == 'four-room-truncated-v0':
            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                # 'Weight 2': wegh[2].item(),
            }
            )
        elif self.env_name == 'deep-sea-treasure-sparse-v0':
            # with th.no_grad():
            #     init_q_value = self.q_net.forward(self.init_state_tensor)[0]  # [ac_dim]
            #     ## Masking part
            #     init_q_value[0] = float('-inf')
            #     init_q_value[2] = float('-inf')
            #     init_prob = th.softmax(init_q_value / self.ent_alpha, dim=-1)

            wandb.log({
                'loss': loss.item(),
                # 'Total Mean Q': th.mean(current_state_value_vector).item(),
                'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                'Value 0': current_state_value_vector[0].item(),
                'Value 1': current_state_value_vector[1].item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total Mean Q': weighted_value,
                'Total train time': self.time,
                # 'Init Up': init_prob[0].item(),
                # 'Init Down': init_prob[1].item(), ###
                # 'Init Left': init_prob[2].item(),
                # 'Init Right': init_prob[3].item(), ###
            }
            )
        else:
            raise NotImplementedError







class AblationPPO(PPO):
    """
    framework from prior work, but use PPO instead of SQL to calculate Q^*_{w_current/w_perturbed}
    Default hyperparameters are taken from the DQN paper,
    except for the optimizer and learning rate that were taken from Stable Baselines defaults.

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param target_update_interval: update the target network every ``target_update_interval``
        environment steps.
    :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
    :param exploration_initial_eps: initial value of random action probability
    :param exploration_final_eps: final value of random action probability
    :param max_grad_norm: The maximum value for the gradient clipping
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """
    def __init__(
        self,
        ### For perturbation 
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str], ###$
        learning_rate: Union[float, Schedule] = 1e-4, ###$
        buffer_size: int = 1_000_000,  ###$
        learning_starts: int = 50000, ###$
        train_freq: Union[int, Tuple[int, str]] = 4, ###$
        target_update_interval: int = 10000, ###$
        tau: float = 0.001, ###$
        exploration_initial_eps: float = 1.0, ###$
        exploration_final_eps: float = 0.05, ###$
        exploration_fraction: float = 0.1, ###$
        verbose: int = 0, ###$
        seed: Optional[int] = None, ###$
        r_dim: int = 1, ###$
        r_dim_policy: int = 1, ###$
        ent_alpha: float = 0.1, ###$
        weight_decay: float = 0,  ###$ default is 0
        ########### perturbation parameters
        #### w update threshold
        soft_q_init_fraction: float = 0.1, #$ 0.1
        #### w update params
        perturb_w_learning_rate: Union[float, Schedule] = 1e-4, #$ Maybe require Scheduler
        #### w grad calculation $
        period_cal_w_grad: int = 1,
        perturb_q_copy_num: int = 20,  # N_p > r_dim + 1
        perturb_std_dev: float = 0.01,
        ### perturb q update
        perturb_q_learning_rate: Union[float, Schedule] = 1e-4, #  Put in Adam optimizer. We think constant lr is OK.
        perturb_grad_step: int = 1,
        perturb_q_batch_size: int = 32, # set as 32
        ### Main Q update after initialization phase - gradient steps
        q_grad_st_after_init: int = 1,  ###
        explicit_w_input: bool = False,
        weight_initialize: Union[str, List[float]]='uniform',
        # period_main_grad: int=1,
        w_schedule_option: str = 'sqrt_inverse',
        stats_window_size: int = 100,
        ## alpha scheduling for SQL variants
        ent_alpha_act_init: float = 0.5,
        annealing_step: int = 10000,
        env_name: Optional[str] = None,
        device: Union[th.device, str] = "cpu",

        ### For PPO
        n_steps: int = 2048, #
        batch_size: int = 64, #
        n_epochs: int = 10, #
        gamma: float = 0.99, ## default
        gae_lambda: float = 0.95, ## default
        clip_range: Union[float, Schedule] = 0.2, ## default
        clip_range_vf: Union[None, float, Schedule] = None, ## default
        normalize_advantage: bool = True, ## default
        ent_coef: float = 0.1, # default: 0.0
        vf_coef: float = 0.5, ## default
        max_grad_norm: float = 0.5, ## default
        use_sde: bool = False, ## default
        sde_sample_freq: int = -1, ## default
        target_kl: Optional[float] = None, ## default
        tensorboard_log: Optional[str] = None, ## default
        policy_kwargs: Optional[Dict[str, Any]] = None, ## default
        _init_setup_model: bool = True, ## default
        r_dim_wise_normalize: bool = False, ### Newly added
        ### Newly added
        ent_coef_weight: float = 0.1,

    ) -> None:
        super().__init__( # Call the __init__ method of the parent class(PPO)
            policy = policy, #
            env = env, #
            learning_rate = learning_rate,
            n_steps = n_steps, #
            batch_size = batch_size, #
            n_epochs = n_epochs, #
            gamma = gamma, ## default
            gae_lambda = gae_lambda, ## default
            clip_range = clip_range, ## default
            clip_range_vf = clip_range_vf, ## default
            normalize_advantage = normalize_advantage, ## default
            ent_coef = ent_coef,
            vf_coef = vf_coef, ## default
            max_grad_norm = max_grad_norm, ## default
            use_sde = use_sde, ## default
            sde_sample_freq = sde_sample_freq, ## default
            target_kl = target_kl, ## default
            stats_window_size = stats_window_size, ### neglect
            tensorboard_log = tensorboard_log, ## default
            policy_kwargs = policy_kwargs, ## default
            verbose = verbose, #
            seed = seed, #
            device = device, ## default
            _init_setup_model = _init_setup_model, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = r_dim_wise_normalize, ### Newly added
            env_name = env_name
        )

        ### Add new parameters: weight
        assert perturb_q_copy_num > r_dim + 1 ## For linear regression. np.linalg.det(X.T @ X) deviates from 0 if perturb_q_copy_num increases

        # self.weight = Weight(r_dim=r_dim)
        self.weight = Weight(r_dim=r_dim, initialize=weight_initialize, device=self.device)
        # pdb.set_trace()
        # print(f'current weight: {self.weight.weight}')
        self.ent_alpha = ent_alpha

        self.time = 0.0

        # try:
        #     env_name = env.spec.id
        # except:
        #     env_name = 'SUMO'

        if env_name == 'DST':
            env_name = 'deep-sea-treasure-sparse-v0'
        elif env_name == 'Four-room':
            env_name = 'four-room-truncated-v0'
        elif env_name == 'reacher':
            env_name = 'mo-reacher-v4'
        elif env_name == 'traffic' or env_name == 'traffic-big' or env_name == 'traffic-asym':
            env_name = 'SUMO'
        else:
            raise Exception("Invalid Env Name")
        self.env_name = env_name

        # if env_name == 'mo-reacher-v4':
        #     init_angle = np.array([0, 3.1415 / 2])
        #     init_state = np.concatenate([
        #         np.cos(init_angle),
        #         np.sin(init_angle),
        #         np.zeros(2)])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        # elif env_name == 'mo-mountaincar-v0': # for revised version
        #     init_state = np.array([-0.5, 0])
        #     self.init_state_tensor = th.tensor(init_state).unsqueeze(0)
        if env_name == 'four-room-truncated-v0':
            self.init_state_tensor = th.tensor([6] + [0 for _ in range(5)]).unsqueeze(0).to(self.device)
            # self.init_state_tensor = th.tensor([12] + [0 for _ in range(13)]).unsqueeze(0)
        elif env_name == 'deep-sea-treasure-sparse-v0':
            init_state = np.array([0, 0], dtype=np.int32)
            self.init_state_tensor = th.tensor(init_state).unsqueeze(0).to(self.device)
        elif env_name == 'SUMO':
            # self.init_state_tensor = th.tensor([1.] + [0. for _ in range(20)]).unsqueeze(0) ### for 2-way-intersection
            self.init_state_tensor = th.tensor([1.] + [0. for _ in range(36)]).unsqueeze(0).to(self.device)  ### for big-intersection
        else:
            raise NotImplementedError

        ### Params for perturbation
        self.soft_q_init_fraction = soft_q_init_fraction
        self.perturb_std_dev = perturb_std_dev         ### std dev for Gaussian noise
        self.perturb_q_copy_num = perturb_q_copy_num
        self.sqrt_q_copy_num = round(math.sqrt(self.perturb_q_copy_num),2)
        self.period_cal_w_grad = period_cal_w_grad ### Period for gradient calculation

        self.perturb_q_batch_size = perturb_q_batch_size
        self.perturb_w_learning_rate = perturb_w_learning_rate
        self.perturb_q_learning_rate = perturb_q_learning_rate
        self.perturb_grad_step = perturb_grad_step
        self.q_grad_st_after_init = q_grad_st_after_init

        ### w_lr, perturb_q_lr scheduling - function
        # self.w_lr_schedule = get_schedule_fn(self.perturb_w_learning_rate)
        self.perturb_q_lr_schedule = get_schedule_fn(self.perturb_q_learning_rate)

        self.perturb_q_net_list = [self.policy.make_q_net() for _ in range(self.perturb_q_copy_num)]

        self.combined_parameters = []
        for perturb_q_net in self.perturb_q_net_list:
            self.combined_parameters.extend(perturb_q_net.parameters())
        self.perturb_q_optimizer = th.optim.Adam(self.combined_parameters, lr=self.perturb_q_lr_schedule(1))

        for perturb_q_net in self.perturb_q_net_list:
            perturb_q_net.set_training_mode(False) # Temporary

        ### Target Q
        self.perturb_q_net_target = self.policy.make_q_net()
        self.perturb_q_net_target.set_training_mode(False) # Permanent

        ### Calculated gradient of w
        self.w_grad = th.zeros(r_dim, dtype=th.float32)

        # ## Main Q periodic update after initialization.
        # self.period_main_grad = period_main_grad
        # print("self.period_main_grad", self.period_main_grad)
        # print()

        # ## w scheduling option
        self.w_schedule_option = w_schedule_option

    def update_perturb_q_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.optim.Optimizer]) -> None:
        if not isinstance(optimizers, list):
            optimizers = [optimizers]
        for optimizer in optimizers:
            update_learning_rate(optimizer, self.perturb_q_lr_schedule(self._current_progress_remaining))

    ### Overriding train function. Here, we should use both (i) soft-q learning and (ii) weight update.
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        ### w update begins
        # if ( ( self.num_timesteps >= int(self.soft_q_init_fraction*self._total_timesteps) )
        #         and ( self.num_timesteps < self._total_timesteps ) ): ### Initialize soft q-learning
        start_train_time = time.time()

        # details for below:
        # if current timestep(= self.num_timesteps) is sufficiently large, then
        # 1) do w-update (proj GD) (do in every timestep. this is the main. inner Q is just for calculating grad_w)
        #   -> light part
        # 2) periodically, newly calculate grad_w (Q-copies, lin-reg)
        #   -> heavy part
        # After if, update Q_main (std sql) for several times
        if self.num_timesteps >= int(self.soft_q_init_fraction * self._total_timesteps):
            if self.num_timesteps%self.period_cal_w_grad == 0: ## calculate projected gradient of w

                ## Perturb weight
                perturbed_weight = self.weight.weight + th.randn(self.weight.weight.shape[0], self.perturb_q_copy_num, device=self.device) * self.perturb_std_dev  # [r_dim, q_copy_num]
                # perturbed_weight = th.clamp(perturbed_weight, min=0.0)  # nonnegative

                ## Make copies of Q_net
                for perturb_q_net in self.perturb_q_net_list:
                    perturb_q_net.set_training_mode(True)
                    perturb_q_net.load_state_dict(self.q_net.state_dict())

                ## Make a copy of Target Q_net
                self.perturb_q_net_target.load_state_dict(self.q_net_target.state_dict())

                # Schedule Optimizer
                self.update_perturb_q_learning_rate(self.perturb_q_optimizer)

                # Update Q_target (i.e. Q in target y)
                for _ in range(self.perturb_grad_step):
                    # Sample replay buffer - For now we only give perturbation on weight, so we fix sampled buffer. (difference is from only perturbation)
                    replay_data = self.replay_buffer.sample(self.perturb_q_batch_size,env=self._vec_normalize_env)  # type: ignore[union-attr]

                    # target value calculation.
                    with th.no_grad():
                        # Compute the next Q-values using the target network
                        next_q_values = self.perturb_q_net_target.forward(replay_data.next_observations)  # [perturb_q_batch_size,ac_dim]
                        ## soft-q Target
                        next_q_values = self.ent_alpha * th.logsumexp(next_q_values / self.ent_alpha, dim=-1,keepdim=True)  # [perturb_q_batch_size,1]
                        # 1-step Soft TD target
                        target_q_values = th.matmul(replay_data.rewards, perturbed_weight) + (1 - replay_data.dones) * self.gamma * next_q_values

                    ### Ver 3. Single process
                    current_q_values_p_list = []
                    for perturb_q_net in self.perturb_q_net_list:
                        # Get current Q-values estimates
                        current_q_values_p = perturb_q_net.forward(replay_data.observations)  # [perturb_q_batch_size,ac_dim]
                        current_q_values_p_list.append(current_q_values_p)

                    current_q_values_p_list = th.stack(current_q_values_p_list, dim=2) # [perturb_q_batch_size, ac_dim, q_copy_num]=[32,4,20]
                    # Retrieve the q-values for the actions from the replay buffer
                    idx = replay_data.actions.long().unsqueeze(-1)
                    idx = idx.repeat(1,1,self.perturb_q_copy_num) # [perturb_q_batch_size,1,self.perturb_q_copy_num]
                    current_q_values = th.gather(current_q_values_p_list, dim=1, index=idx).squeeze(dim=1)  # [perturb_q_batch_size, q_copy_num]

                    # Compute Huber loss (less sensitive to outliers)
                    loss = F.smooth_l1_loss(current_q_values, target_q_values)
                    # loss = F.mse_loss(current_q_values_p, target_q_values)

                    # Optimize the Q
                    self.perturb_q_optimizer.zero_grad()
                    loss.backward()
                    # Clip gradient norm
                    th.nn.utils.clip_grad_norm_(self.combined_parameters, self.sqrt_q_copy_num*self.max_grad_norm)
                    self.perturb_q_optimizer.step()

                with th.no_grad():
                    ## Calculate L(w+eps)
                    # To conduct lin-reg, we need Q_{w + eps_i}: no need update, but compute the value
                    perturb_q_list = []
                    for perturb_q_net in self.perturb_q_net_list:
                        # Switch to eval mode (this affects batch norm / dropout)
                        perturb_q_net.set_training_mode(False) ## kind of "Discarding"
                        new_perturb_q = perturb_q_net.forward(self.init_state_tensor)  # [1,ac_dim]

                        # #### Masking part for DST only
                        # if self.env_name == 'deep-sea-treasure-sparse-v0':
                        #     assert new_perturb_q.shape[0] == 1
                        #     new_perturb_q[0,0] = float('-inf')
                        #     new_perturb_q[0,2] = float('-inf')
                        # ####

                        new_perturb_q = self.ent_alpha * th.logsumexp(new_perturb_q / self.ent_alpha, dim=-1)  # [1,]
                        perturb_q_list.append(new_perturb_q.item())

                    ### Conduct linear regression
                    X = perturbed_weight.transpose(0, 1).to(self.device)  # [q_copy_num, r_dim]
                    y = th.tensor(perturb_q_list, dtype=th.float32, device=self.device) # [q_copy_num,]
                    # Add a column of ones for the intercept
                    ones_column = th.ones(X.size(0), dtype=X.dtype, device=self.device) # [q_copy_num,]
                    X = th.cat((ones_column.unsqueeze(1), X), dim=1) # [q_copy_num, r_dim+1]
                    try:
                        coefficients = th.linalg.inv(X.T @ X) @ X.T @ y # [r_dim+1,]
                    except:
                        coefficients = th.linalg.pinv(X.T @ X) @ X.T @ y

                    ### Update linear weight
                    self.w_grad = coefficients[1:] # [r_dim,]
                    # pdb.set_trace()
                    # print(f'w_grad: {self.w_grad}')

                    del perturb_q_list, X, y

            ### Now we update parameter using proj_grad
            current_timestep_w = self.num_timesteps - int(self.soft_q_init_fraction*self._total_timesteps) + 1
            # self.weight.step(lr=self.perturb_w_learning_rate / math.sqrt(current_timestep_w), grad=self.w_grad)
            if self.w_schedule_option == 'sqrt_inverse':
                # pdb.set_trace()
                # print(f'weight before update: {self.weight.weight.data}')
                self.weight.step(lr=self.perturb_w_learning_rate / math.sqrt(current_timestep_w), grad=self.w_grad)
            elif self.w_schedule_option == 'inverse':
                self.weight.step(lr=self.perturb_w_learning_rate / current_timestep_w, grad=self.w_grad)
            elif self.w_schedule_option == 'linear': # for now, we use 0.1 end ratio
                max_step_w = int( (1-self.soft_q_init_fraction)*self._total_timesteps )
                self.weight.step(lr=self.perturb_w_learning_rate *
                                    ( 1 + (0.1-1)*(current_timestep_w-1)/max_step_w   ), grad=self.w_grad)
            else:
                raise NotImplementedError
            

            # self.weight.step(lr=self.w_lr_schedule(self._current_progress_remaining), grad=self.w_grad)

            ### After init phase, we increase gradient steps
            gradient_steps = self.q_grad_st_after_init


        #### Main Q update
        # if self.num_timesteps < int(self.soft_q_init_fraction * self._total_timesteps) or self.num_timesteps%self.period_main_grad == 0:

        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        ### Main Q update begins
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size,env=self._vec_normalize_env)  # type: ignore[union-attr]

            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target.forward(replay_data.next_observations)  # [32,ac_dim] if r_dim_policy=1 (not r_dim)

                # ### DST Only Begins
                # if self.env_name == 'deep-sea-treasure-sparse-v0':
                #     zero_obs_indices = th.all(replay_data.next_observations == 0, dim=1).nonzero().squeeze()
                #     next_q_values[zero_obs_indices, 0] = float('-inf')
                #     next_q_values[zero_obs_indices, 2] = float('-inf')
                # ### DST Only Ends

                ## soft-q
                next_q_values = self.ent_alpha * th.logsumexp(next_q_values / self.ent_alpha, dim=-1,
                                                              keepdim=True)  # [32,1]

                # 1-step TD target
                # target_q_values: [32,1] // replay_data.rewards: [32, r_dim], replay_data.dones: [32,1]
                target_q_values = self.weight.forward(replay_data.rewards) + (
                    1 - replay_data.dones) * self.gamma * next_q_values


            # Get current Q-values estimates
            current_q_values = self.q_net.forward(
                replay_data.observations)  # [32*4,ac_dim] if r_dim_policy=1 (not r_dim)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())  # [32*4,1]
            assert len(current_q_values.shape) == 2  # [32,1]

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            # temp_current_q_values = current_q_values.to(self.weight.weight.device)
            # temp_target_q_values = target_q_values.to(self.weight.weight.device)
            # loss = F.smooth_l1_loss(temp_current_q_values, temp_target_q_values)
            
            # loss = F.mse_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # print(f'current gradient: {self.weight.weight.grad}')   # None
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()
        ### Main Q update ends

        end_train_time = time.time()
        self.time += end_train_time - start_train_time

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

        # Wandb log
        wegh = self.weight.weight
        # pdb.set_trace()
        # print(f'current weight: {wegh}')

        if self.env_name == 'SUMO':
            wandb.log({
                'loss': np.mean(losses),
                'Total Mean Q': th.mean(current_q_values).item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Weight 2': wegh[2].item(),
                'Weight 3': wegh[3].item(),
                'Total train time': self.time,
                # 'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                # 'Value 0': current_state_value_vector[0].item(),
                # 'Value 1': current_state_value_vector[1].item(),
                # 'Value 2': current_state_value_vector[2].item(),
                # 'Value 3': current_state_value_vector[3].item(),
            }
            )
        elif self.env_name == 'four-room-truncated-v0':
            wandb.log({
                'loss': np.mean(losses),
                'Total Mean Q': th.mean(current_q_values).item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total train time': self.time,
                # 'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                # 'Value 0': current_state_value_vector[0].item(),
                # 'Value 1': current_state_value_vector[1].item(),
                # 'Weight 2': wegh[2].item(),
            }
            )
        elif self.env_name == 'deep-sea-treasure-sparse-v0':
            # with th.no_grad():
            #     init_q_value = self.q_net.forward(self.init_state_tensor)[0]  # [ac_dim]
            #     ## Masking part
            #     init_q_value[0] = float('-inf')
            #     init_q_value[2] = float('-inf')
            #     init_prob = th.softmax(init_q_value / self.ent_alpha, dim=-1)

            wandb.log({
                'loss': np.mean(losses),
                'Total Mean Q': th.mean(current_q_values).item(),
                'Weight 0': wegh[0].item(),
                'Weight 1': wegh[1].item(),
                'Total train time': self.time,
                # 'Total Mean V(mu)': th.mean(current_state_value_vector).item(),
                # 'Value 0': current_state_value_vector[0].item(),
                # 'Value 1': current_state_value_vector[1].item(),
                # 'Init Up': init_prob[0].item(),
                # 'Init Down': init_prob[1].item(), ###
                # 'Init Left': init_prob[2].item(),
                # 'Init Right': init_prob[3].item(), ###
            }
            )
        else:
            raise NotImplementedError





