#!/usr/bin/env python
import os
from pprint import pprint
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import tyro
from dataclasses import dataclass
import setproctitle
from envs.ma_mujoco.multiagent_mujoco.mujoco_multi import MujocoMulti
from envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv
# from tensorboardX import SummaryWriter
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import copy
from gym import Space
from typing import Union, List, Tuple

# ----------------- #
# Utility Functions #
# ----------------- #

def _t2n(x: torch.Tensor) -> np.ndarray:
    """Convert a tensor to a numpy array."""
    return x.detach().cpu().numpy()

def get_shape_from_obs_space(obs_space: Space) -> Tuple:
    """Get shape from observation space."""
    if obs_space.__class__.__name__ == 'Box':
        obs_shape = obs_space.shape
    elif obs_space.__class__.__name__ == 'list':
        obs_shape = obs_space
    else:
        raise NotImplementedError
    return obs_shape

def get_shape_from_act_space(act_space: Space) -> Tuple:
    """Get shape from action space."""
    if act_space.__class__.__name__ == 'Discrete':
        act_shape = (1,)
    elif act_space.__class__.__name__ == "MultiDiscrete":
        act_shape = (act_space.shape,)
    elif act_space.__class__.__name__ == "Box":
        act_shape = act_space.shape
    elif act_space.__class__.__name__ == "MultiBinary":
        act_shape = (act_space.n,)
    else:
        raise NotImplementedError
    return act_shape

def init(module: nn.Module, weight_init, bias_init, gain: float = 1):
    """Initialize a module's weights and biases."""
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

def get_gard_norm(it):
    sum_grad = 0
    for x in it:
        if x.grad is None:
            continue
        sum_grad += x.grad.norm() ** 2
    return sum_grad ** 0.5

def huber_loss(e, d):
    a = (abs(e) <= d).float()
    b = (e > d).float()
    return a*e**2/2 + b*d*(abs(e)-d/2)

def mse_loss(e):
    return e**2/2

def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
    """Decreases the learning rate linearly"""
    progress = epoch / float(total_num_epochs)

    if progress < 0.3:
        lr = initial_lr * progress / 0.3
    else:
        lr = initial_lr * (1 - progress) / 0.7

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# ------ #
# Config #
# ------ #

@dataclass
class Args:
    # Prepare parameters
    env_name: str = "mujoco"
    algorithm_name: str = "happo"
    experiment_name: str = "mlp"
    scenario: str = "Hopper-v2"
    agent_conf: str = "3x1"
    agent_obsk: int = 2
    lr: float = 5e-6
    critic_lr: float = 5e-4
    std_x_coef: float = 1
    std_y_coef: float = 0.5
    running_id: int = 1
    n_training_threads: int = 8
    n_rollout_threads: int = 4
    num_mini_batch: int = 40
    episode_length: int = 1000
    num_env_steps: int = 10000000
    ppo_epoch: int = 5
    kl_threshold: float = 1e-4
    use_value_active_masks: bool = True
    use_eval: bool = True
    add_center_xy: bool = True
    use_state_agent: bool = True
    share_policy: bool = True

    # Additional configs from parser
    seed: int = 1
    seed_specify: bool = True
    cuda: bool = True
    cuda_deterministic: bool = True
    n_eval_rollout_threads: int = 1
    n_render_rollout_threads: int = 1
    user_name: str = "marl"
    use_obs_instead_of_state: bool = False
    use_centralized_V: bool = True
    stacked_frames: int = 1
    use_stacked_frames: bool = False
    hidden_size: int = 64
    layer_N: int = 1
    use_ReLU: bool = True
    use_popart: bool = True
    use_valuenorm: bool = True
    use_feature_normalization: bool = True
    use_orthogonal: bool = True
    gain: float = 0.01
    use_naive_recurrent_policy: bool = False
    use_recurrent_policy: bool = False
    recurrent_N: int = 1
    data_chunk_length: int = 10
    opti_eps: float = 1e-5
    weight_decay: float = 0
    ls_step: int = 10
    accept_ratio: float = 0.5
    use_clipped_value_loss: bool = True
    clip_param: float = 0.2
    entropy_coef: float = 0.01
    value_loss_coef: float = 1
    use_max_grad_norm: bool = True
    max_grad_norm: float = 10.0
    use_gae: bool = True
    gamma: float = 0.99
    gae_lambda: float = 0.95
    use_proper_time_limits: bool = False
    use_huber_loss: bool = True
    use_policy_active_masks: bool = True
    huber_delta: float = 10.0
    use_linear_lr_decay: bool = True
    save_interval: int = 1
    log_interval: int = 1
    model_dir: str = None
    eval_interval: int = 25
    eval_episodes: int = 32
    save_gifs: bool = False
    use_render: bool = False
    render_episodes: int = 5
    ifi: float = 0.1

    # Mujoco-specific arguments
    add_move_state: bool = False
    add_local_obs: bool = False
    add_distance_state: bool = False
    add_enemy_action_state: bool = False
    add_agent_id: bool = False
    add_visible_state: bool = False
    add_xy_state: bool = False
    use_mustalive: bool = True
    use_single_network: bool = False

    # Reward model and preference learning arguments
    use_reward_model: bool = True
    reward_lr: float = 1e-3
    reward_update_freq: int = 1
    preference_buffer_size: int = 10000
    trajectory_length: int = 500
    num_preference_pairs: int = 32
    
    # IPL-specific parameters
    ipl_tau: float = 0.005  # Target network update rate

# ------ #
# PopArt #
# ------ #

class PopArt(nn.Module):
    def __init__(self, input_shape: Tuple, norm_axes: int = 1, beta: float = 0.99999, per_element_update: bool = False, epsilon: float = 1e-5, device: torch.device = torch.device("cpu")):
        super(PopArt, self).__init__()
        self.input_shape = input_shape
        self.norm_axes = norm_axes
        self.epsilon = epsilon
        self.beta = beta
        self.per_element_update = per_element_update
        
        self.running_mean = nn.Parameter(torch.zeros(input_shape, device=device), requires_grad=False)
        self.running_mean_sq = nn.Parameter(torch.zeros(input_shape, device=device), requires_grad=False)
        self.debiasing_term = nn.Parameter(torch.tensor(0.0, device=device), requires_grad=False)

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_mean_sq.zero_()
        self.debiasing_term.zero_()

    def running_mean_var(self) -> Tuple[torch.Tensor, torch.Tensor]:
        debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)
        debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)
        debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)
        return debiased_mean, debiased_var

    def forward(self, input_vector: torch.Tensor, train: bool = True) -> torch.Tensor:
        if train:
            detached_input = input_vector.detach()
            batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))
            batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))

            if self.per_element_update:
                batch_size = np.prod(detached_input.size()[:self.norm_axes])
                weight = self.beta ** batch_size
            else:
                weight = self.beta

            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))

        mean, var = self.running_mean_var()
        out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]
        return out

    def denormalize(self, input_vector: torch.Tensor) -> np.ndarray:
        mean, var = self.running_mean_var()
        out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]
        return out.cpu().numpy()

# ------------------- #
# Action Distribution #
# ------------------- #

class FixedCategorical(torch.distributions.Categorical):
    def sample(self):
        return super().sample().unsqueeze(-1)

    def log_probs(self, actions):
        return super().log_prob(actions.squeeze(-1)).view(actions.size(0), -1).sum(-1).unsqueeze(-1)

    def mode(self):
        return self.probs.argmax(dim=-1, keepdim=True)

class FixedNormal(torch.distributions.Normal):
    def log_probs(self, actions):
        return super().log_prob(actions)

    def entrop(self):
        return super().entropy().sum(-1)

    def mode(self):
        return self.mean

