import warnings
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MultiInputPolicy, QNetwork, SQLPolicy, SoftQNetwork
import math
# from multiprocessing import Process, Queue, cpu_count

SelfDQN = TypeVar("SelfDQN", bound="DQN")

import pdb
import wandb

class DQN(OffPolicyAlgorithm):
    """
    Deep Q-Network (DQN)

    Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
    Default hyperparameters are taken from the Nature paper,
    except for the optimizer and learning rate that were taken from Stable Baselines defaults.

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param target_update_interval: update the target network every ``target_update_interval``
        environment steps.
    :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
    :param exploration_initial_eps: initial value of random action probability
    :param exploration_final_eps: final value of random action probability
    :param max_grad_norm: The maximum value for the gradient clipping
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """

    policy_aliases: ClassVar[Dict[str, Type[BasePolicy]]] = {
        "DQNPolicy": DQNPolicy,
        "SQLPolicy": SQLPolicy, ## Newly added for our algorithm
        "CnnPolicy": CnnPolicy,
        "MultiInputPolicy": MultiInputPolicy
    }
    # Linear schedule will be defined in `_setup_model()`
    exploration_schedule: Schedule
    q_net: Union[QNetwork, SoftQNetwork]
    q_net_target: Union[QNetwork, SoftQNetwork]
    policy: Union[DQNPolicy, SQLPolicy]

    def __init__(
        self,
        # policy: Union[str, Type[DQNPolicy]],
        policy: Union[str, Type[ Union[DQNPolicy, SQLPolicy] ]],
        env: Union[GymEnv, str],
        learning_rate: Union[float, Schedule] = 1e-4,
        buffer_size: int = 1_000_000,  # 1e6
        learning_starts: int = 50000,
        batch_size: int = 32,
        tau: float = 1.0,
        gamma: float = 0.99,
        train_freq: Union[int, Tuple[int, str]] = 4,
        gradient_steps: int = 1,
        replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
        replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
        optimize_memory_usage: bool = False,
        target_update_interval: int = 10000,
        exploration_fraction: float = 0.1,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.05,
        max_grad_norm: float = 10,
        stats_window_size: int = 100,
        tensorboard_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = None,
        device: Union[th.device, str] = "auto",
        _init_setup_model: bool = True,
        r_dim: int = 1,
        r_dim_policy: int = 1,
        ent_alpha: float = None,
    ) -> None:
        super().__init__(
            policy,
            env,
            learning_rate,
            buffer_size,
            learning_starts,
            batch_size,
            tau,
            gamma,
            train_freq,
            gradient_steps,
            action_noise=None,  # No action noise
            replay_buffer_class=replay_buffer_class,
            replay_buffer_kwargs=replay_buffer_kwargs,
            policy_kwargs=policy_kwargs,
            stats_window_size=stats_window_size,
            tensorboard_log=tensorboard_log,
            verbose=verbose,
            device=device,
            seed=seed,
            sde_support=False,
            optimize_memory_usage=optimize_memory_usage,
            supported_action_spaces=(spaces.Discrete,),
            support_multi_env=True,
            r_dim=r_dim,
            r_dim_policy=r_dim_policy,
            ent_alpha=ent_alpha
        )
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.exploration_fraction = exploration_fraction
        self.target_update_interval = target_update_interval
        # For updating the target network with multiple envs:
        self._n_calls = 0
        self.max_grad_norm = max_grad_norm
        # "epsilon" for the epsilon-greedy exploration
        self.exploration_rate = 0.0

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        super()._setup_model()
        self._create_aliases()
        # Copy running stats, see GH issue #996
        self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
        self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
        self.exploration_schedule = get_linear_fn(
            self.exploration_initial_eps,
            self.exploration_final_eps,
            self.exploration_fraction,
        )

        if self.n_envs > 1:
            if self.n_envs > self.target_update_interval:
                warnings.warn(
                    "The number of environments used is greater than the target network "
                    f"update interval ({self.n_envs} > {self.target_update_interval}), "
                    "therefore the target network will be updated after each call to env.step() "
                    f"which corresponds to {self.n_envs} steps."
                )

    def _create_aliases(self) -> None:
        self.q_net = self.policy.q_net
        self.q_net_target = self.policy.q_net_target

    def _on_step(self) -> None:
        """
        Update the exploration rate and target network if needed.
        This method is called in ``collect_rollouts()`` after each step in the environment.
        """
        self._n_calls += 1
        # Account for multiple environments
        # each call to step() corresponds to n_envs transitions
        if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0:
            polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
            # Copy running stats, see GH issue #996
            polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

        self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
        self.logger.record("rollout/exploration_rate", self.exploration_rate)

    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]

            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target.forward(replay_data.next_observations) # [32,ac_dim] if r_dim=1, [32,ac_dim,r_dim] if r_dim > 1

                if self.r_dim > 1: ## Does not use double q-learning.
                    multi_dim_reward = replay_data.rewards # [32, r_dim]
                    act_select_target = multi_dim_reward.unsqueeze(dim=1) + self.gamma * next_q_values # [32, 1, r_dim] + [32,ac_dim,r_dim] = [32,ac_dim,r_dim]
                    act_select_target_scal, _ = th.min(act_select_target, dim=-1) # scalarize: [32,ac_dim]
                    selected_act = th.argmax(act_select_target_scal, dim=1).unsqueeze(dim=-1) # [32] -> [32,1], index range: ac_space
                    tiled_selected_act = th.tile(selected_act, (1,self.r_dim)).unsqueeze(dim=1) # [32,8] -> [32,1,8]
                    selected_target = th.gather(next_q_values, dim=1, index=tiled_selected_act).squeeze(dim=1) # [32, 1, r_dim] -> [32, r_dim]

                    # replay_data.rewards: [32, rew_dim], replay_data.dones: [32,1]
                    target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * selected_target

                elif self.r_dim == 1:
                    # Follow greedy policy: use the one with the highest value
                    next_q_values, _ = next_q_values.max(dim=1) # [32] if r_dim=1,
                    # Avoid potential broadcast issue
                    next_q_values = next_q_values.reshape(-1, 1) # [32,1] if r_dim=1,
                    # 1-step TD target
                    # replay_data.rewards: [32, rew_dim], replay_data.dones: [32,1]
                    target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
                else:
                    raise NotImplementedError

            # Get current Q-values estimates
            current_q_values = self.q_net.forward(replay_data.observations) # [32,ac_dim] if r_dim=1, [32,ac_dim,r_dim] if r_dim > 1

            # Retrieve the q-values for the actions from the replay buffer
            idx = replay_data.actions.long() # replay_data.actions.long(): [32,1]
            if self.r_dim > 1:
                idx = th.tile(idx, (1, self.r_dim)).unsqueeze(dim=1) # [32,r_dim] -> [32,1,r_dim]

            current_q_values = th.gather(current_q_values, dim=1, index=idx) # [32,1] or  [32, 1, r_dim]
            if self.r_dim > 1:
                current_q_values = th.squeeze(current_q_values, dim=1)
            assert len(current_q_values.shape) == 2 # [32,r_dim]

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

        # Wandb log
        scalar_q, _ = th.min(current_q_values, dim=-1)  # [32]
        scalar_expected_q = th.min(th.mean(current_q_values, dim=0))  # [r_dim] -> [1]
        wandb.log({
            'loss': np.mean(losses),
            'Expected Min Q': th.mean(scalar_q).item(),
            'Min Expected Q': scalar_expected_q.item(),
            'Total Mean Q': th.mean(current_q_values).item()
        }
        )

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        state: Optional[Tuple[np.ndarray, ...]] = None,
        episode_start: Optional[np.ndarray] = None,
        deterministic: bool = False,
    ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
        """
        Overrides the base_class predict function to include epsilon-greedy exploration.

        :param observation: the input observation
        :param state: The last states (can be None, used in recurrent policies)
        :param episode_start: The last masks (can be None, used in recurrent policies)
        :param deterministic: Whether or not to return deterministic actions.
        :return: the model's action and the next state
            (used in recurrent policies)
        """
        if not deterministic and np.random.rand() < self.exploration_rate:
            if self.policy.is_vectorized_observation(observation):
                if isinstance(observation, dict):
                    n_batch = observation[next(iter(observation.keys()))].shape[0]
                else:
                    n_batch = observation.shape[0]
                action = np.array([self.action_space.sample() for _ in range(n_batch)])
            else:
                action = np.array(self.action_space.sample())
        else:
            action, state = self.policy.predict(observation, state, episode_start, deterministic)
        return action, state

    def learn(
        self: SelfDQN,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 4,
        tb_log_name: str = "DQN",
        reset_num_timesteps: bool = True,
        progress_bar: bool = False,
    ) -> SelfDQN:
        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            tb_log_name=tb_log_name,
            reset_num_timesteps=reset_num_timesteps,
            progress_bar=progress_bar,
        )

    def _excluded_save_params(self) -> List[str]:
        return [*super()._excluded_save_params(), "q_net", "q_net_target"]

    def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
        state_dicts = ["policy", "policy.optimizer"]

        return state_dicts, []



