# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import os
import statistics
import time
import torch
import warnings
from collections import deque

from typing import Union
import math

from torch.utils.data import TensorDataset, DataLoader
import pytorch_lightning as pl

import rsl_rl
from rsl_rl.algorithms import PPO
from rsl_rl.env import VecEnv
from rsl_rl.modules import (
    ActorCriticBase, 
    ActorCritic, 
    ActorCriticRecurrent, 
    resolve_rnd_config, resolve_symmetry_config, 
    ExtendableActorCritic, 
    HierarchicalActorCritic, 
    ActorCriticForAnalysis,
    ExtendableActorCriticRecurrent, 
    GatedActorCriticWithINV, 
    GatedMultiActorCritic, 
    DLStdActorCritic,
    P4RLAsymmetricActorCritic,
    ActorCriticConstrainedStd,
    ActorZeroOutputPretrainLightning,
    P4RLHamburgerCritic

)
# from rsl_rl.modules import *
from rsl_rl.utils import resolve_obs_groups, store_code_state
from isaaclab_rl.rsl_rl import RslRlVecEnvWrapper
from isaaclab.utils.buffers import CircularBuffer
import h5py



class OnPolicyRunner:
    """On-policy runner for training and evaluation of actor-critic methods."""

    def __init__(self, env: RslRlVecEnvWrapper, train_cfg: dict, log_dir: str | None = None, device="cpu", save_trajectories_prob: float = 0.0):
        self.cfg = train_cfg
        self.alg_cfg = train_cfg["algorithm"]
        self.policy_cfg = train_cfg["policy"]
        self.device = device
        self.env = env
        if save_trajectories_prob > 0.0:
            # save trajectories for inverse dynamics training
            self.save_trajectories_prob = save_trajectories_prob
            self.inv_data_counter = 0
            self.inv_dataset_path = os.path.join(log_dir, "dataset/inv_dynamics_dataset.h5")
            
            self.init_inv_dataset()
            # initialize circular buffers for storing history observations
            self.inv_input_buffer = CircularBuffer(
                max_len = self.env.max_episode_length, # maybe needed later, but keep an eye on memory usage
                # max_len= 2,
                batch_size=self.env.num_envs,
                device=self.device,
            )

            self.action_hist_buffer = CircularBuffer(
                max_len = self.env.max_episode_length, # maybe needed later, but keep an eye on memory usage
                # max_len= 2,
                batch_size=self.env.num_envs,
                device=self.device,
            )

        else:
            self.save_trajectories_prob = None

        # check if multi-gpu is enabled
        self._configure_multi_gpu()

        # store training configuration
        self.num_steps_per_env = self.cfg["num_steps_per_env"]
        self.save_interval = self.cfg["save_interval"]

        # query observations from environment for algorithm construction
        obs = self.env.get_observations()
        default_sets = ["critic"]
        if "rnd_cfg" in self.alg_cfg and self.alg_cfg["rnd_cfg"] is not None:
            default_sets.append("rnd_state")
        self.cfg["obs_groups"] = resolve_obs_groups(obs, self.cfg["obs_groups"], default_sets)

        # create the algorithm
        self.alg = self._construct_algorithm(obs)

        # Decide whether to disable logging
        # We only log from the process with rank 0 (main process)
        self.disable_logs = self.is_distributed and self.gpu_global_rank != 0

        # Logging
        self.log_dir = log_dir
        self.writer = None
        self.tot_timesteps = 0
        self.tot_time = 0
        self.current_learning_iteration = 0
        self.git_status_repos = [rsl_rl.__file__]

    def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False, 
              unfreeze_all_params_at_iteration: int | None = None):  # noqa: C901
        # initialize writer
        self._prepare_logging_writer()

        # randomize initial episode lengths (for exploration)
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(
                self.env.episode_length_buf, high=int(self.env.max_episode_length)
            )

        # start learning
        obs = self.env.get_observations().to(self.device)
        self.train_mode()  # switch to train mode (for dropout for example)

        # Book keeping
        ep_infos = []
        rewbuffer = deque(maxlen=100)
        lenbuffer = deque(maxlen=100)
        cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        # create buffers for logging extrinsic and intrinsic rewards
        if self.alg.rnd or self.inv_dynamics_reward_enabled:
            erewbuffer = deque(maxlen=100)
            irewbuffer = deque(maxlen=100)
            cur_ereward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
            cur_ireward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        # Ensure all parameters are in-synced
        if self.is_distributed:
            print(f"Synchronizing parameters for rank {self.gpu_global_rank}...")
            self.alg.broadcast_parameters()

        # P4RL: critic burn-in 
        if self.cfg["start_actor_RL_at_iteration"] > self.cfg["start_critic_RL_at_iteration"]:
            # there is a critic burn-in phase
            # we assume that, before the start_critic_RL_at_iteration, the pretrained module inside critic should be frozen. 
            # then at the start_critic_RL_at_iteration, it should resume the pre-defined state, whether keeping frozen or unfrozen.
            if isinstance(self.alg.policy.critic, P4RLHamburgerCritic):
                self.alg.policy.critic.freeze_pretrained_module()

        # Start training
        start_iter = self.current_learning_iteration
        tot_iter = start_iter + num_learning_iterations
        for it in range(start_iter, tot_iter):
            start = time.time()

            if unfreeze_all_params_at_iteration is not None and it == start_iter + unfreeze_all_params_at_iteration:
                self.alg.policy.unfreeze_all()
                # self.alg.reset_optimizer() # reset optimizer state

            # Rollout
            with torch.inference_mode():
                for _ in range(self.num_steps_per_env):
                    # Sample actions
                    actions = self.alg.act(obs)
                    # Step the environment
                    obs, rewards, dones, extras = self.env.step(actions.to(self.env.device))
                    # Move to device
                    obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device))
                    
                    # P4RL: save history observations for inverse dynamics
                    if self.save_trajectories_prob:
                        # since for envs with done==1, the observations are those after reset, so we first reset the 
                        # buffer before appending
                        if (dones > 0).any():
                            # reset the buffer for those envs
                            idx = dones.nonzero(as_tuple=False).squeeze(1)
                            self.record_trajectories(idx, it)
                            self.inv_input_buffer.reset(idx)
                            self.action_hist_buffer.reset(idx)
                        self.inv_input_buffer.append(obs["inv_dynamics_input"])
                        self.action_hist_buffer.append(actions)
                        self.alg.process_env_step(obs, rewards, dones, extras, 
                                                  buffers={"inv_input_buffer": self.inv_input_buffer, "action_hist_buffer": self.action_hist_buffer})
                    else: 
                        # Process env step and store in buffer
                        self.alg.process_env_step(obs, rewards, dones, extras)

                    # Intrinsic rewards (extracted here only for logging)!
                    intrinsic_rewards = self.alg.intrinsic_rewards if (self.alg.rnd or self.inv_dynamics_reward_enabled) else None
                    
                    if self.log_dir is not None:
                        if "episode" in extras:
                            ep_infos.append(extras["episode"])
                        elif "log" in extras:
                            ep_infos.append(extras["log"])
                        # Update rewards
                        if self.alg.rnd or self.inv_dynamics_reward_enabled:
                            cur_ereward_sum += rewards
                            cur_ireward_sum += intrinsic_rewards  # type: ignore
                            cur_reward_sum += rewards + intrinsic_rewards
                        else:
                            cur_reward_sum += rewards
                        # Update episode length
                        cur_episode_length += 1
                        # Clear data for completed episodes
                        # -- common
                        new_ids = (dones > 0).nonzero(as_tuple=False)
                        rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                        lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
                        cur_reward_sum[new_ids] = 0
                        cur_episode_length[new_ids] = 0
                        # -- intrinsic and extrinsic rewards
                        if self.alg.rnd or self.inv_dynamics_reward_enabled:
                            erewbuffer.extend(cur_ereward_sum[new_ids][:, 0].cpu().numpy().tolist())
                            irewbuffer.extend(cur_ireward_sum[new_ids][:, 0].cpu().numpy().tolist())
                            cur_ereward_sum[new_ids] = 0
                            cur_ireward_sum[new_ids] = 0

                stop = time.time()
                collection_time = stop - start
                start = stop

                # compute returns
                self.alg.compute_returns(obs)

            # zero output actor pretraining
            if self.cfg["actor_zero_output_pretrain"] is not None:
                assert self.cfg["actor_zero_output_pretrain"] <= self.cfg["start_actor_RL_at_iteration"], "actor_zero_output_pretrain must be less than start_actor_RL_at_iteration"
                if it <= self.cfg["actor_zero_output_pretrain"]:
                    self.add_data_to_actor_zero_output_pretrain_dataset()
                if it == self.cfg["actor_zero_output_pretrain"] - 1:
                    self.actor_zero_output_pretrain()

            if self.cfg["start_actor_RL_at_iteration"] > self.cfg["start_critic_RL_at_iteration"] and isinstance(self.alg.policy.critic, P4RLHamburgerCritic) and it == self.cfg["start_actor_RL_at_iteration"]:
                # unfreeze the pretrained module
                self.alg.policy.critic.reset_trainable_state_of_pretrained_module()

            # update policy
            loss_dict = self.alg.update(
                update_actor=(it >= self.cfg["start_actor_RL_at_iteration"]),
                update_critic=(it >= self.cfg["start_critic_RL_at_iteration"])
                )

            if isinstance(self.alg.policy, GatedActorCriticWithINV) or isinstance(self.alg.policy, GatedMultiActorCritic):
                mean_gating_value = self.alg.policy.gating_value.mean().item()

            stop = time.time()
            learn_time = stop - start
            self.current_learning_iteration = it
            # p4rl: retrain dynamics model
            if self.alg.inv_ensemble:
                if (it+1) % self.inv_retrain_interval == 0:
                    self.save_dataset_to_file()
                    if self.alg.inv_ensemble.check_if_data_is_sufficient_for_training(self.inv_dataset_path):

                    # for debugging
                    # if True:

                        # retrain the inverse dynamics model
                        print(f"Retraining inverse dynamics model at iteration {it+1} with dataset {self.inv_dataset_path}.")
                        inv_dynamics_retrain_epoch_errors, inv_train_samples_num = self.alg.inv_ensemble.retrain_models(
                            dataset_path = self.inv_dataset_path, 
                            model_save_dir= self.log_dir + "/inv_ensemble_models",
                            iteration_num = it, 
                            epochs=5, # TODO: make it configurable
                            window_size=self.alg.inv_input_timesteps)
                
                    else:
                        print(f"Skipping retraining of inverse dynamics model at iteration {it+1} due to insufficient data in {self.inv_dataset_path}.")

                # extract the number of inv dataset for logging

            if self.save_trajectories_prob:
                num_inv_dataset = self.inv_data_counter

            # log info
            if self.log_dir is not None and not self.disable_logs:
                # Log information
                self.log(locals())
                # Save model
                if it % self.save_interval == 0:
                    self.save(os.path.join(self.log_dir, f"model_{it}.pt"))

            # Clear episode infos
            ep_infos.clear()
            # Save code state
            if it == start_iter and not self.disable_logs:
                # obtain all the diff files
                git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
                # if possible store them to wandb
                if self.logger_type in ["wandb", "neptune"] and git_file_paths:
                    for path in git_file_paths:
                        self.writer.save_file(path)

            if it == tot_iter - 1 and self.save_trajectories_prob is not None:
                self.save_dataset_to_file()

        # Save the final model after training
        if self.log_dir is not None and not self.disable_logs:
            self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))

    def log(self, locs: dict, width: int = 80, pad: int = 35):
        # Compute the collection size
        collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size
        # Update total time-steps and time
        self.tot_timesteps += collection_size
        self.tot_time += locs["collection_time"] + locs["learn_time"]
        iteration_time = locs["collection_time"] + locs["learn_time"]

        # -- Episode info
        ep_string = ""
        if locs["ep_infos"]:
            for key in locs["ep_infos"][0]:
                infotensor = torch.tensor([], device=self.device)
                for ep_info in locs["ep_infos"]:
                    # handle scalar and zero dimensional tensor infos
                    if key not in ep_info:
                        continue
                    if not isinstance(ep_info[key], torch.Tensor):
                        ep_info[key] = torch.Tensor([ep_info[key]])
                    if len(ep_info[key].shape) == 0:
                        ep_info[key] = ep_info[key].unsqueeze(0)
                    infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
                value = torch.mean(infotensor)
                # log to logger and terminal
                if "/" in key:
                    self.writer.add_scalar(key, value, locs["it"])
                    ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
                else:
                    self.writer.add_scalar("Episode/" + key, value, locs["it"])
                    ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""

        mean_std = self.alg.policy.action_std.mean()
        fps = int(collection_size / (locs["collection_time"] + locs["learn_time"]))

        # -- Losses
        for key, value in locs["loss_dict"].items():
            self.writer.add_scalar(f"Loss/{key}", value, locs["it"])
        self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])

        # -- Policy
        self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
        if "mean_gating_value" in locs:
            self.writer.add_scalar("Policy/gating_value", locs["mean_gating_value"], locs["it"])
    
        # -- Performance
        self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
        self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
        self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])

        # -- Training
        if len(locs["rewbuffer"]) > 0:
            # separate logging for intrinsic and extrinsic rewards
            if hasattr(self.alg, "rnd") and self.alg.rnd:
                self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"])
                self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"])
                self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, locs["it"])
            if self.inv_dynamics_reward_enabled:
                self.writer.add_scalar("INV/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), locs["it"])
                self.writer.add_scalar("INV/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), locs["it"])
                self.writer.add_scalar("INV/weight", self.alg.inv_reward_scale, locs["it"])
            # everything else
            self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
            self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
            if self.logger_type != "wandb":  # wandb does not support non-integer x-axis logging
                self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time)
                self.writer.add_scalar(
                    "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
                )

        # P4RL: log inverse dynamics metrics
        if self.alg.inv_ensemble:
            if hasattr(self, "inv_data_counter"):
                self.writer.add_scalar("InvDynamics/num_dataset", self.inv_data_counter, locs["it"])

            if "inv_dynamics_retrain_epoch_errors" in locs:
                # assert self.writer is WandbSummaryWriter
                self.writer.add_scalar("InvDynamics/inv_train_samples_num", locs["inv_train_samples_num"], locs["it"])
                # for ensemble_model_idx in range(len(locs["inv_dynamics_retrain_epoch_errors"])):
                #     table_data = [(enum, error) for enum, error in enumerate(locs["inv_dynamics_retrain_epoch_errors"][ensemble_model_idx])]
                #     self.writer.add_table_wandb(f"InvDynamics/retrain_epoch_error/it_{locs['it']}/model_{ensemble_model_idx}", columns=["inv_epoch", "loss"], data=table_data)

        # -- print log string
        str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "

        if len(locs["rewbuffer"]) > 0:
            log_string = (
                f"""{'#' * width}\n"""
                f"""{str.center(width, ' ')}\n\n"""
                f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                    'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
            )
            # -- Losses
            for key, value in locs["loss_dict"].items():
                log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n"""
            # -- Rewards
            if hasattr(self.alg, "rnd") and self.alg.rnd:
                log_string += (
                    f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n"""
                    f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n"""
                )
            log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
            # -- episode info
            log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
        else:
            log_string = (
                f"""{'#' * width}\n"""
                f"""{str.center(width, ' ')}\n\n"""
                f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
                    'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
                f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
            )
            for key, value in locs["loss_dict"].items():
                log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""

        log_string += ep_string
        log_string += (
            f"""{'-' * width}\n"""
            f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
            f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
            f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n"""
            f"""{'ETA:':>{pad}} {time.strftime(
                "%H:%M:%S",
                time.gmtime(
                    self.tot_time / (locs['it'] - locs['start_iter'] + 1)
                    * (locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])
                )
            )}\n"""
        )
        print(log_string)

    def save(self, path: str, infos=None):
        # -- Save model
        saved_dict = {
            "model_state_dict": self.alg.policy.state_dict(),
            "actor_optimizer_state_dict": self.alg.actor_optimizer.state_dict(),
            "critic_optimizer_state_dict": self.alg.critic_optimizer.state_dict(),
            "iter": self.current_learning_iteration,
            "infos": infos,
        }
        # -- Save RND model if used
        if hasattr(self.alg, "rnd") and self.alg.rnd:
            saved_dict["rnd_state_dict"] = self.alg.rnd.state_dict()
            saved_dict["rnd_optimizer_state_dict"] = self.alg.rnd_optimizer.state_dict()
        torch.save(saved_dict, path)

        # upload model to external logging service
        # to save cloud storage, commenting out uploading model to external logging service for now
        # if self.logger_type in ["neptune", "wandb"] and not self.disable_logs:
        #     self.writer.save_model(path, self.current_learning_iteration)

    def load(self, path: str, load_optimizer: bool = True, load_only_pretrain: bool = False, load_pretrain_percentage: float = 1.0, trim_ratio: Union[float, None] = None, reset_noise_params=False, map_location: str | None = None):
        """
        Load a previously saved model.
        Args:
            path (str): Path to the saved model.
            load_optimizer (bool): If True, load the optimizer state.
            load_only_pretrain (bool): If True, only load the pre-trained module.
            load_pretrain_percentage (float): from 0.0 to 1.0, percentage of pre-trained model weights to load from shallow to deep.
            trim_ratio (float | None): If not None, only the largest trim_ratio weights of the model will be kept, and rest will be set to 0.
        
        """
        loaded_dict = torch.load(path, weights_only=False, map_location=map_location)
        # -- Load model

        # reset noise parameters if needed
        if reset_noise_params:
            loaded_dict["model_state_dict"]["std"].fill_(self.policy_cfg["init_noise_std"])

        resumed_training = self.alg.policy.load_state_dict(loaded_dict["model_state_dict"])

        # -- Load RND model if used
        if hasattr(self.alg, "rnd") and self.alg.rnd:
            self.alg.rnd.load_state_dict(loaded_dict["rnd_state_dict"])
        # -- load optimizer if used
        if load_optimizer and resumed_training:
            try:
                # -- algorithm optimizer
                self.alg.actor_optimizer.load_state_dict(loaded_dict["actor_optimizer_state_dict"])
                self.alg.critic_optimizer.load_state_dict(loaded_dict["critic_optimizer_state_dict"])
            except KeyError:
                if "optimizer_state_dict" in loaded_dict:
                    # original code:
                    # self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
                    raise RuntimeError(
                        f"P4RL: Failed to load optimizer states but found 'optimizer_state_dict' in the loaded dictionary. "
                        "Which means that the model was saved with an main-rsl branch version of the code where the updates of actor and critic are dealt by one single optimizer. "
                        "and not supported here."
                        "You can either use the main-rsl branch version of the code to load this model, or disable load_optimizer in the load function."
                    )
            # -- RND optimizer if used
            if hasattr(self.alg, "rnd") and self.alg.rnd:
                self.alg.rnd_optimizer.load_state_dict(loaded_dict["rnd_optimizer_state_dict"])
        # -- load current learning iteration
        # if resumed_training:
        #     self.current_learning_iteration = loaded_dict["iter"]
        return loaded_dict["infos"]

    def get_inference_policy(self, device=None):
        self.eval_mode()  # switch to evaluation mode (dropout for example)
        if device is not None:
            self.alg.policy.to(device)
        return self.alg.policy.act_inference

    def train_mode(self):
        # -- PPO
        self.alg.policy.train()
        # -- RND
        if hasattr(self.alg, "rnd") and self.alg.rnd:
            self.alg.rnd.train()

    def eval_mode(self):
        # -- PPO
        self.alg.policy.eval()
        # -- RND
        if hasattr(self.alg, "rnd") and self.alg.rnd:
            self.alg.rnd.eval()

    def add_git_repo_to_log(self, repo_file_path):
        self.git_status_repos.append(repo_file_path)

    """
    Helper functions.
    """

    def _configure_multi_gpu(self):
        """Configure multi-gpu training."""
        # check if distributed training is enabled
        self.gpu_world_size = int(os.getenv("WORLD_SIZE", "1"))
        self.is_distributed = self.gpu_world_size > 1

        # if not distributed training, set local and global rank to 0 and return
        if not self.is_distributed:
            self.gpu_local_rank = 0
            self.gpu_global_rank = 0
            self.multi_gpu_cfg = None
            return

        # get rank and world size
        self.gpu_local_rank = int(os.getenv("LOCAL_RANK", "0"))
        self.gpu_global_rank = int(os.getenv("RANK", "0"))

        # make a configuration dictionary
        self.multi_gpu_cfg = {
            "global_rank": self.gpu_global_rank,  # rank of the main process
            "local_rank": self.gpu_local_rank,  # rank of the current process
            "world_size": self.gpu_world_size,  # total number of processes
        }

        # check if user has device specified for local rank
        if self.device != f"cuda:{self.gpu_local_rank}":
            raise ValueError(
                f"Device '{self.device}' does not match expected device for local rank '{self.gpu_local_rank}'."
            )
        # validate multi-gpu configuration
        if self.gpu_local_rank >= self.gpu_world_size:
            raise ValueError(
                f"Local rank '{self.gpu_local_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
            )
        if self.gpu_global_rank >= self.gpu_world_size:
            raise ValueError(
                f"Global rank '{self.gpu_global_rank}' is greater than or equal to world size '{self.gpu_world_size}'."
            )

        # initialize torch distributed
        torch.distributed.init_process_group(backend="nccl", rank=self.gpu_global_rank, world_size=self.gpu_world_size)
        # set device to the local rank
        torch.cuda.set_device(self.gpu_local_rank)

    def _construct_algorithm(self, obs) -> PPO:
        """Construct the actor-critic algorithm."""
        # resolve RND config
        self.alg_cfg = resolve_rnd_config(self.alg_cfg, obs, self.cfg["obs_groups"], self.env)

        # resolve symmetry config
        self.alg_cfg = resolve_symmetry_config(self.alg_cfg, self.env)

        # P4RL added ------
        if self.alg_cfg.get("inv_dynamics_cfg") is not None:
            input_inv_dim = obs["inv_dynamics_input"].shape[-1]
            self.inv_retrain_interval = self.alg_cfg["inv_dynamics_cfg"]["retrain_interval"]
            self.alg_cfg["inv_dynamics_cfg"]["dim_states"] = input_inv_dim
            self.alg_cfg["inv_dynamics_cfg"]["dim_actions"] = self.env.num_actions
            self.inv_dynamics_reward_enabled = (self.alg_cfg["inv_dynamics_cfg"]["reward_scale"] > 0.0)
        else: 
            self.inv_dynamics_reward_enabled = False
        # P4RL added ------

        # resolve deprecated normalization config
        if self.cfg.get("empirical_normalization") is not None:
            warnings.warn(
                "The `empirical_normalization` parameter is deprecated. Please set `actor_obs_normalization` and "
                "`critic_obs_normalization` as part of the `policy` configuration instead.",
                DeprecationWarning,
            )
            if self.policy_cfg.get("actor_obs_normalization") is None:
                self.policy_cfg["actor_obs_normalization"] = self.cfg["empirical_normalization"]
            if self.policy_cfg.get("critic_obs_normalization") is None:
                self.policy_cfg["critic_obs_normalization"] = self.cfg["empirical_normalization"]

        # initialize the actor-critic
        actor_critic_class = eval(self.policy_cfg.pop("class_name"))
        actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
            obs, self.cfg["obs_groups"], self.env.num_actions, **self.policy_cfg
        ).to(self.device)

        # initialize the algorithm
        alg_class = eval(self.alg_cfg.pop("class_name"))
        alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg)

        # initialize the storage
        alg.init_storage(
            "rl",
            self.env.num_envs,
            self.num_steps_per_env,
            obs,
            [self.env.num_actions],
        )

        return alg

    def _prepare_logging_writer(self):
        """Prepares the logging writers."""
        if self.log_dir is not None and self.writer is None and not self.disable_logs:
            # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
            self.logger_type = self.cfg.get("logger", "tensorboard")
            self.logger_type = self.logger_type.lower()

            if self.logger_type == "neptune":
                from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter

                self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
                self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
            elif self.logger_type == "wandb":
                from rsl_rl.utils.wandb_utils import WandbSummaryWriter

                self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
                self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
            elif self.logger_type == "tensorboard":
                from torch.utils.tensorboard import SummaryWriter

                self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
            else:
                raise ValueError("Logger type not found. Please choose 'neptune', 'wandb' or 'tensorboard'.")


    def init_inv_dataset(self):
        """Initialize the inverse dynamics dataset."""
        os.makedirs(os.path.dirname(self.inv_dataset_path), exist_ok=True)
        if os.path.exists(self.inv_dataset_path):
            raise ValueError("the inv dataset object already exists, please remove it before initializing a new one.")
        with h5py.File(self.inv_dataset_path, "w"):
            pass 
        # initialize the counter
        self.inv_data_counter = 0
        self.inv_dataset_obj_memory = h5py.File('inv_dataset.h5', driver='core', backing_store=False, mode='w')


    def record_trajectories(self, idx: torch.Tensor, iteration_num: int):
        """dump the trajectories corresponding to these indices to the storage before clearing those buffers."""
        for i in idx:
            if self.inv_input_buffer.current_length[i] == 0:
                # skip if the buffer is empty (env initialized to a terminating pose)
                continue
            if torch.rand(1).item() > self.save_trajectories_prob:
                # skip this trajectory with a probability
                continue
            # create a path for the trajectory
            self.inv_data_counter += 1

            # extract slices from buffers
            inv_input = self.inv_input_buffer.buffer[i, -self.inv_input_buffer.current_length[i]:]
            actions = self.action_hist_buffer.buffer[i, -self.action_hist_buffer.current_length[i]:]

            # write to HDF5 file with gzip compression
            grp = self.inv_dataset_obj_memory.create_group(f"it_{iteration_num:04d}_trajectory_{self.inv_data_counter:06d}")
            grp.create_dataset("inv_input", data=inv_input.cpu().numpy())
            grp.create_dataset("actions", data=actions.cpu().numpy())

    def save_dataset_to_file(self):
        self.inv_dataset_obj_memory.flush() 
        with open(self.inv_dataset_path, "wb") as out_file:
            out_file.write(self.inv_dataset_obj_memory.id.get_file_image())
        print(f"Dataset: Saved {self.inv_data_counter} inverse dynamics trajectories to {self.inv_dataset_path}!")

    def add_data_to_actor_zero_output_pretrain_dataset(self):
        if not hasattr(self, "actor_zero_output_pretrain_dataset"):
            self.actor_zero_output_pretrain_dataset = self.alg.storage.observations["policy"]
        else:
            self.actor_zero_output_pretrain_dataset = torch.cat(
                (self.actor_zero_output_pretrain_dataset, self.alg.storage.observations["policy"]), dim=0
            )

    def actor_zero_output_pretrain(self):
        assert hasattr(self.alg.policy.actor, "pretrained_module"), "Actor must have a pretrained_module for zero output pretraining."
        zo_pt_dataset = TensorDataset(self.actor_zero_output_pretrain_dataset.flatten(0, 1))
        print(f"Starting actor zero output pretraining with {len(zo_pt_dataset)} samples...")
        zo_pt_dataloader = DataLoader(zo_pt_dataset, batch_size=1024, shuffle=True, drop_last=True)
        l_module = ActorZeroOutputPretrainLightning(actor=self.alg.policy.actor)
        trainer = pl.Trainer(max_epochs=5, accelerator="auto")
        trainer.fit(l_module, zo_pt_dataloader)
        # print per epoch loss
        self.alg.policy.actor.to("cuda")
        loss_seq_tensor = torch.tensor(l_module.train_losses, device=self.device)
        print(f"Actor zero output pretraining finished. Losses sequence: {loss_seq_tensor[::loss_seq_tensor.shape[0]//10]}")