class Categorical(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
        super(Categorical, self).__init__()
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        def init_(m): 
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
        self.linear = init_(nn.Linear(num_inputs, num_outputs))

    def forward(self, x, available_actions=None):
        x = self.linear(x)
        if available_actions is not None:
            x[available_actions == 0] = -1e10
        return FixedCategorical(logits=x)

class DiagGaussian(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01, args=None):
        super(DiagGaussian, self).__init__()
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)

        self.std_x_coef = args.std_x_coef if args else 1.0
        self.std_y_coef = args.std_y_coef if args else 0.5
        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        log_std = torch.ones(num_outputs) * self.std_x_coef
        self.log_std = torch.nn.Parameter(log_std)

    def forward(self, x, available_actions=None):
        action_mean = self.fc_mean(x)
        action_std = torch.sigmoid(self.log_std / self.std_x_coef) * self.std_y_coef
        return FixedNormal(action_mean, action_std)

# ---------- #
# ACT Layer #
# ---------- #

class ACTLayer(nn.Module):
    def __init__(self, action_space, inputs_dim, use_orthogonal, gain, args=None):
        super(ACTLayer, self).__init__()
        self.action_type = action_space.__class__.__name__
        if self.action_type == "Discrete":
            action_dim = action_space.n
            self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
        elif self.action_type == "Box":
            action_dim = action_space.shape[0]
            self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain, args)
        else:
            raise NotImplementedError(f"Action space {self.action_type} not supported.")

    def forward(self, x, available_actions=None, deterministic=False):
        action_logits = self.action_out(x, available_actions)
        actions = action_logits.mode() if deterministic else action_logits.sample()
        action_log_probs = action_logits.log_probs(actions)
        return actions, action_log_probs

    def evaluate_actions(self, x, action, available_actions=None, active_masks=None):
        action_logits = self.action_out(x, available_actions)
        action_log_probs = action_logits.log_probs(action)
        if active_masks is not None:
            if self.action_type == "Discrete":
                dist_entropy = (action_logits.entropy() * active_masks.squeeze(-1)).sum() / active_masks.sum()
            else:
                dist_entropy = (action_logits.entropy() * active_masks).sum() / active_masks.sum()
        else:
            dist_entropy = action_logits.entropy().mean()
        return action_log_probs, dist_entropy

# --- #
# MLP #
# --- #

class MLPLayer(nn.Module):
    def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU):
        super(MLPLayer, self).__init__()
        self._layer_N = layer_N
        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)

        self.fc1 = nn.Sequential(init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        self.fc_h = nn.ModuleList([nn.Sequential(init_(nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) for _ in range(self._layer_N)])

    def forward(self, x):
        x = self.fc1(x)
        for i in range(self._layer_N):
            x = self.fc_h[i](x)
        return x

class MLPBase(nn.Module):
    def __init__(self, args: Args, obs_shape):
        super(MLPBase, self).__init__()
        self._use_feature_normalization = args.use_feature_normalization
        self.hidden_size = args.hidden_size
        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.mlp = MLPLayer(obs_dim, self.hidden_size, args.layer_N, args.use_orthogonal, args.use_ReLU)

    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)
        return self.mlp(x)

# ---------------- #
# Actor and Critic #
# ---------------- #

class Actor(nn.Module):
    def __init__(self, args: Args, obs_space: Space, action_space: Space):
        super(Actor, self).__init__()
        self.hidden_size = args.hidden_size
        obs_shape = get_shape_from_obs_space(obs_space)
        self.base = MLPBase(args, obs_shape)
        self.act = ACTLayer(action_space, self.hidden_size, args.use_orthogonal, args.gain, args)

    @torch.compile(fullgraph=True)
    def forward(self, obs: torch.Tensor, available_actions: torch.Tensor = None, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        obs = obs.float()
        actor_features = self.base(obs)
        actions, action_log_probs = self.act(actor_features, available_actions, deterministic)
        return actions, action_log_probs

    @torch.compile(fullgraph=True)
    def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor, available_actions: torch.Tensor = None, active_masks: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
        obs = obs.float()
        actor_features = self.base(obs)
        action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks)
        return action_log_probs, dist_entropy


class Critic(nn.Module):
    def __init__(self, args: Args, cent_obs_space: Space):
        super(Critic, self).__init__()
        self.hidden_size = args.hidden_size
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][args.use_orthogonal]
        
        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        self.base = MLPBase(args, cent_obs_shape)
        
        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))
        self.v_out = init_(nn.Linear(self.hidden_size, 1))

    @torch.compile(fullgraph=True)
    def forward(self, cent_obs: torch.Tensor) -> torch.Tensor:
        cent_obs = cent_obs.float()
        critic_features = self.base(cent_obs)
        values = self.v_out(critic_features)
        return values

# ------ #
# Policy #
# ------ #