###### Write a new class for Maxmin and use overriding


# Define a custom neural network model class
class Weight(th.nn.Module):
    def __init__(self,
                 r_dim: int=2,
                 initialize: str='uniform'
        ):
        super(Weight, self).__init__()

        # Declare a trainable parameter
        if initialize == 'uniform':
            self.weight = th.nn.Parameter(th.full((r_dim,1), 1/r_dim, dtype=th.float32))  # Example: a 1D tensor with 1/r_dim values
        else:
            raise NotImplementedError

    def step(self, lr, grad):
        # Update the weights using the gradient and learning rate
        with th.no_grad():
            self.weight.data += lr * grad

    def forward(self, input):
        # Use the trainable parameter in the forward pass
        assert len(input.shape) == 2
        output = th.matmul(input, self.weight)
        return output


class MaxminMFQ(DQN):
    """
    Maxmin Model-free Q-learning (MaxminMFQ)
    Default hyperparameters are taken from the DQN paper,
    except for the optimizer and learning rate that were taken from Stable Baselines defaults.

    :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
    :param env: The environment to learn from (if registered in Gym, can be str)
    :param learning_rate: The learning rate, it can be a function
        of the current progress remaining (from 1 to 0)
    :param buffer_size: size of the replay buffer
    :param learning_starts: how many steps of the model to collect transitions for before learning starts
    :param batch_size: Minibatch size for each gradient update
    :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
    :param gamma: the discount factor
    :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
        like ``(5, "step")`` or ``(2, "episode")``.
    :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
        Set to ``-1`` means to do as many gradient steps as steps done in the environment
        during the rollout.
    :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
        If ``None``, it will be automatically selected.
    :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
    :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
        at a cost of more complexity.
        See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
    :param target_update_interval: update the target network every ``target_update_interval``
        environment steps.
    :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
    :param exploration_initial_eps: initial value of random action probability
    :param exploration_final_eps: final value of random action probability
    :param max_grad_norm: The maximum value for the gradient clipping
    :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
        the reported success rate, mean episode length, and mean reward over
    :param tensorboard_log: the log location for tensorboard (if None, no logging)
    :param policy_kwargs: additional arguments to be passed to the policy on creation
    :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
        debug messages
    :param seed: Seed for the pseudo random generators
    :param device: Device (cpu, cuda, ...) on which the code should be run.
        Setting it to auto, the code will be run on the GPU if possible.
    :param _init_setup_model: Whether or not to build the network at the creation of the instance
    """
    def __init__(
        self,
        # policy: Union[str, Type[DQNPolicy]], ###
        policy: Union[str, Type[Union[DQNPolicy, SQLPolicy]]],
        env: Union[GymEnv, str], ###
        learning_rate: Union[float, Schedule] = 1e-4, ###
        buffer_size: int = 1_000_000,  ###
        learning_starts: int = 50000, ###
        train_freq: Union[int, Tuple[int, str]] = 4, ###
        target_update_interval: int = 10000, ###
        exploration_initial_eps: float = 1.0, ###
        exploration_final_eps: float = 0.05, ###
        verbose: int = 0, ###
        seed: Optional[int] = None, ###
        r_dim: int = 1, ###
        r_dim_policy: int = 1, ###
        ent_alpha: float = 0.1, ###
        #### w update threshold
        soft_q_init_fraction: float = 0.1,
        #### w update params
        perturb_w_learning_rate: Union[float, Schedule] = 1e-4, # Maybe require Scheduler
        #### w grad calculation
        period_cal_w_grad: int = 2,
        perturb_q_copy_num: int = 20,  # N_p > r_dim + 1
        perturb_std_dev: float = 0.01,
        perturb_q_batch_size: int = 128,
        perturb_q_learning_rate: Union[float, Schedule] = 1e-4, # no self. Put in Adam optimizer. We think constant lr is OK.
    ) -> None:
        super().__init__(
            policy=policy,
            env=env,
            learning_rate=learning_rate,
            buffer_size=buffer_size,
            learning_starts=learning_starts,
            train_freq=train_freq,
            target_update_interval=target_update_interval,
            exploration_initial_eps=exploration_initial_eps,
            exploration_final_eps=exploration_final_eps,
            verbose=verbose,
            seed=seed,
            r_dim=r_dim,
            r_dim_policy=r_dim_policy,
            ent_alpha=ent_alpha
        )

        ### Add new parameters: weight
        assert perturb_q_copy_num > r_dim + 1 ## For linear regression. np.linalg.det(X.T @ X) deviates from 0 if perturb_q_copy_num increases

        self.weight = Weight(r_dim=r_dim)
        self.ent_alpha = ent_alpha

        self.init_state_tensor = th.tensor([1.] + [0. for _ in range(20)]).unsqueeze(0)

        ### Params for perturbation
        self.soft_q_init_fraction = soft_q_init_fraction
        self.perturb_std_dev = perturb_std_dev         ### std dev for Gaussian noise
        self.perturb_q_copy_num = perturb_q_copy_num
        self.sqrt_q_copy_num = round(math.sqrt(self.perturb_q_copy_num),2)
        self.period_cal_w_grad = period_cal_w_grad ### Period for gradient calculation

        self.perturb_q_batch_size = perturb_q_batch_size
        self.perturb_w_learning_rate = perturb_w_learning_rate

        ## To differentiate from the existing network
        ### Ver 1. Sequential
        # self.perturb_q_net = self.policy.make_q_net()
        #
        # self.perturb_q_optimizer = th.optim.Adam(
        #     self.perturb_q_net.parameters(),
        #     lr=perturb_q_learning_rate
        # )
        # self.perturb_q_net.set_training_mode(False) # Temporary

        ### Ver 2. Parallel using multi-processing
        # self.perturb_q_optimizer_list = [th.optim.Adam(perturb_q_net.parameters(),lr=perturb_q_learning_rate) for perturb_q_net in self.perturb_q_net_list]

        ### Ver 3. Single optimizer
        self.perturb_q_net_list = [self.policy.make_q_net() for _ in range(self.perturb_q_copy_num)]

        self.combined_parameters = []
        for perturb_q_net in self.perturb_q_net_list:
            self.combined_parameters.extend(perturb_q_net.parameters())
        self.perturb_q_optimizer = th.optim.Adam(self.combined_parameters, lr=perturb_q_learning_rate)

        for perturb_q_net in self.perturb_q_net_list:
            perturb_q_net.set_training_mode(False) # Temporary

        ### Target Q
        self.perturb_q_net_target = self.policy.make_q_net()
        self.perturb_q_net_target.set_training_mode(False) # Permanent

        ### Calculated gradient of w
        self.proj_grad = np.zeros((r_dim,1), dtype=np.float32)

        ### weighted reward is only feeded in this 'train' function.

    ### Function Define for multi-processing
    def train_perturbed_q(self, q_network, optimizer, input_obs, ac_idx, target_q_values):
        # Update the parameters of the perturbed Q-network
        # zero_grad-backward order is OK.
        # See: https://github.com/yunjey/pytorch-tutorial/issues/238

        q_network.set_training_mode(True)
        q_network.load_state_dict(self.q_net.state_dict())

        # Get current Q-values estimates
        current_q_values_p = q_network.forward(input_obs)  # [perturb_q_batch_size,ac_dim]
        # Retrieve the q-values for the actions from the replay buffer
        current_q_values_p = th.gather(current_q_values_p, dim=1, index=ac_idx)  # [perturb_q_batch_size,1]
        assert len(current_q_values_p.shape) == 2

        # Compute Huber loss (less sensitive to outliers)
        loss = F.smooth_l1_loss(current_q_values_p, target_q_values)
        # loss = F.mse_loss(current_q_values_p, target_q_values)

        # Optimize the Q
        optimizer.zero_grad()
        loss.backward()
        # Clip gradient norm
        th.nn.utils.clip_grad_norm_(q_network.parameters(), self.max_grad_norm)
        optimizer.step()

        q_network.set_training_mode(False)

    ### Overriding train function. Here, we should use both (i) soft-q learning and (ii) weight update.
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        # Switch to train mode (this affects batch norm / dropout)
        self.policy.set_training_mode(True)
        # Update learning rate according to schedule
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        ### Main Q update begins
        for _ in range(gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]

            with th.no_grad():
                # Compute the next Q-values using the target network
                next_q_values = self.q_net_target.forward(replay_data.next_observations)  # [32,ac_dim] if r_dim_policy=1 (not r_dim)
                ## soft-q
                next_q_values = self.ent_alpha * th.logsumexp(next_q_values / self.ent_alpha, dim=-1, keepdim=True)  # [32,1]

                # 1-step TD target
                # target_q_values: [32,1] // replay_data.rewards: [32, r_dim], replay_data.dones: [32,1]
                target_q_values = self.weight.forward(replay_data.rewards) + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates
            current_q_values = self.q_net.forward(replay_data.observations)  # [32*4,ac_dim] if r_dim_policy=1 (not r_dim)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())  # [32*4,1]
            assert len(current_q_values.shape) == 2  # [32,1]

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            # loss = F.mse_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()
        ### Main Q update ends

        ### w update begins
        if ( (self.num_timesteps >= int(self.soft_q_init_fraction*self._total_timesteps))
                and not(th.any(self.weight.weight < 0) or th.any(self.weight.weight > 1)) ): ### Initialize soft q-learning
            if self.num_timesteps%self.period_cal_w_grad == 0: ## calculate projected gradient of w

                ###### Ver 2, 3. Simultaneous
                self.perturb_q_net_target.load_state_dict(self.q_net_target.state_dict())

                # Sample replay buffer - For now we only give perturbation w.r.t. weight, so we fix sampled buffer.
                ### Increse buffer size since we perform only one gradient step. Increase accuracy.
                replay_data = self.replay_buffer.sample(self.perturb_q_batch_size,env=self._vec_normalize_env)  # type: ignore[union-attr]

                # target value calculation.
                with th.no_grad():
                    # Compute the next Q-values using the target network
                    next_q_values = self.perturb_q_net_target.forward(replay_data.next_observations)  # [perturb_q_batch_size,ac_dim]
                    ## soft-q Target
                    next_q_values = self.ent_alpha * th.logsumexp(next_q_values / self.ent_alpha, dim=-1,keepdim=True)  # [perturb_q_batch_size,1]
                    ## Perturb weight
                    perturbed_weight = self.weight.weight + th.randn(self.weight.weight.shape[0], self.perturb_q_copy_num) * self.perturb_std_dev # [r_dim, q_copy_num]
                    perturbed_weight = th.clamp(perturbed_weight, min=0.0) # nonnegative
                    # 1-step Soft TD target
                    # replay_data.rewards: [perturb_q_batch_size, r_dim], replay_data.dones: [perturb_q_batch_size,1]
                    # target_q_values: [perturb_q_batch_size, q_copy_num]
                    target_q_values = th.matmul(replay_data.rewards, perturbed_weight) + (1 - replay_data.dones) * self.gamma * next_q_values

                ### Ver 3. Single process
                current_q_values_p_list = []
                for i, perturb_q_net in enumerate(self.perturb_q_net_list):
                    perturb_q_net.set_training_mode(True)
                    perturb_q_net.load_state_dict(self.q_net.state_dict())

                    # Get current Q-values estimates
                    current_q_values_p = perturb_q_net.forward(replay_data.observations)  # [perturb_q_batch_size,ac_dim]
                    current_q_values_p_list.append(current_q_values_p)

                current_q_values_p_list = th.stack(current_q_values_p_list, dim=2) # [perturb_q_batch_size, ac_dim, q_copy_num]=[128,4,20]
                # Retrieve the q-values for the actions from the replay buffer
                idx = replay_data.actions.long().unsqueeze(-1)
                idx = idx.repeat(1,1,self.perturb_q_copy_num) # [perturb_q_batch_size,1,self.perturb_q_copy_num]
                current_q_values = th.gather(current_q_values_p_list, dim=1, index=idx).squeeze(dim=1)  # [perturb_q_batch_size, q_copy_num]

                # Compute Huber loss (less sensitive to outliers)
                loss = F.smooth_l1_loss(current_q_values, target_q_values)
                # loss = F.mse_loss(current_q_values_p, target_q_values)

                # Optimize the Q
                self.perturb_q_optimizer.zero_grad()
                loss.backward()
                # Clip gradient norm
                th.nn.utils.clip_grad_norm_(self.combined_parameters, self.sqrt_q_copy_num*self.max_grad_norm)
                self.perturb_q_optimizer.step()

                # ### Ver 2. Multi-processing
                # # Create separate processes for training each Q-network copy
                # processes = []
                # # num_processes = min(self.perturb_q_copy_num, cpu_count())  # Use at most the number of available CPU cores
                # # output_queue = Queue()
                #
                # for i, (perturb_q_net, optimizer) in enumerate(zip(self.perturb_q_net_list, self.perturb_q_optimizer_list)):
                #     # if i < num_processes:
                #     p = Process(target=self.train_perturbed_q, args=(perturb_q_net, optimizer, replay_data.observations, idx, target_q_values[:,i]))
                #     processes.append(p)
                #     p.start()
                # # Wait for all processes to finish
                # for p in processes:
                #     p.join()

                ## Since preserving w order and Q order is crucial and multiprocessing may not necessarily preserve order,
                ## We explicitly use order
                perturb_q_list = []
                for perturb_q_net in self.perturb_q_net_list:
                    perturb_q_net.set_training_mode(False)
                    new_perturb_q = perturb_q_net.forward(self.init_state_tensor)  # [1,ac_dim]
                    new_perturb_q = self.ent_alpha * th.logsumexp(new_perturb_q / self.ent_alpha, dim=-1)  # [1,]
                    perturb_q_list.append(new_perturb_q.item())
                X = perturbed_weight.transpose(0, 1).numpy() # [q_copy_num, r_dim]

                ###### Ver 1. Sequential
                # self.perturb_q_net.set_training_mode(True) ## Explicitly write
                # self.perturb_q_net_target.load_state_dict(self.q_net_target.state_dict())
                #
                # # Sample replay buffer - For now we only give perturbation w.r.t. weight, so we fix sampled buffer.
                # ### Increse buffer size since we perform only one gradient step. Increase accuracy.
                # ## 4*batch_size
                # replay_data = self.replay_buffer.sample(self.perturb_q_batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]
                #
                # # Then the target value is the same.
                # with th.no_grad():
                #     # Compute the next Q-values using the target network
                #     next_q_values = self.perturb_q_net_target.forward(replay_data.next_observations)  # [32*4,ac_dim] if r_dim_policy=1 (not r_dim)
                #     ## soft-q Target
                #     next_q_values = self.ent_alpha * th.logsumexp(next_q_values / self.ent_alpha, dim=-1,keepdim=True)  # [32,1]
                #
                #     next_q_values = (1 - replay_data.dones) * self.gamma * next_q_values
                #
                # idx = replay_data.actions.long()  # replay_data.actions.long(): [32,1]
                #
                # # Perturbation iteration begins. Generate (w+eps, Q(w+eps))
                # perturb_w_list = []
                # perturb_q_list = []
                # current_weight = self.weight.weight
                # for _ in range(self.perturb_q_copy_num): # For now we use sequential update
                #     ### Conduct SQL using perturbed_weight ### Later parallelize
                #     self.perturb_q_net.load_state_dict(self.q_net.state_dict())
                #
                #     with th.no_grad():
                #         # 1-step TD target
                #         # target_q_values: [32,1] // replay_data.rewards: [32, r_dim], replay_data.dones: [32,1]
                #         ### perturb weight
                #         perturbed_weight = current_weight + th.randn_like(current_weight) * self.perturb_std_dev
                #         perturbed_weight = th.clamp(perturbed_weight, min=0.0)
                #         perturb_w_list.append(perturbed_weight)
                #         target_q_values = th.matmul(replay_data.rewards, perturbed_weight) + next_q_values
                #
                #     # Get current Q-values estimates
                #     current_q_values_p = self.perturb_q_net.forward(replay_data.observations)  # [32,ac_dim] if r_dim_policy=1 (not r_dim)
                #
                #     # Retrieve the q-values for the actions from the replay buffer
                #     current_q_values_p = th.gather(current_q_values_p, dim=1, index=idx)  # [32,1]
                #     assert len(current_q_values_p.shape) == 2  # [32,1]
                #
                #     # Compute Huber loss (less sensitive to outliers)
                #     loss = F.smooth_l1_loss(current_q_values_p, target_q_values)
                #     # loss = F.mse_loss(current_q_values_p, target_q_values)
                #
                #     # Optimize the policy
                #     self.perturb_q_optimizer.zero_grad()
                #     loss.backward()
                #     # Clip gradient norm
                #     th.nn.utils.clip_grad_norm_(self.perturb_q_net.parameters(), self.max_grad_norm)
                #     self.perturb_q_optimizer.step()
                #
                #     #######
                #     # Initial state save!  For now we implement as follows:
                #     new_perturb_q = self.perturb_q_net.forward(self.init_state_tensor)  # [1,ac_dim]
                #     new_perturb_q = self.ent_alpha * th.logsumexp(new_perturb_q / self.ent_alpha, dim=-1)  # [1,]
                #     perturb_q_list.append(new_perturb_q.item())
                #
                # self.perturb_q_net.set_training_mode(False)

                # X = th.cat(perturb_w_list, dim=1).transpose(0, 1).numpy() # [N_p, r_dim]


                ####### Conduct linear regression.
                y = np.array(perturb_q_list) # [N_p,]

                X = np.column_stack((np.ones(X.shape[0]), X)) # [N_p, r_dim+1]
                try:
                    coefficients = np.linalg.inv(X.T @ X) @ X.T @ y # [r_dim+1,]. Assume N_p > r_dim + 1 and inverse exists.
                except: # LinAlgError("Singular matrix")
                    coefficients = np.linalg.pinv(X.T @ X) @ X.T @ y
                    print("P-Inv is used") ## Sometimes this is used.
                    print()
                intercept, feature_coef = coefficients[0], coefficients[1:]

                # print("Intercept:", intercept)
                # print("Feature Coefficients:", feature_coefficients) #  [0.01222464 0.01084723 0.00959038 0.02531738]

                ###
                self.proj_grad = feature_coef - np.mean(feature_coef)*np.ones_like(feature_coef) # array([ 0.00190182, -0.00757042,  0.00457813,  0.00109046])
                self.proj_grad = self.proj_grad.reshape(-1,1) # [r_dim,1]
                self.proj_grad = self.proj_grad.astype(np.float32)
            ### Now we update parameter using proj_grad
            self.weight.step(lr=self.perturb_w_learning_rate, grad=self.proj_grad)

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

        # Wandb log
        wegh = self.weight.weight
        wandb.log({
            'loss': np.mean(losses),
            'Total Mean Q': th.mean(current_q_values).item(),
            'Weight 0': wegh[0].item(),
            'Weight 1': wegh[1].item(),
            'Weight 2': wegh[2].item(),
            'Weight 3': wegh[3].item(),
        }
        )


