import os
import os.path
import subprocess
import json
from collections import defaultdict
from abc import ABC
from typing import Any, Callable, Dict, List, Optional

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm

from openrlhf.models import GPTLMLoss, PolicyLoss, ValueLoss, Actor
from openrlhf.models.utils import masked_mean
from openrlhf.utils.distributed_sampler import DistributedSampler

from openrlhf.trainer.ppo_utils import AdaptiveKLController, FixedKLController, NaiveReplayBuffer
from .experience_maker import Experience, NaiveExperienceMaker
import time

class PPOTrainer(ABC):
    """
    Trainer for Proximal Policy Optimization (PPO) algorithm.

    Args:
        strategy (Strategy): The training strategy to use.
        actor (Actor): The actor model in the PPO algorithm.
        critic (nn.Module): The critic model in the PPO algorithm.
        reward_model (nn.Module): The reward model for calculating rewards in the RLHF setup.
        initial_model (Actor): The initial model for reference logits to limit actor updates in RLHF.
        ema_model (Actor): The exponential moving average model for stable training.
        actor_optim (Optimizer): The optimizer for the actor model.
        critic_optim (Optimizer): The optimizer for the critic model.
        actor_scheduler (Scheduler): The learning rate scheduler for the actor.
        critic_scheduler (Scheduler): The learning rate scheduler for the critic.
        ema_beta (float, defaults to 0.992): EMA decay rate for model stability.
        init_kl_coef (float, defaults to 0.001): Initial coefficient for KL divergence.
        kl_target (float, optional): Target value for KL divergence.
        kl_horizon (int, defaults to 10000): Horizon for KL annealing.
        ptx_coef (float, defaults to 0): Coefficient for supervised loss from pre-trained data.
        micro_train_batch_size (int, defaults to 8): Micro-batch size for actor training.
        buffer_limit (int, defaults to 0): Maximum size of the replay buffer.
        buffer_cpu_offload (bool, defaults to True): If True, offloads replay buffer to CPU.
        eps_clip (float, defaults to 0.2): Clipping coefficient for policy loss.
        value_clip (float, defaults to 0.2): Clipping coefficient for value function loss.
        micro_rollout_batch_size (int, defaults to 8): Micro-batch size for generating rollouts.
        gradient_checkpointing (bool, defaults to False): If True, enables gradient checkpointing.
        max_epochs (int, defaults to 1): Number of epochs to train.
        max_norm (float, defaults to 1.0): Maximum gradient norm for gradient clipping.
        tokenizer (Callable, optional): Tokenizer for input data.
        prompt_max_len (int, defaults to 128): Maximum length for prompts.
        dataloader_pin_memory (bool, defaults to True): If True, pins memory in the data loader.
        remote_rm_url (str, optional): URL for remote reward model API.
        reward_fn (Callable, optional): Custom reward function for computing rewards.
        **generate_kwargs: Additional arguments for model generation.
    """

    def __init__(
        self,
        strategy,
        actor: Actor,
        critic: nn.Module,
        reward_model: nn.Module,
        initial_model: Actor,
        ema_model: Actor,
        actor_optim: Optimizer,
        critic_optim: Optimizer,
        actor_scheduler,
        critic_scheduler,
        ema_beta: float = 0.992,
        init_kl_coef: float = 0.001,
        kl_target: float = None,
        kl_horizon: int = 10000,
        ptx_coef: float = 0,
        micro_train_batch_size: int = 8,
        buffer_limit: int = 0,
        buffer_cpu_offload: bool = True,
        eps_clip: float = 0.2,
        value_clip: float = 0.2,
        micro_rollout_batch_size: int = 8,
        gradient_checkpointing: bool = False,
        max_epochs: int = 1,
        max_norm: float = 1.0,
        tokenizer: Optional[Callable[[Any], dict]] = None,
        prompt_max_len: int = 128,
        dataloader_pin_memory: bool = True,
        remote_rm_url: str = None,
        reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
        **generate_kwargs,
    ) -> None:
        assert (
            not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None
        ), "reward_fn must be specified if using multiple reward models"

        super().__init__()
        self.strategy = strategy
        self.args = strategy.args
        self.micro_rollout_batch_size = micro_rollout_batch_size
        self.max_epochs = max_epochs
        self.tokenizer = tokenizer
        self.generate_kwargs = generate_kwargs
        self.dataloader_pin_memory = dataloader_pin_memory
        self.max_norm = max_norm
        self.ptx_coef = ptx_coef
        self.micro_train_batch_size = micro_train_batch_size
        self.kl_target = kl_target
        self.prompt_max_len = prompt_max_len
        self.ema_beta = ema_beta
        self.gradient_checkpointing = gradient_checkpointing
        self.reward_fn = reward_fn

        self.actor = actor
        self.critic = critic
        self.reward_model = reward_model
        self.remote_rm_url = remote_rm_url
        self.initial_model = initial_model
        self.ema_model = ema_model
        self.actor_optim = actor_optim
        self.critic_optim = critic_optim
        self.actor_scheduler = actor_scheduler
        self.critic_scheduler = critic_scheduler

        self.actor_loss_fn = PolicyLoss(eps_clip)
        self.critic_loss_fn = ValueLoss(value_clip)
        self.ptx_loss_fn = GPTLMLoss()

        self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)

        # Mixtral 8x7b
        self.aux_loss = self.args.aux_loss_coef > 1e-8

        if self.kl_target:
            self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon)
        else:
            self.kl_ctl = FixedKLController(init_kl_coef)

        self.experience_maker = NaiveExperienceMaker(
            actor,
            critic,
            reward_model,
            initial_model,
            tokenizer,
            prompt_max_len,
            self.kl_ctl,
            strategy,
            remote_rm_url,
            reward_fn,
        )
        packing_samples = getattr(self.args, "packing_samples", False)
        self.replay_buffer = NaiveReplayBuffer(
            micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples
        )

        # wandb/tensorboard setting
        self._wandb = None
        self._tensorboard = None
        if self.strategy.args.use_wandb and self.strategy.is_rank_0():
            import wandb

            self._wandb = wandb
            if not wandb.api.api_key:
                wandb.login(key=strategy.args.use_wandb)
            wandb.init(
                entity=strategy.args.wandb_org,
                project=strategy.args.wandb_project,
                group=strategy.args.wandb_group,
                name=strategy.args.wandb_run_name,
                config=strategy.args.__dict__,
                reinit=True,
            )

            wandb.define_metric("train/global_step")
            wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
            wandb.define_metric("eval/epoch")
            wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)

        # Initialize TensorBoard writer if wandb is not available
        if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
            from torch.utils.tensorboard import SummaryWriter

            os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
            log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
            self._tensorboard = SummaryWriter(log_dir=log_dir)

    def fit(
        self,
        args,
        prompts_dataloader,
        pretrain_dataloader,
        consumed_samples=0,
        num_update_steps_per_episodes=1,
    ) -> None:
        num_rollouts_per_episodes = (
            num_update_steps_per_episodes
            * args.train_batch_size
            // args.max_epochs
            // args.rollout_batch_size
            // args.n_samples_per_prompt
        )

        # get eval and save steps
        if args.eval_steps == -1:
            args.eval_steps = num_rollouts_per_episodes  # Evaluate once per epoch
        if args.save_steps == -1:
            args.save_steps = float("inf")  # do not save ckpt

        self.prompts_dataloader = prompts_dataloader
        self.pretrain_dataloader = pretrain_dataloader

        # Restore step and start_epoch
        steps = consumed_samples // args.rollout_batch_size + 1
        start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
        consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)

        for episode in range(start_episode, args.num_episodes):

            if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
                self.prompts_dataloader.sampler.set_epoch(
                    episode, consumed_samples=0 if episode > start_episode else consumed_samples
                )
            pbar = tqdm(
                range(self.prompts_dataloader.__len__()),
                desc=f"Episode [{episode + 1}/{args.num_episodes}]",
                disable=not self.strategy.is_rank_0(),
            )

            for rand_prompts, rand_rawdata in self.prompts_dataloader:
                experience_list, time_duration = self.experience_maker.make_experience_list(rand_prompts, rand_rawdata, steps, **self.generate_kwargs)
                ppo_train_time_start = time.time()
                for i, experience in enumerate(
                    experience_list
                ):
                    if i == 0:
                        output = self.tokenizer.batch_decode(
                            experience.sequences[0].unsqueeze(0), skip_special_tokens=True
                        )
                        self.strategy.print(output)
                    self.replay_buffer.append(experience)
                
                torch.cuda.empty_cache()
                self.replay_buffer.normalize("advantages", self.strategy)
                status = self.ppo_train(steps)
                self.replay_buffer.clear()
                torch.cuda.empty_cache()

                if "kl" in status:
                    self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
                pbar.set_postfix(status)
                # calculate metric for different datasets
                dataset_types = self.strategy.args.ppo_dataset_type
                # gather all experience
                experience_info_local = defaultdict(list)
                for experience in experience_list:
                    for key, value in experience.info.items():
                        experience_info_local[key].extend(value)
                experience_info_local = {k: torch.tensor(v) for k, v in experience_info_local.items()}
                world_size = torch.distributed.get_world_size()
                experience_info_batched_gather = [None for _ in range(world_size)]
                torch.distributed.all_gather_object(
                    object_list=experience_info_batched_gather,
                    obj=experience_info_local
                )
                if self.strategy.is_rank_0():
                    self.strategy.print("experience_info_batched_gather length:", len(experience_info_batched_gather))
                    experience_info_batched = {k: torch.cat([v[k] for v in experience_info_batched_gather]) for k in experience_info_batched_gather[0].keys()}
                    self.strategy.print("experience_info_batched inside length:", experience_info_batched["scores"].shape)
                    if steps <= self.strategy.args.discrimination_only_step:
                        status.update(self.compute_dataset_only_discrimination(experience_info_batched, dataset_types))
                    else:
                        status.update(self.compute_dataset_metrics(experience_info_batched, dataset_types))
                ppo_train_time_end = time.time()
                ppo_train_time = ppo_train_time_end - ppo_train_time_start
                time_duration["ppo_train_time"] = ppo_train_time
                # 打印下花费的时间
                if self._wandb is not None and self.strategy.is_rank_0():
                    print(f"time_duration: {time_duration}")
                    logs = {
                        "time/critique" : time_duration["critique_generate_time"],
                        # "eval/step": step,
                        "time/refinement": time_duration["refinement_generate_time"],
                        "time/ppo_train_time": time_duration["ppo_train_time"]
                    }
                    self._wandb.log(logs)

                # logs/checkpoints
                client_states = {"consumed_samples": steps * args.rollout_batch_size}
                self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)
                pbar.update()
                steps = steps + 1
                # save model by eval_steps
                if self.strategy.is_rank_0():
                    if steps % args.eval_steps == 0:
                        self.strategy.print("-------------------------------------------------------------")
                        self.strategy.print(f"Save and Eval at Step {steps}")
                        self.strategy.save_model(
                            self.actor,
                            self.tokenizer,
                            f'{args.save_path}/step{steps}',
                        )
                        self.strategy.print("-------------------------------------------------------------")


            # save model every episode
            if self.strategy.is_rank_0():
                self.strategy.print("-------------------------------------------------------------")
                self.strategy.print(f"Eval and save policy at Episode {episode + 1}")
                self.strategy.save_model(
                    self.actor,
                    self.tokenizer,
                    f'{args.save_path}/episode{episode + 1}',
                )
                self.strategy.print("-------------------------------------------------------------")

        if self._wandb is not None and self.strategy.is_rank_0():
            self._wandb.finish()
        if self._tensorboard is not None and self.strategy.is_rank_0():
            self._tensorboard.close()

    def ppo_train(self, global_steps=0):
        # replay buffer may be empty at first, we should rebuild at each training
        dataloader = DataLoader(
            self.replay_buffer,
            batch_size=self.replay_buffer.sample_batch_size,
            shuffle=True,
            drop_last=True,
            pin_memory=self.dataloader_pin_memory,
            collate_fn=self.replay_buffer.collate_fn,
        )
        device = torch.cuda.current_device()

        status_list = []
        status_mean = {}
        for epoch in range(self.max_epochs):
            pbar = tqdm(
                dataloader,
                desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
                disable=not self.strategy.is_rank_0(),
            )
            for experience in pbar:
                # print(experience)
                experience.to_device(device)
                status = self.training_step(experience, global_steps)

                # for DP
                # weighted mean for kl
                if "kl" in status:
                    status["kl"] *= status["response_length"]
                    status = self.strategy.all_reduce(status)
                    status["kl"] /= status["response_length"]

                short_status = {}

                if "policy_loss" in status:
                    short_status = {
                        "pg": status["policy_loss"],
                        "rm": status["reward"],
                        "ret": status["return"],
                        "glen": status["response_length"],
                        "tlen": status["total_length"],
                        "kl": status["kl"],
                        "act_lr": status["actor_lr"],
                    }

                if "critic_loss" in status:
                    short_status["cri"] = status["critic_loss"]
                    short_status["vals"] = status["values"]
                    short_status["cri_lr"] = status["critic_lr"]

                if "ptx_loss" in status:
                    short_status["ptx"] = status["ptx_loss"]

                status_list.append(status)
                pbar.set_postfix(short_status)

        if status_list:
            status_mean = status_list[0]
            for m in status_list[1:]:
                for k, v in m.items():
                    status_mean[k] += v
            for k in status_mean.keys():
                status_mean[k] /= len(status_list)
        return status_mean

    def training_step(self, experience: Experience, global_steps) -> Dict[str, float]:
        status = {}
        if global_steps > self.freezing_actor_steps:
            status = self.training_step_actor(experience)
        if self.critic is not None:
            status.update(self.training_step_critic(experience))
        return status

    def compute_dataset_only_discrimination(self, experience_info, dataset_types: List[str]=None) -> Dict[str, float]:
        if dataset_types is None:
            dataset_types = ["math", "gsm8k", "aqua"]
        metrics = {}
        discrimination = experience_info["discriminations"]
        discrimination_acc = experience_info["discrimination_acc"]
        scores = experience_info["scores"]
        origin_correct = experience_info["origin_scores"]
        origin_wrong = experience_info["origin_scores"].logical_not()
        for dataset_type in dataset_types:
            is_key = f"is_{dataset_type}"
            is_dataset = experience_info.get(is_key, None)
            if is_dataset is None:
                print(f"Warning: '{is_key}' not found in experience_info.")
                continue
            mask = is_dataset == 1.0
            discrimination_value = (mask * discrimination).float().sum().item()
            discrimination_acc_value = (mask * discrimination_acc).float().sum().item()
            discrimination_c_value = (mask * discrimination_acc * origin_correct).float().sum().item()
            discrimination_i_value = (mask * discrimination_acc * origin_wrong).float().sum().item()
            scores_value = (mask * scores).float().sum().item()
            dataset_num = mask.float().sum().item()
            origin_correct_num = (mask & (origin_correct == 1.0)).float().sum().item()
            origin_wrong_num = (mask & (origin_wrong == 1.0)).float().sum().item()
            metrics[f"discrimination_origin_correct/{dataset_type}"] = discrimination_c_value / origin_correct_num if origin_correct_num > 0 else 0.0
            metrics[f"discrimination_origin_wrong/{dataset_type}"] = discrimination_i_value / origin_wrong_num if origin_wrong_num > 0 else 0.0
            if dataset_num > 0:
                metrics[f"scores/{dataset_type}"] = scores_value / dataset_num
                metrics[f"discriminations/{dataset_type}"] = discrimination_value / dataset_num
                metrics[f"discrimination_acc/{dataset_type}"] = discrimination_acc_value / dataset_num
            else:
                metrics[f"scores/{dataset_type}"] = 0.0
                metrics[f"discriminations/{dataset_type}"] = 0.0
                metrics[f"discrimination_acc/{dataset_type}"] = 0.0
        metrics["scores/all"] = scores.float().mean().item()
        metrics["discriminations/all"] = discrimination.float().mean().item()
        metrics["discrimination_acc/all"] = discrimination_acc.float().mean().item()
        # discrimination i all
        total_discrimination_c = (discrimination_acc * origin_correct).float().sum().item()
        total_discrimination_i = (discrimination_acc * origin_wrong).float().sum().item()
        total_origin_correct_num = origin_correct.float().sum().item()
        total_origin_incorrect_num = origin_wrong.float().sum().item()
        metrics["discrimination_origin_correct/all"] = total_discrimination_c / total_origin_correct_num if total_origin_correct_num > 0 else 0.0
        metrics["discrimination_origin_wrong/all"] = total_discrimination_i / total_origin_incorrect_num if total_origin_incorrect_num > 0 else 0.0
        return metrics

    def compute_dataset_metrics(self, experience_info, dataset_types: List[str]=None) -> Dict[str, float]:
        if dataset_types is None:
            dataset_types = ["math", "gsm8k", "aqua"]
        metrics = {}
        c2c = experience_info["correct_to_correct"]
        c2i = experience_info["correct_to_incorrect"]
        i2c = experience_info["incorrect_to_correct"]
        i2i = experience_info["incorrect_to_incorrect"]
        origin_correct = experience_info["origin_scores"]
        origin_wrong = experience_info["origin_scores"].logical_not()
        refined_correct = experience_info["refinement_scores"]
        discrimination = experience_info["discriminations"]
        discrimination_acc = experience_info["discrimination_acc"]
        delta = experience_info["delta"]
        scores = experience_info["scores"]
        change_rate = experience_info["change_rate"]
        for dataset_type in dataset_types:
            is_key = f"is_{dataset_type}"
            is_dataset = experience_info.get(is_key, None)
            if is_dataset is None:
                print(f"Warning: '{is_key}' not found in experience_info.")
                continue
            mask = is_dataset == 1.0
            i2c_num = (mask & (i2c == 1.0)).float().sum().item()
            c2c_num = (mask & (c2c == 1.0)).float().sum().item()
            c2i_num = (mask & (c2i == 1.0)).float().sum().item()
            i2i_num = (mask & (i2i == 1.0)).float().sum().item()
            delta_value = (mask * delta).float().sum().item()
            scores_value = (mask * scores).float().sum().item()
            change_rate_value = (mask * change_rate).float().sum().item()
            discrimination_value = (mask * discrimination).float().sum().item()
            discrimination_acc_value = (mask * discrimination_acc).float().sum().item()
            discrimination_c_value = (mask * discrimination_acc * origin_correct).float().sum().item()
            discrimination_i_value = (mask * discrimination_acc * origin_wrong).float().sum().item()
            relevance_value = (mask * discrimination_acc * refined_correct).float().sum().item()
            relevance_i_value = (mask * discrimination_acc * origin_wrong * refined_correct).float().sum().item()
            relevance_c_value = (mask * discrimination_acc * origin_correct * refined_correct).float().sum().item()
            dataset_num = mask.float().sum().item()
            origin_wrong_num = i2c_num + i2i_num
            origin_correct_num = c2i_num + c2c_num
            metrics[f"i2c-all i/{dataset_type}"] = i2c_num / origin_wrong_num if origin_wrong_num > 0 else 0.0
            metrics[f"i2i-all i/{dataset_type}"] = i2i_num / origin_wrong_num if origin_wrong_num > 0 else 0.0
            metrics[f"c2i-all c/{dataset_type}"] = c2i_num / origin_correct_num if origin_correct_num > 0 else 0.0
            metrics[f"c2c-all c/{dataset_type}"] = c2c_num / origin_correct_num if origin_correct_num > 0 else 0.0
            metrics[f"discrimination_origin_correct/{dataset_type}"] = discrimination_c_value / origin_correct_num if origin_correct_num > 0 else 0.0
            metrics[f"discrimination_origin_wrong/{dataset_type}"] = discrimination_i_value / origin_wrong_num if origin_wrong_num > 0 else 0.0
            metrics[f"relevance_origin_correct/{dataset_type}"] = relevance_c_value / discrimination_c_value if discrimination_c_value > 0 else 0.0
            metrics[f"relevance_origin_wrong/{dataset_type}"] = relevance_i_value / discrimination_i_value if discrimination_i_value > 0 else 0.0
            if dataset_num > 0:
                metrics[f"incorrect_to_correct/{dataset_type}"] = i2c_num / dataset_num
                metrics[f"correct_to_correct/{dataset_type}"] = c2c_num / dataset_num
                metrics[f"correct_to_incorrect/{dataset_type}"] = c2i_num / dataset_num
                metrics[f"incorrect_to_incorrect/{dataset_type}"] = i2i_num / dataset_num
                metrics[f"delta/{dataset_type}"] = delta_value / dataset_num
                metrics[f"scores/{dataset_type}"] = scores_value / dataset_num
                metrics[f"change_rate/{dataset_type}"] = change_rate_value / dataset_num
                metrics[f"discriminations/{dataset_type}"] = discrimination_value / dataset_num
                metrics[f"discrimination_acc/{dataset_type}"] = discrimination_acc_value / dataset_num
                metrics[f"relevance/{dataset_type}"] = relevance_value / discrimination_acc_value if discrimination_acc_value > 0 else 0.0
            else:
                metrics[f"incorrect_to_correct/{dataset_type}"] = 0.0
                metrics[f"correct_to_correct/{dataset_type}"] = 0.0
                metrics[f"correct_to_incorrect/{dataset_type}"] = 0.0
                metrics[f"incorrect_to_incorrect/{dataset_type}"] = 0.0
                metrics[f"delta/{dataset_type}"] = 0.0
                metrics[f"scores/{dataset_type}"] = 0.0
                metrics[f"change_rate/{dataset_type}"] = 0.0
                metrics[f"discriminations/{dataset_type}"] = 0.0
                metrics[f"discrimination_acc/{dataset_type}"] = 0.0
                metrics[f"relevance/{dataset_type}"] = relevance_value / discrimination_acc_value if discrimination_acc_value > 0 else 0.0
        # calculate metrics for all
        metrics["incorrect_to_correct/all"] = i2c.float().mean().item()
        metrics["correct_to_correct/all"] = c2c.float().mean().item()
        metrics["correct_to_incorrect/all"] = c2i.float().mean().item()
        metrics["incorrect_to_incorrect/all"] = i2i.float().mean().item()
        metrics["delta/all"] = delta.float().mean().item()
        metrics["scores/all"] = scores.float().mean().item()
        metrics["change_rate/all"] = change_rate.float().mean().item()
        metrics["discriminations/all"] = discrimination.float().mean().item()
        metrics["discrimination_acc/all"] = discrimination_acc.float().mean().item()
        # x2x/all
        total_origin_correct_num = (c2c + c2i).float().sum().item()
        total_origin_wrong_num = (i2c + i2i).float().sum().item()
        total_i2c_num = i2c.float().sum().item()
        total_i2i_num = i2i.float().sum().item()
        total_c2i_num = c2i.float().sum().item()
        total_c2c_num = c2c.float().sum().item()
        metrics["i2c-all i/all"] = total_i2c_num / total_origin_wrong_num if total_origin_wrong_num > 0 else 0.0
        metrics["i2i-all i/all"] = total_i2i_num / total_origin_wrong_num if total_origin_wrong_num > 0 else 0.0
        metrics["c2i-all c/all"] = total_c2i_num / total_origin_correct_num if total_origin_correct_num > 0 else 0.0
        metrics["c2c-all c/all"] = total_c2c_num / total_origin_correct_num if total_origin_correct_num > 0 else 0.0
        total_discrimination_c = (discrimination_acc * origin_correct).float().sum().item()
        total_discrimination_i = (discrimination_acc * origin_wrong).float().sum().item()
        metrics["discrimination_origin_correct/all"] = total_discrimination_c / total_origin_correct_num if total_origin_correct_num > 0 else 0.0
        metrics["discrimination_origin_wrong/all"] = total_discrimination_i / total_origin_wrong_num if total_origin_wrong_num > 0 else 0.0
        # calcaulate relevance for all
        metrics["relevance_origin_correct/all"] = (discrimination_acc * origin_correct * refined_correct).float().sum().item() / total_discrimination_c if total_discrimination_c > 0 else 0.0
        metrics["relevance_origin_wrong/all"] = (discrimination_acc * origin_wrong * refined_correct).float().sum().item() / total_discrimination_i if total_discrimination_i > 0 else 0.0
        metrics["relevance/all"] = (discrimination_acc * refined_correct).float().sum().item() / discrimination_acc.float().sum().item() if discrimination_acc.float().sum().item() > 0 else 0.0
        return metrics
    
    def training_step_actor(self, experience: Experience) -> Dict[str, float]:
        self.actor.train()

        # TODO: this is a bad indicator to say that data is packed...
        if isinstance(experience.sequences, list):
            sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
            old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0)
            advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0)
            num_actions = [v.numel() for v in experience.advantages]
            packed_seq_lens = [s.numel() for s in experience.sequences]
            attention_mask = torch.cat(
                [torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0
            ).unsqueeze(0)
        else:
            sequences = experience.sequences
            old_action_log_probs = experience.action_log_probs
            advantages = experience.advantages
            num_actions = experience.action_mask.size(1)
            packed_seq_lens = None
            attention_mask = experience.attention_mask

        # actor loss
        action_log_probs, output = self.actor(
            sequences,
            num_actions,
            attention_mask=attention_mask,
            return_output=True,
            packed_seq_lens=packed_seq_lens,
        )

        # loss function
        actor_loss = self.actor_loss_fn(
            action_log_probs,
            old_action_log_probs,
            advantages,
            action_mask=experience.action_mask,
        )
        # mixtral
        if self.aux_loss:
            aux_loss = output.aux_loss
        else:
            aux_loss = 0
        loss = actor_loss + aux_loss * self.args.aux_loss_coef
        self.strategy.backward(loss, self.actor, self.actor_optim)

        # ptx loss
        if self.pretrain_dataloader is not None:
            data = next(self.pretrain_dataloader)
            inputs = data[1].squeeze(1).to(torch.cuda.current_device())
            attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
            label = torch.where(
                attention_mask.bool(),
                inputs,
                self.ptx_loss_fn.IGNORE_INDEX,
            )

            output = self.actor(inputs, attention_mask=attention_mask, return_output=True)
            ptx_log_probs = output["logits"]

            # loss function
            ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
            # mixtral
            if self.aux_loss:
                aux_loss = output.aux_loss
            else:
                aux_loss = 0
            loss = ptx_loss + aux_loss * self.args.aux_loss_coef
            self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)

        self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
        if self.ema_model:
            self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cpu")

        # status
        status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
        if self.pretrain_dataloader is not None:
            status["ptx_loss"] = ptx_loss.item()
        for k, v in experience.info.items():
            if k == "kl":
                status[k] = (
                    (v * experience.info["response_length"]).sum() / experience.info["response_length"].sum()
                ).item()
            else:
                status[k] = v.float().mean().item()
        return status

    def training_step_critic(self, experience: Experience) -> Dict[str, float]:
        self.critic.train()

        # TODO: this is a bad indicator to say that data is packed...
        if isinstance(experience.sequences, list):
            sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
            old_values = torch.cat(experience.values, dim=0).unsqueeze(0)
            returns = torch.cat(experience.returns, dim=0).unsqueeze(0)
            num_actions = [v.numel() for v in experience.advantages]
            packed_seq_lens = [s.numel() for s in experience.sequences]
            attention_mask = torch.cat(
                [torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0
            ).unsqueeze(0)
        else:
            sequences = experience.sequences
            old_values = experience.values
            returns = experience.returns
            num_actions = experience.action_mask.size(1)
            packed_seq_lens = None
            attention_mask = experience.attention_mask

        # critic loss
        values, output = self.critic(
            sequences,
            num_actions=num_actions,
            attention_mask=attention_mask,
            return_output=True,
            packed_seq_lens=packed_seq_lens,
        )
        # loss function
        critic_loss = self.critic_loss_fn(
            values,
            old_values,
            returns,
            action_mask=experience.action_mask,
        )
        # mixtral
        if self.aux_loss:
            aux_loss = output.aux_loss
        else:
            aux_loss = 0
        loss = critic_loss + aux_loss * self.args.aux_loss_coef
        self.strategy.backward(loss, self.critic, self.critic_optim)
        self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")

        # status
        status = {
            "critic_loss": critic_loss.item(),
            "values": masked_mean(values, experience.action_mask).item(),
            "critic_lr": self.critic_scheduler.get_last_lr()[0],
        }
        return status

    def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
        if global_step % args.logging_steps == 0:
            # wandb
            dataset_types = self.strategy.args.ppo_dataset_type + ["all"]
            logs = {}
            if self._wandb is not None and self.strategy.is_rank_0():
                for k, v in {**logs_dict, "global_step": global_step}.items():
                    if k.split('/')[-1] in dataset_types:
                        logs.update({k: v})
                    else:
                        logs.update({"train/%s" % k: v})
                self._wandb.log(logs)
            # TensorBoard
            elif self._tensorboard is not None and self.strategy.is_rank_0():
                for k, v in logs_dict.items():
                    self._tensorboard.add_scalar(f"train/{k}", v, global_step)
                if self.experience_maker.perf_stats is not None:
                    for k, v in self.experience_maker.perf_stats.items():
                        self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step)

        # TODO: Add evaluation mechanism for PPO
        # if global_step % args.eval_steps == 0:
            # self.evaluate(self.eval_dataloader, global_step)
            # pass
        # save ckpt
        # TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
        if global_step % args.save_steps == 0:
            tag = f"global_step{global_step}"
            self._save_checkpoint(args, tag, client_states)

    def _save_checkpoint(self, args, tag, client_states):
        self.strategy.save_ckpt(
            self.actor.model,
            os.path.join(args.ckpt_path, "_actor"),
            tag,
            args.max_ckpt_num,
            args.max_ckpt_mem,
            client_states,
        )
        if self.critic is not None:
            self.strategy.save_ckpt(
                self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
            )