class HAPPO_Policy:
    def __init__(self, args: Args, obs_space: Space, cent_obs_space: Space, act_space: Space, device: torch.device = torch.device("cpu")):
        self.device = device
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay

        self.obs_space = obs_space
        self.share_obs_space = cent_obs_space
        self.act_space = act_space

        self.actor = Actor(args, self.obs_space, self.act_space).to(self.device)
        self.critic = Critic(args, self.share_obs_space).to(self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr, eps=self.opti_eps, weight_decay=self.weight_decay)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr, eps=self.opti_eps, weight_decay=self.weight_decay)

    def lr_decay(self, episode: int, episodes: int):
        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)

    def get_actions(self, cent_obs: torch.Tensor, obs: torch.Tensor, available_actions: torch.Tensor = None, deterministic: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        actions, action_log_probs = self.actor(obs, available_actions, deterministic)
        values = self.critic(cent_obs)
        return values, actions, action_log_probs

    def get_values(self, cent_obs: torch.Tensor) -> torch.Tensor:
        return self.critic(cent_obs)

    def evaluate_actions(self, cent_obs: torch.Tensor, obs: torch.Tensor, action: torch.Tensor, available_actions: torch.Tensor = None, active_masks: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        action_log_probs, dist_entropy = self.actor.evaluate_actions(obs, action, available_actions, active_masks)
        values = self.critic(cent_obs)
        return values, action_log_probs, dist_entropy

    def act(self, obs: torch.Tensor, available_actions: torch.Tensor = None, deterministic: bool = False) -> torch.Tensor:
        actions, _ = self.actor(obs, available_actions, deterministic)
        return actions

# --------------- #
# Reward Model    #
# --------------- #

class IPLRewardModel(nn.Module):
    """
    IPL (Inverse Preference Learning) Reward Model
    Based on the approach from "Inverse Preference Learning: Preference-based RL without a Reward Function"
    
    This model learns to predict Q-values and uses them to derive rewards via the Bellman equation:
    R(s,a) = Q(s,a) - γ * V(s')
    """
    def __init__(self, args: Args, obs_space: Space, action_space: Space, device: torch.device = torch.device("cpu")):
        super(IPLRewardModel, self).__init__()
        self.device = device
        self.gamma = args.gamma
        self.tau = 0.005  # Target network update rate
        self.chi2_coeff = 0.5  # Chi-squared regularization coefficient
        
        obs_shape = get_shape_from_obs_space(obs_space)
        act_shape = get_shape_from_act_space(action_space)
        
        # Encoder for observations
        self.encoder = nn.Sequential(
            nn.Linear(obs_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.LayerNorm(args.hidden_size)
        )
        
        # Q-network (critic) - ensemble of 2 Q-functions for stability
        self.q_net1 = nn.Sequential(
            nn.Linear(args.hidden_size + act_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )
        
        self.q_net2 = nn.Sequential(
            nn.Linear(args.hidden_size + act_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )
        
        # Value network for computing V(s)
        self.value_net = nn.Sequential(
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )
        
        # Target networks
        self.target_encoder = nn.Sequential(
            nn.Linear(obs_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.LayerNorm(args.hidden_size)
        )
        
        self.target_q_net1 = nn.Sequential(
            nn.Linear(args.hidden_size + act_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )
        
        self.target_q_net2 = nn.Sequential(
            nn.Linear(args.hidden_size + act_shape[0], args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )
        
        # Copy parameters to target networks
        self.target_encoder.load_state_dict(self.encoder.state_dict())
        self.target_q_net1.load_state_dict(self.q_net1.state_dict())
        self.target_q_net2.load_state_dict(self.q_net2.state_dict())
        
        # Freeze target networks
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_q_net1.parameters():
            param.requires_grad = False
        for param in self.target_q_net2.parameters():
            param.requires_grad = False
        
        # Optimizers
        self.q_optimizer = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.q_net1.parameters()) + list(self.q_net2.parameters()),
            lr=args.reward_lr
        )
        self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=args.reward_lr)
        
        # Loss functions
        self.reward_criterion = nn.BCEWithLogitsLoss(reduction='none')
        self.expectile = 0.7  # For IQL-style value learning
        
        # Reward normalization statistics
        self.reward_mean = nn.Parameter(torch.tensor(0.0, device=device), requires_grad=False)
        self.reward_std = nn.Parameter(torch.tensor(1.0, device=device), requires_grad=False)
        self.reward_count = nn.Parameter(torch.tensor(0.0, device=device), requires_grad=False)
        self.reward_momentum = 0.99  # Momentum for running statistics
        
        self.to(device)
    
    def _iql_loss(self, v_values: torch.Tensor, target_q: torch.Tensor, expectile: float) -> torch.Tensor:
        """IQL expectile loss for value function"""
        diff = target_q - v_values
        weight = torch.where(diff > 0, expectile, 1 - expectile)
        return (weight * (diff ** 2)).mean()
    
    def encode_obs(self, obs: torch.Tensor) -> torch.Tensor:
        """Encode observations"""
        return self.encoder(obs.float())
    
    def encode_obs_target(self, obs: torch.Tensor) -> torch.Tensor:
        """Encode observations with target network"""
        with torch.no_grad():
            return self.target_encoder(obs.float())
    
    def compute_q_values(self, obs_encoded: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Compute Q-values using ensemble"""
        actions = actions.float()
        q_input = torch.cat([obs_encoded, actions], dim=-1)
        q1 = self.q_net1(q_input)
        q2 = self.q_net2(q_input)
        return torch.stack([q1.squeeze(-1), q2.squeeze(-1)], dim=0)  # Shape: (2, batch_size)
    
    def compute_target_q_values(self, obs_encoded: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Compute target Q-values"""
        with torch.no_grad():
            actions = actions.float()
            q_input = torch.cat([obs_encoded, actions], dim=-1)
            q1 = self.target_q_net1(q_input)
            q2 = self.target_q_net2(q_input)
            return torch.stack([q1.squeeze(-1), q2.squeeze(-1)], dim=0)  # Shape: (2, batch_size)
    
    def compute_value(self, obs_encoded: torch.Tensor) -> torch.Tensor:
        """Compute state values"""
        return self.value_net(obs_encoded).squeeze(-1)
    
    def predict_rewards(self, obs: torch.Tensor, actions: torch.Tensor, next_obs: torch.Tensor, dones: torch.Tensor) -> torch.Tensor:
        """
        Predict rewards using the IPL approach: R(s,a) = Q(s,a) - γ * V(s')
        Normalizes rewards to [-1, 1] range for stable policy training
        """
        with torch.no_grad():
            # Encode observations
            obs_encoded = self.encode_obs(obs)
            next_obs_encoded = self.encode_obs_target(next_obs)
            
            # Compute Q-values and take minimum for stability
            q_values = self.compute_q_values(obs_encoded, actions)
            q_min = torch.min(q_values, dim=0)[0]
            
            # Compute next state values
            next_values = self.compute_value(next_obs_encoded)
            
            # Compute rewards: R = Q(s,a) - γ * (1 - done) * V(s')
            dones = dones.float()
            raw_rewards = q_min - self.gamma * (1 - dones) * next_values
            
            # Update running statistics for normalization
            # self.update_reward_stats(raw_rewards)
            
            # Normalize rewards to [-1, 1] range using running statistics
            normalized_rewards = self.normalize_rewards(raw_rewards)
            
            return normalized_rewards.unsqueeze(-1)  # Shape: (batch_size, 1)
    
    def update_reward_stats(self, rewards: torch.Tensor):
        """Update running statistics for reward normalization"""
        with torch.no_grad():
            batch_mean = rewards.mean()
            batch_std = rewards.std() + 1e-8
            
            # Update running statistics with momentum
            self.reward_mean.data = self.reward_momentum * self.reward_mean.data + (1 - self.reward_momentum) * batch_mean
            self.reward_std.data = self.reward_momentum * self.reward_std.data + (1 - self.reward_momentum) * batch_std
            self.reward_count.data += rewards.numel()
    
    def normalize_rewards(self, rewards: torch.Tensor) -> torch.Tensor:
        """Normalize rewards using running statistics and clamp to [-1, 1]"""
        with torch.no_grad():
            # Normalize using running statistics
            # rewards = (rewards - self.reward_mean) / (self.reward_std + 1e-8)
            # # Clamp to reasonable range and apply tanh for smooth bounds
            # rewards = torch.clamp(rewards, -5.0, 5.0)
            rewards = torch.tanh(rewards)
            return rewards
    
    def update_target_networks(self):
        """Soft update target networks"""
        for param, target_param in zip(self.encoder.parameters(), self.target_encoder.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.q_net1.parameters(), self.target_q_net1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.q_net2.parameters(), self.target_q_net2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

# Keep the old RewardModel for backward compatibility
RewardModel = IPLRewardModel

class IPLPreferenceBuffer:
    """Buffer for storing trajectory preferences for IPL learning"""
    def __init__(self, args: Args, obs_space: Space, action_space: Space):
        self.buffer_size = args.preference_buffer_size
        self.trajectory_length = args.trajectory_length
        self.segment_size = args.trajectory_length  # IPL typically uses full trajectory segments
        
        obs_shape = get_shape_from_obs_space(obs_space)
        act_shape = get_shape_from_act_space(action_space)
        
        # Storage for trajectory segments (including next_obs for IPL)
        self.obs_buffer = np.zeros((self.buffer_size, self.trajectory_length + 1, obs_shape[0]), dtype=np.float32)
        self.action_buffer = np.zeros((self.buffer_size, self.trajectory_length, act_shape[0]), dtype=np.float32)
        self.reward_buffer = np.zeros((self.buffer_size, self.trajectory_length, 1), dtype=np.float32)
        self.done_buffer = np.zeros((self.buffer_size, self.trajectory_length, 1), dtype=np.float32)
        
        # Preference dataset storage
        self.preference_obs_1 = []
        self.preference_obs_2 = []
        self.preference_action_1 = []
        self.preference_action_2 = []
        self.preference_labels = []
        
        self.ptr = 0
        self.size = 0
    
    def add_trajectory(self, obs_segment: np.ndarray, action_segment: np.ndarray, reward_segment: np.ndarray, done_segment: np.ndarray = None):
        """Add a trajectory segment to the buffer"""
        # Handle trajectory length mismatch
        if obs_segment.shape[0] != self.trajectory_length + 1:
            if obs_segment.shape[0] == self.trajectory_length:
                # Pad with the last observation for next_obs
                obs_segment = np.concatenate([obs_segment, obs_segment[-1:]], axis=0)
            else:
                return  # Skip if we can't handle the trajectory
                
        if action_segment.shape[0] != self.trajectory_length:
            return  # Skip if action length doesn't match
        
        if reward_segment.shape[0] != self.trajectory_length:
            return  # Skip if reward length doesn't match
            
        self.obs_buffer[self.ptr] = obs_segment
        self.action_buffer[self.ptr] = action_segment
        self.reward_buffer[self.ptr] = reward_segment
        
        if done_segment is not None and done_segment.shape[0] == self.trajectory_length:
            self.done_buffer[self.ptr] = done_segment
        else:
            # Default to not done
            self.done_buffer[self.ptr] = np.zeros((self.trajectory_length, 1), dtype=np.float32)
        
        self.ptr = (self.ptr + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)
    
    def generate_preference_pairs(self, num_pairs: int):
        """Generate preference pairs from stored trajectories for IPL learning"""
        if self.size < 2:
            return []
        
        pairs = []
        for _ in range(num_pairs):
            # Sample two different trajectories
            idx1, idx2 = np.random.choice(self.size, 2, replace=False)
            
            # Get trajectory returns (sum of original environment rewards)
            traj1_return = np.sum(self.reward_buffer[idx1])
            traj2_return = np.sum(self.reward_buffer[idx2])
            
            # Create preference label (1: traj2 preferred over traj1, 0: traj1 preferred over traj2)
            preference = 1.0 if traj2_return > traj1_return else 0.0
            
            # Store in IPL format
            pairs.append({
                'obs_1': self.obs_buffer[idx1].copy(),        # Shape: (T+1, obs_dim)
                'obs_2': self.obs_buffer[idx2].copy(),        # Shape: (T+1, obs_dim) 
                'action_1': self.action_buffer[idx1].copy(),  # Shape: (T, act_dim)
                'action_2': self.action_buffer[idx2].copy(),  # Shape: (T, act_dim)
                'label': preference,
                'discount': 0.99,  # Discount factor for IPL
                'traj1_return': traj1_return,
                'traj2_return': traj2_return
            })
        
        return pairs
    
    def add_preference_pairs(self, pairs):
        """Add preference pairs to the preference dataset"""
        for pair in pairs:
            self.preference_obs_1.append(pair['obs_1'])
            self.preference_obs_2.append(pair['obs_2'])
            self.preference_action_1.append(pair['action_1'])
            self.preference_action_2.append(pair['action_2'])
            self.preference_labels.append(pair['label'])
    
    def sample_preference_batch(self, batch_size: int):
        """Sample a batch of preference pairs for training"""
        # Always generate fresh preference pairs from current trajectory buffer
        if self.size < 2:
            return None
            
        new_pairs = self.generate_preference_pairs(batch_size)
        if len(new_pairs) == 0:
            return None
            
        batch = {
            'obs_1': np.array([pair['obs_1'] for pair in new_pairs]),
            'obs_2': np.array([pair['obs_2'] for pair in new_pairs]),
            'action_1': np.array([pair['action_1'] for pair in new_pairs]),
            'action_2': np.array([pair['action_2'] for pair in new_pairs]),
            'label': np.array([pair['label'] for pair in new_pairs]),
            'discount': np.array([pair['discount'] for pair in new_pairs])
        }
        
        return batch

# Keep backward compatibility
PreferenceBuffer = IPLPreferenceBuffer

# ------------ #
# ReplayBuffer #
# ------------ #

class SeparatedReplayBuffer:
    def __init__(self, args: Args, obs_space, share_obs_space, act_space, device="cuda"):
        self.episode_length = args.episode_length
        self.n_rollout_threads = args.n_rollout_threads
        self.gamma = args.gamma
        self.gae_lambda = args.gae_lambda
        self._use_gae = args.use_gae
        self._use_popart = args.use_popart
        self._use_proper_time_limits = args.use_proper_time_limits
        self.device = device

        obs_shape = get_shape_from_obs_space(obs_space)
        share_obs_shape = get_shape_from_obs_space(share_obs_space)
        act_shape = get_shape_from_act_space(act_space)

        self.share_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *share_obs_shape), dtype=np.float32)
        self.obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *obs_shape), dtype=np.float32)
        self.value_preds = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.returns = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.actions = np.zeros((self.episode_length, self.n_rollout_threads, *act_shape), dtype=np.float32)
        self.action_log_probs = np.zeros((self.episode_length, self.n_rollout_threads, *act_shape), dtype=np.float32)
        self.rewards = np.zeros((self.episode_length, self.n_rollout_threads, 1), dtype=np.float32)
        self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)
        self.bad_masks = np.ones_like(self.masks)
        self.active_masks = np.ones_like(self.masks)
        self.factor = None
        self.step = 0

    def update_factor(self, factor):
        self.factor = factor.copy()

    def insert(self, share_obs, obs, actions, action_log_probs, value_preds, rewards, masks, bad_masks=None, active_masks=None):
        self.share_obs[self.step + 1] = share_obs.copy()
        self.obs[self.step + 1] = obs.copy()
        self.actions[self.step] = actions.copy()
        self.action_log_probs[self.step] = action_log_probs.copy()
        self.value_preds[self.step] = value_preds.copy()
        self.rewards[self.step] = rewards.copy()
        self.masks[self.step + 1] = masks.copy()
        if bad_masks is not None: self.bad_masks[self.step + 1] = bad_masks.copy()
        if active_masks is not None: self.active_masks[self.step + 1] = active_masks.copy()
        self.step = (self.step + 1) % self.episode_length

    def after_update(self):
        self.share_obs[0] = self.share_obs[-1].copy()
        self.obs[0] = self.obs[-1].copy()
        self.masks[0] = self.masks[-1].copy()
        self.bad_masks[0] = self.bad_masks[-1].copy()
        self.active_masks[0] = self.active_masks[-1].copy()

    def compute_returns(self, next_value, value_normalizer=None):
        if self._use_proper_time_limits:
            if self._use_gae:
                self.value_preds[-1] = next_value
                gae = 0
                for step in reversed(range(self.rewards.shape[0])):
                    if self._use_popart:
                        value_preds_tensor = torch.from_numpy(self.value_preds).to(self.device)
                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(value_preds_tensor[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(value_preds_tensor[step])
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        gae = gae * self.bad_masks[step + 1]
                        self.returns[step] = gae + value_normalizer.denormalize(value_preds_tensor[step])
                    else:
                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        gae = gae * self.bad_masks[step + 1]
                        self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value
                for step in reversed(range(self.rewards.shape[0])):
                    self.returns[step] = (self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step])
        else:
            if self._use_gae:
                self.value_preds[-1] = next_value
                gae = 0
                for step in reversed(range(self.rewards.shape[0])):
                    if self._use_popart:
                        value_preds_tensor = torch.from_numpy(self.value_preds).to(self.device)
                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(value_preds_tensor[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(value_preds_tensor[step])
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        self.returns[step] = gae + value_normalizer.denormalize(value_preds_tensor[step])
                    else:
                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]
                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae
                        self.returns[step] = gae + self.value_preds[step]
            else:
                self.returns[-1] = next_value
                for step in reversed(range(self.rewards.shape[0])):
                    self.returns[step] = self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]

    def feed_forward_generator(self, advantages, num_mini_batch=None, mini_batch_size=None):
        episode_length, n_rollout_threads = self.rewards.shape[0:2]
        batch_size = n_rollout_threads * episode_length
        if mini_batch_size is None:
            assert batch_size >= num_mini_batch, "PPO requires the number of processes * number of steps to be greater than or equal to the number of PPO mini batches."
            mini_batch_size = batch_size // num_mini_batch

        rand = torch.randperm(batch_size).numpy()
        sampler = [rand[i*mini_batch_size:(i+1)*mini_batch_size] for i in range(num_mini_batch)]

        share_obs = self.share_obs[:-1].reshape(-1, *self.share_obs.shape[2:])
        obs = self.obs[:-1].reshape(-1, *self.obs.shape[2:])
        actions = self.actions.reshape(-1, self.actions.shape[-1])
        value_preds = self.value_preds[:-1].reshape(-1, 1)
        returns = self.returns[:-1].reshape(-1, 1)
        masks = self.masks[:-1].reshape(-1, 1)
        active_masks = self.active_masks[:-1].reshape(-1, 1)
        action_log_probs = self.action_log_probs.reshape(-1, self.action_log_probs.shape[-1])
        advantages = advantages.reshape(-1, 1)
        factor = self.factor.reshape(-1, self.factor.shape[-1]) if self.factor is not None else None

        for indices in sampler:
            share_obs_batch = share_obs[indices]
            obs_batch = obs[indices]
            actions_batch = actions[indices]
            value_preds_batch = value_preds[indices]
            return_batch = returns[indices]
            masks_batch = masks[indices]
            active_masks_batch = active_masks[indices]
            old_action_log_probs_batch = action_log_probs[indices]
            adv_targ = advantages[indices]
            factor_batch = factor[indices] if factor is not None else None
            
            yield share_obs_batch, obs_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, factor_batch

# ------- #
# Trainer #
# ------- #

class HAPPO:
    def __init__(self, args: Args, policy: HAPPO_Policy, device: torch.device = torch.device("cpu")):
        self.device = device
        self.policy = policy
        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm
        self.huber_delta = args.huber_delta
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks

        self.value_normalizer = PopArt((1,), device=self.device) if self._use_popart else None

    def cal_value_loss(self, values: torch.Tensor, value_preds_batch: torch.Tensor, return_batch: torch.Tensor, active_masks_batch: torch.Tensor) -> torch.Tensor:
        if self._use_popart:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped
            error_original = self.value_normalizer(return_batch) - values
        else:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        value_loss = torch.max(value_loss_original, value_loss_clipped) if self._use_clipped_value_loss else value_loss_original
        value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() if self._use_value_active_masks else value_loss.mean()
        return value_loss

    def ppo_update(self, sample: Tuple, update_actor: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        share_obs_batch, obs_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, factor_batch = sample

        old_action_log_probs_batch = torch.from_numpy(old_action_log_probs_batch).to(self.device)
        adv_targ = torch.from_numpy(adv_targ).to(self.device)
        value_preds_batch = torch.from_numpy(value_preds_batch).to(self.device)
        return_batch = torch.from_numpy(return_batch).to(self.device)
        active_masks_batch = torch.from_numpy(active_masks_batch).to(self.device)
        if factor_batch is not None:
            factor_batch = torch.from_numpy(factor_batch).to(self.device)

        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(
            torch.from_numpy(share_obs_batch).to(self.device),
            torch.from_numpy(obs_batch).to(self.device),
            torch.from_numpy(actions_batch).to(self.device),
            active_masks=active_masks_batch
        )
        
        imp_weights = torch.prod(torch.exp(action_log_probs - old_action_log_probs_batch), dim=-1, keepdim=True)
        surr1 = imp_weights * adv_targ
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ
        
        policy_action_loss = (-torch.sum(factor_batch * torch.min(surr1, surr2), dim=-1, keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() if self._use_policy_active_masks else -torch.sum(factor_batch * torch.min(surr1, surr2), dim=-1, keepdim=True).mean()
        policy_loss = policy_action_loss

        self.policy.actor_optimizer.zero_grad()
        if update_actor:
            (policy_loss - dist_entropy * self.entropy_coef).backward()
        
        actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm) if self._use_max_grad_norm else get_gard_norm(self.policy.actor.parameters())
        self.policy.actor_optimizer.step()

        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)
        self.policy.critic_optimizer.zero_grad()
        (value_loss * self.value_loss_coef).backward()
        critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm) if self._use_max_grad_norm else get_gard_norm(self.policy.critic.parameters())
        self.policy.critic_optimizer.step()

        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights

    def train(self, buffer: SeparatedReplayBuffer, update_actor: bool = True) -> dict:
        advantages = buffer.returns[:-1] - (self.value_normalizer.denormalize(torch.from_numpy(buffer.value_preds[:-1]).to(self.device)) if self._use_popart else buffer.value_preds[:-1])
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        
        train_info = {'value_loss': 0, 'policy_loss': 0, 'dist_entropy': 0, 'actor_grad_norm': 0, 'critic_grad_norm': 0, 'ratio': 0}

        for _ in range(self.ppo_epoch):
            data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)
            for sample in data_generator:
                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights = self.ppo_update(sample, update_actor)
                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()

        num_updates = self.ppo_epoch * self.num_mini_batch
        for k in train_info.keys():
            train_info[k] /= num_updates
        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()

# ------ #
# Runner #
# ------ #

class MujocoRunner:
    def __init__(self, config: dict):
        pprint(config)
        self.all_args: Args = config['all_args']
        self.envs: ShareSubprocVecEnv = config['envs']
        self.eval_envs: ShareSubprocVecEnv = config['eval_envs']
        self.device = config['device']
        self.num_agents = config['num_agents']

        # parameters
        self.env_name = self.all_args.env_name
        self.algorithm_name = self.all_args.algorithm_name
        self.experiment_name = self.all_args.experiment_name
        self.use_centralized_V = self.all_args.use_centralized_V
        self.num_env_steps = self.all_args.num_env_steps
        self.episode_length = self.all_args.episode_length
        self.n_rollout_threads = self.all_args.n_rollout_threads
        self.use_linear_lr_decay = self.all_args.use_linear_lr_decay
        self.save_interval = self.all_args.save_interval
        self.use_eval = self.all_args.use_eval
        self.eval_interval = self.all_args.eval_interval
        self.log_interval = self.all_args.log_interval
        self.model_dir = self.all_args.model_dir

        self.log_dir = f"logs/{self.all_args.scenario}"
        if not os.path.exists(self.log_dir): os.makedirs(self.log_dir)
        self.writter = SummaryWriter(self.log_dir)
        self.save_dir = f"{self.log_dir}/models"
        if not os.path.exists(self.save_dir): os.makedirs(self.save_dir)

        # policy
        self.policy: list[HAPPO_Policy] = []
        for agent_id in range(self.num_agents):
            share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]
            po = HAPPO_Policy(self.all_args, self.envs.observation_space[agent_id], share_observation_space, self.envs.action_space[agent_id], device=self.device)
            self.policy.append(po)

        if self.model_dir is not None: self.restore()

        self.trainer: list[HAPPO] = []
        self.buffer: list[SeparatedReplayBuffer] = []
        for agent_id in range(self.num_agents):
            tr = HAPPO(self.all_args, self.policy[agent_id], device=self.device)
            share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]
            bu = SeparatedReplayBuffer(self.all_args, self.envs.observation_space[agent_id], share_observation_space, self.envs.action_space[agent_id], device=self.device)
            self.buffer.append(bu)
            self.trainer.append(tr)

        # Initialize IPL reward models and preference buffers if using reward model
        if self.all_args.use_reward_model:
            self.reward_models: list[IPLRewardModel] = []
            self.preference_buffers: list[IPLPreferenceBuffer] = []
            self.episode_trajectories: list = [[] for _ in range(self.num_agents)]  # Store current episode data
            
            for agent_id in range(self.num_agents):
                reward_model = IPLRewardModel(self.all_args, self.envs.observation_space[agent_id], 
                                            self.envs.action_space[agent_id], device=self.device)
                preference_buffer = IPLPreferenceBuffer(self.all_args, self.envs.observation_space[agent_id], 
                                                      self.envs.action_space[agent_id])
                self.reward_models.append(reward_model)
                self.preference_buffers.append(preference_buffer)

    def run(self):
        self.warmup()
        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads
        train_episode_rewards = [0 for _ in range(self.n_rollout_threads)]
        
        for episode in range(episodes):
            if self.use_linear_lr_decay:
                for agent_id in range(self.num_agents):
                    self.trainer[agent_id].policy.lr_decay(episode, episodes)

            # Reset episode trajectories
            if self.all_args.use_reward_model:
                for agent_id in range(self.num_agents):
                    self.episode_trajectories[agent_id] = []

            done_episodes_rewards = []
            predicted_rewards_info = []
            
            for step in tqdm(range(self.episode_length), "Collecting ...", ncols=100, leave=False, disable=True):
                values, actions, action_log_probs = self.collect(step)
                obs, share_obs, rewards, dones, infos, _ = self.envs.step(actions)
                
                # Store ORIGINAL environment rewards in episode trajectories for preference learning
                if self.all_args.use_reward_model:
                    for agent_id in range(self.num_agents):
                        if len(self.episode_trajectories[agent_id]) > 0:
                            # Store original environment reward for preference learning
                            self.episode_trajectories[agent_id][-1]['reward'] = rewards[:, agent_id].reshape(-1, 1).copy()
                
                # Use predicted rewards for policy training if reward model is enabled
                # if self.all_args.use_reward_model:
                    # original_rewards = rewards.copy()
                    # predicted_rewards = self.get_predicted_rewards(obs, actions)
                    # # More conservative blend ratio - only use 20% predicted rewards max
                    # blend_ratio = min(0.2, (episode - 20) / 200.0)  # Gradually increase reliance on reward model
                    # rewards = (1 - blend_ratio) * rewards + blend_ratio * predicted_rewards
                    
                    # if len(predicted_rewards_info) == 0:  # Log once per episode
                    #     predicted_rewards_info = {
                    #         'original_reward_mean': np.mean(original_rewards),
                    #         'predicted_reward_mean': np.mean(predicted_rewards),
                    #         'blended_reward_mean': np.mean(rewards),
                    #         'blend_ratio': blend_ratio
                    #     }
                    # rewards = self.get_predicted_rewards(obs, actions)
                
                dones_env = np.all(dones, axis=1)
                reward_env = np.mean(rewards, axis=1).flatten()
                train_episode_rewards += reward_env
                for t in range(self.n_rollout_threads):
                    if dones_env[t]:
                        done_episodes_rewards.append(train_episode_rewards[t])
                        train_episode_rewards[t] = 0

                if self.all_args.use_reward_model:
                    rewards = self.get_predicted_rewards(obs, actions)
                    # Scale predicted rewards to match the scale of original environment rewards
                    # Use a blend of original and predicted rewards for stability
                    # blend_ratio = min(0.8, max(0.0, (episode - 10) / 50.0))  # Gradually increase reliance on predicted rewards
                    # rewards = (1 - blend_ratio) * rewards + blend_ratio * predicted_rewards
                    
                    # # Log reward information occasionally 
                    # if step == 0 and episode % 10 == 0:
                    #     print(f"Episode {episode}: Blend ratio = {blend_ratio:.3f}, "
                    #           f"Orig reward mean = {np.mean(rewards):.4f}, "
                    #           f"Pred reward mean = {np.mean(predicted_rewards):.4f}")

                data = obs, share_obs, rewards, dones, values, actions, action_log_probs
                self.insert(data)

            # Process episode trajectories for preference learning
            if self.all_args.use_reward_model:
                self.process_episode_trajectories()

            self.compute()
            train_infos = self.train()
            
            # Train reward model periodically
            if self.all_args.use_reward_model and episode % self.all_args.reward_update_freq == 0 and episode > 0:
                print(f"Training reward models at episode {episode}...")
                reward_train_infos = self.train_reward_models()
                if reward_train_infos:
                    for agent_id, info in enumerate(reward_train_infos):
                        if len(train_infos) > agent_id:
                            train_infos[agent_id].update(info)
                    print(f"Reward model training completed. Info: {reward_train_infos}")
                else:
                    print("No reward model training info returned")

            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
            if (episode % self.save_interval == 0 or episode == episodes - 1): self.save()

            if episode % self.log_interval == 0:
                end = time.time()
                print(f"\n Scenario {self.all_args.scenario} Algo {self.algorithm_name} Exp {self.experiment_name} updates {episode}/{episodes} episodes, total num timesteps {total_num_steps}/{self.num_env_steps}, FPS {int(total_num_steps / (end - start))}.\n")
                self.log_train(train_infos, total_num_steps)
                if len(done_episodes_rewards) > 0:
                    aver_episode_rewards = np.mean(done_episodes_rewards)
                    print("some episodes done, average rewards: ", aver_episode_rewards)
                    self.writter.add_scalar("train_episode_rewards/aver_rewards", aver_episode_rewards, total_num_steps)
                
                # Log reward model information
                if self.all_args.use_reward_model and predicted_rewards_info:
                    for k, v in predicted_rewards_info.items():
                        self.writter.add_scalar(f"reward_model/{k}", v, total_num_steps)

            if episode % self.eval_interval == 0 and self.use_eval: self.eval(total_num_steps)

    def warmup(self):
        obs, share_obs, _ = self.envs.reset()
        share_obs = obs if not self.use_centralized_V else share_obs
        for agent_id in range(self.num_agents):
            self.buffer[agent_id].share_obs[0] = share_obs[:, agent_id].copy()
            self.buffer[agent_id].obs[0] = obs[:, agent_id].copy()

    def collect(self, step: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        values, actions, action_log_probs = [], [], []
        for agent_id in range(self.num_agents):
            value, action, action_log_prob = self._collect(agent_id, self.buffer[agent_id].share_obs[step], self.buffer[agent_id].obs[step])
            values.append(value)
            actions.append(action)
            action_log_probs.append(action_log_prob)
            
            # Store trajectory data for preference learning
            if self.all_args.use_reward_model:
                step_data = {
                    'obs': self.buffer[agent_id].obs[step].copy(),
                    'action': action.copy()
                }
                self.episode_trajectories[agent_id].append(step_data)

        return np.array(values).transpose(1,0,2), np.array(actions).transpose(1,0,2), np.array(action_log_probs).transpose(1,0,2)

    @torch.no_grad()
    def _collect(self, agent_id, share_obs, obs):
        self.trainer[agent_id].prep_rollout()
        share_obs_tensor = torch.from_numpy(share_obs).to(self.device)
        obs_tensor = torch.from_numpy(obs).to(self.device)
        value, action, action_log_prob = self.trainer[agent_id].policy.get_actions(share_obs_tensor, obs_tensor)
        return _t2n(value), _t2n(action), _t2n(action_log_prob)

    def insert(self, data):
        obs, share_obs, rewards, dones, values, actions, action_log_probs = data
        dones_env = np.all(dones, axis=1)
        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)
        active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)
        active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)
        if not self.use_centralized_V: share_obs = obs

        for agent_id in range(self.num_agents):
            self.buffer[agent_id].insert(share_obs[:, agent_id], obs[:, agent_id], actions[:, agent_id], action_log_probs[:, agent_id], values[:, agent_id], rewards[:, agent_id], masks[:, agent_id], active_masks=active_masks[:, agent_id])

    @torch.no_grad()
    def compute(self):
        for agent_id in range(self.num_agents):
            self.trainer[agent_id].prep_rollout()
            next_value = self.trainer[agent_id].policy.get_values(torch.from_numpy(self.buffer[agent_id].share_obs[-1]).to(self.device))
            next_value = _t2n(next_value)
            self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer)

    def train(self) -> List[dict]:
        train_infos = []
        factor = np.ones((self.episode_length, self.n_rollout_threads, 1), dtype=np.float32)
        for agent_id in torch.randperm(self.num_agents):
            self.trainer[agent_id].prep_training()
            self.buffer[agent_id].update_factor(factor)
            
            obs_tensor = torch.from_numpy(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:])).to(self.device)
            actions_tensor = torch.from_numpy(self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:])).to(self.device)
            active_masks_tensor = torch.from_numpy(self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:])).to(self.device)

            old_actions_logprob, _ = self.trainer[agent_id].policy.actor.evaluate_actions(obs_tensor, actions_tensor, active_masks=active_masks_tensor)
            train_info = self.trainer[agent_id].train(self.buffer[agent_id])
            new_actions_logprob, _ = self.trainer[agent_id].policy.actor.evaluate_actions(obs_tensor, actions_tensor, active_masks=active_masks_tensor)
            
            factor = factor * _t2n(torch.prod(torch.exp(new_actions_logprob - old_actions_logprob), dim=-1).reshape(self.episode_length, self.n_rollout_threads, 1))
            train_infos.append(train_info)
            self.buffer[agent_id].after_update()
        return train_infos

    def save(self):
        for agent_id in range(self.num_agents):
            policy_actor = self.trainer[agent_id].policy.actor
            torch.save(policy_actor.state_dict(), str(self.save_dir) + f"/actor_agent{agent_id}.pt")
            policy_critic = self.trainer[agent_id].policy.critic
            torch.save(policy_critic.state_dict(), str(self.save_dir) + f"/critic_agent{agent_id}.pt")
            
            # Save reward model if using it
            if self.all_args.use_reward_model:
                torch.save(self.reward_models[agent_id].state_dict(), str(self.save_dir) + f"/reward_model_agent{agent_id}.pt")

    def restore(self):
        for agent_id in range(self.num_agents):
            policy_actor_state_dict = torch.load(str(self.model_dir) + f'/actor_agent{agent_id}.pt')
            self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict)
            policy_critic_state_dict = torch.load(str(self.model_dir) + f'/critic_agent{agent_id}.pt')
            self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict)
            
            # Restore reward model if using it and file exists
            if self.all_args.use_reward_model:
                reward_model_path = str(self.model_dir) + f'/reward_model_agent{agent_id}.pt'
                if os.path.exists(reward_model_path):
                    reward_model_state_dict = torch.load(reward_model_path)
                    self.reward_models[agent_id].load_state_dict(reward_model_state_dict)

    def log_train(self, train_infos, total_num_steps):
        print("average_step_rewards is {}.".format(np.mean(self.buffer[0].rewards)))
        for agent_id in range(self.num_agents):
            train_infos[agent_id]["average_step_rewards"] = np.mean(self.buffer[agent_id].rewards)
            for k, v in train_infos[agent_id].items():
                agent_k = f"agent{agent_id}/{k}"
                self.writter.add_scalar(agent_k, v, total_num_steps)

    @torch.no_grad()
    def eval(self, total_num_steps: int):
        eval_episode_rewards = []
        eval_obs, _, _ = self.eval_envs.reset()
        
        p = tqdm(total=self.all_args.eval_episodes, desc="Evaluating...", ncols=100, leave=False, disable=True)
        while len(eval_episode_rewards) < self.all_args.eval_episodes:
            eval_actions = []
            for agent_id in range(self.num_agents):
                self.trainer[agent_id].prep_rollout()
                eval_obs_tensor = torch.from_numpy(eval_obs[:, agent_id]).to(self.device)
                action = self.trainer[agent_id].policy.act(eval_obs_tensor, deterministic=True)
                eval_actions.append(_t2n(action))
            eval_actions = np.array(eval_actions).transpose(1, 0, 2)

            eval_obs, _, eval_rewards, eval_dones, _, _ = self.eval_envs.step(eval_actions)
            eval_dones_env = np.all(eval_dones, axis=1)
            
            for i in range(self.all_args.n_eval_rollout_threads):
                if eval_dones_env[i]:
                    eval_episode_rewards.append(np.sum(eval_rewards[i]))
                    p.update(1)
        p.close()

        eval_env_infos = {'eval_average_episode_rewards': np.mean(eval_episode_rewards), 'eval_max_episode_rewards': np.max(eval_episode_rewards)}
        self.log_env(eval_env_infos, total_num_steps)
        print(f"eval_average_episode_rewards is {np.mean(eval_episode_rewards)}.")

    def log_env(self, env_infos, total_num_steps):
        for k, v in env_infos.items():
            if isinstance(v, list) and len(v) > 0:
                self.writter.add_scalar(k, np.mean(v), total_num_steps)
            else:
                self.writter.add_scalar(k, v, total_num_steps)

    def get_predicted_rewards(self, obs: np.ndarray, actions: np.ndarray) -> np.ndarray:
        """Get predicted rewards from IPL reward models"""
        if not self.all_args.use_reward_model:
            return np.zeros_like(obs[:, :, 0:1])  # Return zero rewards if not using reward model
        
        predicted_rewards = np.zeros((obs.shape[0], obs.shape[1], 1), dtype=np.float32)
        
        # For IPL, we need next observations to compute R = Q(s,a) - γ * V(s')
        # We'll use the current step's next observation from the buffer
        for agent_id in range(self.num_agents):
            current_obs = torch.from_numpy(obs[:, agent_id]).float().to(self.device)
            current_actions = torch.from_numpy(actions[:, agent_id]).float().to(self.device)
            
            # For next_obs, we use the next step's observation from the buffer
            # If we're at the last step, we use the current observation as next_obs
            buffer_step = self.buffer[agent_id].step
            if buffer_step < self.episode_length - 1:
                next_obs = torch.from_numpy(self.buffer[agent_id].obs[buffer_step + 1]).float().to(self.device)
            else:
                next_obs = current_obs  # Use current obs if at episode end
            
            # Create done mask (assume not done during episode)
            dones = torch.zeros(current_obs.shape[0], dtype=torch.float32, device=self.device)
            
            with torch.no_grad():
                pred_rewards = self.reward_models[agent_id].predict_rewards(
                    current_obs, current_actions, next_obs, dones
                )
                predicted_rewards[:, agent_id] = _t2n(pred_rewards)
        
        return predicted_rewards

    def process_episode_trajectories(self):
        """Process completed episode trajectories and add them to IPL preference buffers"""
        if not self.all_args.use_reward_model:
            return
        
        total_trajectories_added = 0
        
        for agent_id in range(self.num_agents):
            episode_data = self.episode_trajectories[agent_id]
            
            if len(episode_data) < self.all_args.trajectory_length:
                print(f"Agent {agent_id}: Episode too short ({len(episode_data)} < {self.all_args.trajectory_length})")
                continue
            
            # Create trajectory segments for IPL
            num_segments = len(episode_data) // self.all_args.trajectory_length
            # print(f"Agent {agent_id}: Processing {num_segments} segments from {len(episode_data)} steps")
            
            for seg_idx in range(num_segments):
                start_idx = seg_idx * self.all_args.trajectory_length
                end_idx = start_idx + self.all_args.trajectory_length
                
                segment = episode_data[start_idx:end_idx]
                
                # Extract observations, actions, rewards, and next observations
                obs_segment = np.array([step['obs'] for step in segment])  # Shape: (traj_len, n_envs, obs_dim)
                action_segment = np.array([step['action'] for step in segment])  # Shape: (traj_len, n_envs, act_dim)
                reward_segment = np.array([step['reward'] for step in segment])  # Shape: (traj_len, n_envs, 1)
                
                # Add each environment's trajectory segment to the buffer
                for env_idx in range(self.n_rollout_threads):
                    # Ensure proper shapes
                    obs_seg = obs_segment[:, env_idx]  # (traj_len, obs_dim)
                    act_seg = action_segment[:, env_idx]  # (traj_len, act_dim)
                    rew_seg = reward_segment[:, env_idx]  # (traj_len, 1)
                    
                    # Handle potential shape issues with rewards
                    if rew_seg.ndim > 2:
                        rew_seg = rew_seg.squeeze()
                    if rew_seg.ndim == 1:
                        rew_seg = rew_seg.reshape(-1, 1)
                    
                    # For IPL, we need observations including next_obs
                    # Create obs_with_next by padding the last observation
                    obs_with_next = np.concatenate([obs_seg, obs_seg[-1:]], axis=0)  # Shape: (traj_len+1, obs_dim)
                    
                    # Create done segment (assume not done in the middle of trajectory)
                    done_seg = np.zeros((self.all_args.trajectory_length, 1), dtype=np.float32)
                    
                    self.preference_buffers[agent_id].add_trajectory(
                        obs_with_next, act_seg, rew_seg, done_seg
                    )
                    total_trajectories_added += 1
        
        print(f"Total trajectories added this episode: {total_trajectories_added}")

    def train_reward_models(self) -> List[dict]:
        """Train IPL reward models using preference learning"""
        if not self.all_args.use_reward_model:
            return []
        
        reward_train_infos = []
        
        for agent_id in range(self.num_agents):
            reward_model = self.reward_models[agent_id]
            preference_buffer = self.preference_buffers[agent_id]
            
            print(f"Agent {agent_id}: Buffer size = {preference_buffer.size}")
            
            # Generate fresh preference pairs for IPL training
            preference_pairs = preference_buffer.generate_preference_pairs(self.all_args.num_preference_pairs)
            
            if len(preference_pairs) == 0:
                print(f"Agent {agent_id}: No preference pairs generated (buffer size: {preference_buffer.size})")
                reward_train_infos.append({'q_loss': 0.0, 'chi2_loss': 0.0, 'value_loss': 0.0, 'reward_accuracy': 0.0})
                continue
                
            print(f"Agent {agent_id}: Generated {len(preference_pairs)} preference pairs")
            
            # Convert to batch format
            preference_batch = {
                'obs_1': np.array([pair['obs_1'] for pair in preference_pairs]),
                'obs_2': np.array([pair['obs_2'] for pair in preference_pairs]),
                'action_1': np.array([pair['action_1'] for pair in preference_pairs]),
                'action_2': np.array([pair['action_2'] for pair in preference_pairs]),
                'label': np.array([pair['label'] for pair in preference_pairs]),
                'discount': np.array([pair['discount'] for pair in preference_pairs])
            }
                
            # Convert to tensors
            obs_1 = torch.from_numpy(preference_batch['obs_1']).float().to(self.device)
            obs_2 = torch.from_numpy(preference_batch['obs_2']).float().to(self.device)
            action_1 = torch.from_numpy(preference_batch['action_1']).float().to(self.device)
            action_2 = torch.from_numpy(preference_batch['action_2']).float().to(self.device)
            labels = torch.from_numpy(preference_batch['label']).float().to(self.device)
            discount = torch.from_numpy(preference_batch['discount']).float().to(self.device)
            
            B, S = obs_1.shape[:2]  # Batch size, Sequence length
            S_minus_1 = S - 1  # Account for next_obs
            
            # Flatten for batch processing
            flat_obs_1 = obs_1[:, :-1].reshape(B * S_minus_1, -1)  # Current observations
            flat_obs_2 = obs_2[:, :-1].reshape(B * S_minus_1, -1)
            flat_next_obs_1 = obs_1[:, 1:].reshape(B * S_minus_1, -1)  # Next observations
            flat_next_obs_2 = obs_2[:, 1:].reshape(B * S_minus_1, -1)
            flat_action_1 = action_1.reshape(B * S_minus_1, -1)
            flat_action_2 = action_2.reshape(B * S_minus_1, -1)
            
            # Concatenate all data for joint processing
            all_obs = torch.cat([flat_obs_1, flat_obs_2], dim=0)
            all_next_obs = torch.cat([flat_next_obs_1, flat_next_obs_2], dim=0)
            all_actions = torch.cat([flat_action_1, flat_action_2], dim=0)
            
            # Encode observations
            obs_encoded = reward_model.encode_obs(all_obs)
            next_obs_encoded = reward_model.encode_obs_target(all_next_obs)
            
            # ============ Value Function Update ============
            with torch.no_grad():
                target_qs = reward_model.compute_target_q_values(obs_encoded, all_actions)
                target_q = torch.min(target_qs, dim=0)[0]  # Min over ensemble
            
            v_values = reward_model.compute_value(obs_encoded)
            value_loss = reward_model._iql_loss(v_values, target_q, reward_model.expectile)
            
            reward_model.value_optimizer.zero_grad()
            value_loss.backward()
            reward_model.value_optimizer.step()
            
            # ============ Q-Function Update ============
            # Re-encode observations to avoid gradient graph issues
            obs_encoded_q = reward_model.encode_obs(all_obs)
            
            # Compute Q-values
            q_values = reward_model.compute_q_values(obs_encoded_q, all_actions)
            
            # Compute next state values for reward calculation
            with torch.no_grad():
                next_values = reward_model.compute_value(next_obs_encoded)
            
            # Compute rewards: R = Q(s,a) - γ * V(s')
            # Note: We assume no terminal states in trajectory segments for simplicity
            rewards = q_values - discount[0] * next_values.unsqueeze(0)  # Shape: (2, 2*B*S)
            
            # Split rewards back into trajectory segments
            r1, r2 = torch.chunk(rewards, 2, dim=1)  # Each shape: (2, B*S)
            r1 = r1.view(2, B, S_minus_1)  # Shape: (ensemble, batch, seq)
            r2 = r2.view(2, B, S_minus_1)
            
            # Sum over sequence dimension to get trajectory returns
            traj_return_1 = r1.sum(dim=2)  # Shape: (ensemble, batch)
            traj_return_2 = r2.sum(dim=2)  # Shape: (ensemble, batch)
            
            # Compute preference logits: preference for trajectory 2 over trajectory 1
            logits = traj_return_2 - traj_return_1  # Shape: (ensemble, batch)
            
            # Preference loss (average over ensemble)
            expanded_labels = labels.unsqueeze(0).expand(2, -1)  # Shape: (ensemble, batch)
            q_loss = reward_model.reward_criterion(logits, expanded_labels).mean()
            
            # Chi-squared regularization to prevent rewards from being too large
            chi2_loss = reward_model.chi2_coeff * (rewards ** 2).mean()
            
            # Total Q-function loss
            total_q_loss = q_loss + chi2_loss
            
            reward_model.q_optimizer.zero_grad()
            total_q_loss.backward()
            reward_model.q_optimizer.step()
            
            # ============ Target Network Update ============
            reward_model.update_target_networks()
            
            # ============ Compute Metrics ============
            with torch.no_grad():
                probs = torch.sigmoid(logits.mean(dim=0))  # Average over ensemble
                predicted_preferences = (probs > 0.5).float()
                accuracy = (predicted_preferences == labels).float().mean().item()
            
            # Update reward statistics for normalization
            reward_model.update_reward_stats(rewards.mean(dim=0))  # Average over ensemble
            
            info = {
                'q_loss': q_loss.item(),
                'chi2_loss': chi2_loss.item(),
                'value_loss': value_loss.item(),
                'reward_accuracy': accuracy,
                'avg_reward': rewards.mean().item(),
                'reward_std': rewards.std().item(),
                'reward_norm_mean': reward_model.reward_mean.item(),
                'reward_norm_std': reward_model.reward_std.item()
            }
            
            reward_train_infos.append(info)
            print(f"Agent {agent_id}: Q-loss={q_loss.item():.4f}, Value-loss={value_loss.item():.4f}, Accuracy={accuracy:.3f}")
        
        return reward_train_infos

# ----------------- #
# Environment Setup #
# ----------------- #

def make_train_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "mujoco":
                env_args = {"scenario": all_args.scenario, "agent_conf": all_args.agent_conf, "agent_obsk": all_args.agent_obsk, "episode_limit": 1000}
                env = MujocoMulti(env_args=env_args)
            else:
                raise NotImplementedError(f"Can not support the {all_args.env_name} environment.")
            env.seed(all_args.seed + rank * 1000)
            return env
        return init_env
    if all_args.n_rollout_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])

def make_eval_env(all_args):
    def get_env_fn(rank):
        def init_env():
            if all_args.env_name == "mujoco":
                env_args = {"scenario": all_args.scenario, "agent_conf": all_args.agent_conf, "agent_obsk": all_args.agent_obsk, "episode_limit": 1000}
                env = MujocoMulti(env_args=env_args)
            else:
                raise NotImplementedError(f"Can not support the {all_args.env_name} environment.")
            env.seed(all_args.seed * 50000 + rank * 10000)
            return env
        return init_env
    if all_args.n_eval_rollout_threads == 1:
        return ShareDummyVecEnv([get_env_fn(0)])
    else:
        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])

# ---- #
# Main #
# ---- #

def main():
    all_args = tyro.cli(Args)
    if all_args.seed_specify:
        all_args.seed = all_args.running_id
    else:
        all_args.seed = np.random.randint(1000, 10000)
    print(f"seed is : {all_args.seed}")

    if all_args.cuda and torch.cuda.is_available():
        print("choose to use gpu...")
        device = torch.device("cuda:0")
        torch.set_num_threads(all_args.n_training_threads)
        if all_args.cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
            torch.set_float32_matmul_precision('high')
    else:
        print("choose to use cpu...")
        device = torch.device("cpu")
        torch.set_num_threads(all_args.n_training_threads)

    setproctitle.setproctitle(f"{all_args.algorithm_name}-{all_args.env_name}-{all_args.experiment_name}@{all_args.user_name}")

    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)

    envs = make_train_env(all_args)
    eval_envs = make_eval_env(all_args) if all_args.use_eval else None
    num_agents = envs.n_agents

    config = {
        "all_args": all_args, "envs": envs, "eval_envs": eval_envs,
        "num_agents": num_agents, "device": device
    }

    runner = MujocoRunner(config)
    runner.run()

    envs.close()
    if all_args.use_eval and eval_envs is not envs:
        eval_envs.close()

    runner.writter.close()

if __name__ == "__main__":
    main()
