import torch
torch.backends.cuda.matmul.allow_tf32 = True
import os
import gc
import time
import json
import wandb
import tqdm
import random
import functools
import contextlib
import transformers

import numpy as np
import torch.nn as nn
import tensor_parallel as tp
import torch.nn.functional as F
import torch.distributed as dist

from copy import deepcopy
from omegaconf import DictConfig
from collections import defaultdict
from collections import OrderedDict
from model import PolicyAndValueWrapper
from transformers import GenerationConfig
from typing import Optional, Dict, List, Union, Tuple
from preference_datasets_eval import get_batch_iterator_eval
from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel
from preference_datasets import get_batch_iterator, get_num_samples
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    StateDictType,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from utils import (
    slice_and_move_batch_for_device,
    slice_and_move_model_for_device,
    formatted_dict,
    all_gather_if_needed,
    pad_to_length,
    get_block_class_from_model,
    rank0_print,
    get_local_dir,
    AdaptiveKLController,
    FixedKLController,
    first_true_indices,
    get_reward,
    layer_init,
    pad,
    forward,
    truncate_response,
    masked_sum,
    masked_mean,
    masked_whiten,
    exact_div,
    batch_generation,
    create_reference_model,
    disable_dropout
)

INVALID_LOGPROB=1.0

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def reward_loss(policy_chosen_rewards, policy_rejected_reward):
    chosen_prob = 1 / (1 + torch.exp(policy_rejected_reward - policy_chosen_rewards))
    losses = - torch.log(chosen_prob)

    return losses.mean()

def preference_loss(policy_chosen_logps: torch.FloatTensor,
                    policy_rejected_logps: torch.FloatTensor,
                    reference_chosen_logps: torch.FloatTensor,
                    reference_rejected_logps: torch.FloatTensor,
                    beta: float,
                    label_smoothing: float = 0.0,
                    ipo: bool = False,
                    reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing)
        ipo: If True, use the IPO loss instead of the DPO loss.
        reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

    Returns:
        A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
        The losses tensor contains the DPO loss for each example in the batch.
        The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
    """
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios  # also known as h_{\pi_\theta}^{y_w,y_l}

    if ipo:
        losses = (logits - 1/(2 * beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
    else:
        # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing

    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards

# Tansfer raw output score to log probability
def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of 50277 are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    assert logits.shape[:-1] == labels.shape

    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = (labels != 50277)

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == 50277] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)


def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.
    
    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
        
    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """
    max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
    concatenated_batch = {}
    for k in batch:
        if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
            pad_value = 50277 if 'labels' in k else 0
            concatenated_key = k.replace('chosen', 'concatenated')
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
    for k in batch:
        if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
            pad_value = 50277 if 'labels' in k else 0
            concatenated_key = k.replace('rejected', 'concatenated')
            concatenated_batch[concatenated_key] = torch.cat((
                concatenated_batch[concatenated_key],
                pad_to_length(batch[k], max_length, pad_value=pad_value),
            ), dim=0)
    return concatenated_batch

class BasicTrainer(object):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, reference_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1):
        """A trainer for a language model, supporting either SFT or DPO training.
           
           If multiple GPUs are present, naively splits the model across them, effectively
           offering N times available memory, but without any parallel computation.
        """
        self.seed = seed
        self.rank = rank
        self.world_size = world_size
        self.config = config
        self.run_dir = run_dir

        # Setup tokenizers
        tokenizer_name_or_path = (config.reward_model.tokenizer_name_or_path or config.reward_model.name_or_path) \
            if config.loss.name == 'reward_loss' or config.loss.name == 'advbon' else (config.model.tokenizer_name_or_path or config.model.name_or_path)
        rank0_print(f'Loading tokenizer {tokenizer_name_or_path}')
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path, cache_dir=get_local_dir(config.local_dirs))
        # if self.tokenizer.pad_token_id is None:
        #     self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        # Setup data iterators to sample training and evaluation batches
        data_iterator_kwargs = dict(
            names=config.datasets,
            tokenizer=self.tokenizer,
            shuffle=True,
            max_length=config.max_length,
            max_prompt_length=config.max_prompt_length,
            sft_mode=config.loss.name == 'sft',
        )

        # self.policy = policy
        self.policy = slice_and_move_model_for_device(policy, self.rank) 
        # self.reference_model = reference_model
        self.reference_model = slice_and_move_model_for_device(reference_model, self.rank) 

        self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, n_examples=config.n_examples, batch_size=config.batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
        rank0_print(f'Loaded train data iterator')
        self.total_num_train_samples = get_num_samples(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs, n_examples=config.n_examples, batch_size=config.batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
        rank0_print(f'There are {self.total_num_train_samples} samples in train_iterator')
        if self.config.loss.name != "adv":
            self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples, batch_size=config.eval_batch_size, silent=rank != 0, cache_dir=get_local_dir(config.local_dirs))
            self.eval_batches = list(self.eval_iterator)
            rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}')

    # Get decoded output for training batches
    def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
        """Generate samples from the policy (and reference model, if doing DPO training) for the given batch of inputs."""

        # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
        # ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
        ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if self.config.is_distributed else contextlib.nullcontext())
        with ctx():
            policy_output = self.policy.generate(
                batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)

        if self.config.loss.name in {'dpo', 'ipo'}:
            # ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
            ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if self.config.is_distributed else contextlib.nullcontext())
            with ctx():
                reference_output = self.reference_model.generate(
                    batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)

        policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
        policy_output = all_gather_if_needed(policy_output, self.rank, self.world_size)
        policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)

        if self.config.loss.name in {'dpo', 'ipo'}:
            reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)
            reference_output = all_gather_if_needed(reference_output, self.rank, self.world_size)
            reference_output_decoded = self.tokenizer.batch_decode(reference_output, skip_special_tokens=True)
        else:
            reference_output_decoded = []

        return policy_output_decoded, reference_output_decoded
    
    # Concatenate chosen and rejected answers
    def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
        
           We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = concatenated_inputs(batch)
        all_logits = model(concatenated_batch['concatenated_input_ids'], attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
        all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'], average_log_prob=False)
        chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
        rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]
        return chosen_logps, rejected_logps

    # Compute loss and metrics for the purpose of evaluation
    def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
        """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

        metrics = {}
        train_test = 'train' if train else 'eval'

        if loss_config.name in {'dpo', 'ipo'}:
            policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)
            with torch.no_grad():
                reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch)

            if loss_config.name == 'dpo':
                loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': False}
            elif loss_config.name == 'ipo':
                loss_kwargs = {'beta': loss_config.beta, 'ipo': True}
            else:
                raise ValueError(f'unknown loss {loss_config.name}')

            losses, chosen_rewards, rejected_rewards = preference_loss(
                policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)

            reward_accuracies = (chosen_rewards > rejected_rewards).float()

            chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size)
            rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size)
            reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size)

            metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist()

            policy_rejected_logps = all_gather_if_needed(policy_rejected_logps.detach(), self.rank, self.world_size)
            metrics[f'logps_{train_test}/rejected'] = policy_rejected_logps.cpu().numpy().tolist()

        elif loss_config.name == 'sft':
            attention_mask = batch['chosen_input_ids'] != self.tokenizer.pad_token_id
            input_ids = torch.masked_fill(batch['chosen_input_ids'], ~attention_mask, 0)
            logits = self.policy(input_ids=input_ids, attention_mask=attention_mask, return_dict=True).logits
            labels = batch['chosen_input_ids'].masked_fill(batch['chosen_input_ids'] == self.tokenizer.pad_token_id, -1)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            losses = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-1).unsqueeze(dim=-1)

        if loss_config.name in {'dpo', 'ipo'}:
            policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size)
            metrics[f'logps_{train_test}/chosen'] = policy_chosen_logps.cpu().numpy().tolist()

        all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

        return losses.mean(), metrics

    def train(self):
        """Begin either SFT or DPO training, with periodic evaluation."""
    
        # Setup optimizer
        rank0_print(f'Using {self.config.optimizer} optimizer')
        self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr, eps=self.config.optimizer_eps)
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1)))
        # total_num_train_samples = 92858 * self.config.n_epochs
        # training_steps = total_num_train_samples // self.config.batch_size + 1
        # self.scheduler = transformers.get_scheduler("cosine", optimizer=self.optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=training_steps)
    
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        if self.config.loss.name in {'dpo', 'ipo'}:
            self.reference_model.eval()

        self.example_counter = 0
        self.batch_counter = 0
        last_log = None

        for batch in self.train_iterator:
            #### BEGIN EVALUATION ####
            if self.example_counter % self.config.eval_every == 0 and (self.example_counter > 0 or self.config.do_first_eval):
                rank0_print(f'Running evaluation after {self.example_counter} train examples')
                self.policy.eval()

                all_eval_metrics = defaultdict(list)
                if self.config.sample_during_eval:
                    all_policy_samples, all_reference_samples = [], []
                    policy_text_table = wandb.Table(columns=["step", "prompt", "sample"])
                    if self.config.loss.name in {'dpo', 'ipo'}:
                        reference_text_table = wandb.Table(columns=["step", "prompt", "sample"])

                for eval_batch in (tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches):
                    local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank)
                    with torch.no_grad():
                        _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False)

                    for k, v in eval_metrics.items():
                        all_eval_metrics[k].extend(v)

                if self.config.sample_during_eval:
                    if self.config.n_eval_model_samples < self.config.eval_batch_size:
                        rank0_print(f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.')
                        sample_batches = self.eval_batches[:1]
                    else:
                        n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size
                        sample_batches = self.eval_batches[:n_sample_batches]
                    for eval_batch in (tqdm.tqdm(sample_batches, desc='Generating samples...') if self.rank == 0 else sample_batches):
                        local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank)
                        policy_samples, reference_samples = self.get_batch_samples(local_eval_batch)

                        all_policy_samples.extend(policy_samples)
                        all_reference_samples.extend(reference_samples)

                        for prompt, sample in zip(eval_batch['prompt'], policy_samples):
                            policy_text_table.add_data(self.example_counter, prompt, sample)
                        if self.config.loss.name in {'dpo', 'ipo'}:
                            for prompt, sample in zip(eval_batch['prompt'], reference_samples):
                                reference_text_table.add_data(self.example_counter, prompt, sample)

                mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()}
                rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}')
                if self.config.sample_during_eval:                    
                    rank0_print(json.dumps(all_policy_samples[:10], indent=2))
                    if self.config.loss.name in {'dpo', 'ipo'}:
                        rank0_print(json.dumps(all_reference_samples[:10], indent=2))

                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_eval_metrics, step=self.example_counter)

                    if self.config.sample_during_eval:
                        wandb.log({"policy_samples": policy_text_table}, step=self.example_counter)
                        if self.config.loss.name in {'dpo', 'ipo'}:
                            wandb.log({"reference_samples": reference_text_table}, step=self.example_counter)

                if self.example_counter > 0:
                    if self.config.debug:
                        rank0_print('skipping save in debug mode')
                    else:
                        if self.rank == 0:
                            output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
                            rank0_print(f'creating checkpoint to write to {output_dir}...')
                            self.save(output_dir, mean_eval_metrics)
            #### END EVALUATION ####

            #### BEGIN TRAINING ####
            self.policy.train()

            start_time = time.time()
            batch_metrics = defaultdict(list)
            # Compute loss and update model
            for microbatch_idx in range(self.config.gradient_accumulation_steps):
                global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
                local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
                loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
                (loss / self.config.gradient_accumulation_steps).backward()

                for k, v in metrics.items():
                    batch_metrics[k].extend(v)

            grad_norm = self.clip_gradient()
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

            # Post processing
            step_time = time.time() - start_time
            examples_per_second = self.config.batch_size / step_time
            batch_metrics['examples_per_second'].append(examples_per_second)
            batch_metrics['grad_norm'].append(grad_norm)

            self.batch_counter += 1
            self.example_counter += self.config.batch_size

            if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs:
                mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()}
                mean_train_metrics['counters/examples'] = self.example_counter
                mean_train_metrics['counters/updates'] = self.batch_counter
                rank0_print(f'train stats after {self.example_counter} examples: {formatted_dict(mean_train_metrics)}')

                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_train_metrics, step=self.example_counter)

                last_log = time.time()
            else:
                rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently')
            #### END TRAINING ####

    def clip_gradient(self):
        """Clip the gradient norm of the parameters of a non-FSDP policy."""
        return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item()

    def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str, dir_name: Optional[str] = None):
        """Write a checkpoint to disk."""
        if dir_name is None:
            dir_name = os.path.join(self.run_dir, f'LATEST')

        os.makedirs(dir_name, exist_ok=True)
        output_path = os.path.join(dir_name, filename)
        rank0_print(f'writing checkpoint to {output_path}...')
        torch.save({
            'step_idx': step,
            'state': state,
            'metrics': metrics if metrics is not None else {},
        }, output_path)
    
    def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None):
        """Save policy, optimizer, and scheduler state to disk."""

        policy_state_dict = self.policy.state_dict()
        self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir)
        del policy_state_dict

        optimizer_state_dict = self.optimizer.state_dict()
        self.write_state_dict(self.example_counter, optimizer_state_dict, metrics, 'optimizer.pt', output_dir)
        del optimizer_state_dict

        scheduler_state_dict = self.scheduler.state_dict()
        self.write_state_dict(self.example_counter, scheduler_state_dict, metrics, 'scheduler.pt', output_dir)

class RewardTrainer(BasicTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # self.policy.config.pad_token_id = self.policy.config.eos_token_id

        # initialize linear head using N(0, 1/sqrt(d+1))
        self.policy.score = layer_init(self.policy.score, std=(1/np.sqrt(self.policy.config.hidden_size+1)))


    def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]], loss_type = 'reward_loss') -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        if loss_type in ['reward_loss', 'reward_gap']:
            concatenated_input_ids = torch.cat((batch['chosen_input_ids'], batch['rejected_input_ids']), dim=0)
        else:
            concatenated_input_ids = batch['chosen_input_ids']

        _, predicted_rewards, _ = get_reward(model=model, query_responses=concatenated_input_ids, pad_token_id=self.tokenizer.pad_token_id, context_length=0)
        if loss_type in ['reward_loss', 'reward_gap']:
            chosen_rewards = predicted_rewards[:batch['chosen_input_ids'].shape[0]]
            rejected_rewards = predicted_rewards[batch['chosen_input_ids'].shape[0]:]
            return chosen_rewards, rejected_rewards
        else:
            return predicted_rewards, None

        return chosen_rewards, rejected_rewards

    def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
        metrics = {}
        train_test = 'train' if train else 'eval'

        policy_chosen_rewards, policy_rejected_reward = self.concatenated_forward(self.policy, batch, loss_config.name)
        # losses = reward_loss(policy_chosen_rewards, policy_rejected_reward)
        if loss_config.name in ['reward_loss', 'reward_gap']:
            if loss_config.name == 'reward_loss':
                losses = -F.logsigmoid(policy_chosen_rewards - policy_rejected_reward).mean()
            else:
                losses = (batch['edge_weight'] * ((policy_chosen_rewards - policy_rejected_reward) - (batch['chosen_reward'] - batch['rejected_reward'])) ** 2).mean()

            accuracy = (policy_chosen_rewards > policy_rejected_reward).float().mean()
        else:
            losses = (batch['edge_weight'] * (policy_chosen_rewards - batch['chosen_reward']) ** 2).mean()
            accuracy = None

        all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = [all_devices_losses.mean().cpu().numpy()]

        if accuracy is not None:
            all_devices_acc = all_gather_if_needed(accuracy.detach(), self.rank, self.world_size)
            metrics[f'accuracy/{train_test}'] = [all_devices_acc.mean().cpu().numpy()]

        return losses.mean(), metrics
    
    
    
import math
    
    
    

class AdversarialBONTrainer(BasicTrainer):
    # def __init__(self, *args, **kwargs):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, reference_model: Optional[nn.Module] = None, policy_ref: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1):
        super().__init__(policy=policy,
            config=config,
            seed=seed,
            run_dir=run_dir,
            reference_model=reference_model,
            rank=rank,
            world_size=world_size,)

        self.policy_ref = slice_and_move_model_for_device(policy_ref, self.rank)

        self.policy.score = layer_init(self.policy.score, std=(1/np.sqrt(self.policy.config.hidden_size+1)))
        # self.tokenizer = transformers.AutoTokenizer.from_pretrained('openai-community/gpt2-large')
        # self.tokenizer.pad_token_id = 0#self.tokenizer.eos_token_id
        
        # policy_model = self.policy
        # peft_config = LoraConfig(
        #     task_type=self.config.loss.lora.lora_task_type,
        #     r=self.config.loss.lora.lora_r,
        #     target_modules=find_all_linear_names(policy_model),
        #     lora_alpha=self.config.loss.lora.lora_alpha,
        #     lora_dropout=self.config.loss.lora.lora_dropout,
        #     bias="none",
        #     use_rslora=self.config.loss.lora.use_rslora,
        #     modules_to_save=self.config.loss.lora.lora_modules_to_save,
        # )
        # self.policy = get_peft_model(policy_model, peft_config)
        # del policy_model
        # torch.cuda.empty_cache()
        # self.policy.enable_input_require_grads()
        # disable_dropout(self.policy)
        # print_trainable_parameters(self.policy)
        # print(f"self.config.minimum_log_interval_secs is {self.config.minimum_log_interval_secs}")


    def concatenated_forward(self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        # print(f"The current batch is {batch}")
        concatenated_input_ids = torch.cat((batch['chosen_input_ids'], batch['rejected_input_ids']), dim=0)
        # print(f"The outputs are : {concatenated_input_ids}")
        # _, predicted_rewards, _ = get_reward(model=model, query_responses=concatenated_input_ids, pad_token_id=0, context_length=0)
        _, predicted_rewards, _ = get_reward(model=model, query_responses=concatenated_input_ids, pad_token_id=self.tokenizer.pad_token_id, context_length=0)
        chosen_rewards = predicted_rewards[:batch['chosen_input_ids'].shape[0]]
        rejected_rewards = predicted_rewards[batch['chosen_input_ids'].shape[0]:]

        return chosen_rewards, rejected_rewards
    
    def bon_generation(
        self,
        batch: Dict[str, Union[List, torch.LongTensor]],
        n_samples: int = 16,
        temp : float = 0.7,
    ) -> torch.Tensor:
        """
        Two-stage Best-of-N sampling:
        Stage 1 (no_grad): Generate candidate completions in chunks and compute their rewards without gradients.
        Stage 2 (grad): Re-run the forward pass for the best candidate with gradient tracking.
        Finally, if the second (gradient) reward is lower than the original no-grad reward, we fall back to the no-grad value.
        """
        # Determine device for generation.
        # device = next(self.policy_ref.parameters()).device
        input_ids = batch["prompt_input_ids"]  # shape: [batch_size, prompt_len]
        batch_size = input_ids.shape[0]

        # Stage 1: Candidate Generation (No Grad)
        # We set num_chunks equal to n_samples so that each chunk generates one candidate.
        num_chunks = 64  # each chunk produces 1 sample
        chunk_size = math.ceil(n_samples / num_chunks)  # will be 1 in this case
        max_seq_len = 512
        max_new_token = 64
        
        candidate_outputs_list = []
        candidate_rewards_list = []

        # print(f"The pad token id is {self.tokenizer.pad_token_id}")
        # if self.tokenizer.pad_token_id == None:
        #     self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        # if self.tokenizer.pad_token_id == None:
        #     self.tokenizer.pad_token_id = 0
        with torch.no_grad():
            for i in range(num_chunks):
                # print(i)
                total_generated_so_far = i * chunk_size
                remaining_samples = n_samples - total_generated_so_far
                current_chunk_size = min(chunk_size, remaining_samples)
                if current_chunk_size <= 0:
                    break

                # Expand input_ids for current chunk:
                # [batch_size, current_chunk_size, prompt_len] -> flatten to [batch_size * current_chunk_size, prompt_len]
                expanded_input_ids = input_ids.unsqueeze(1).expand(batch_size, current_chunk_size, -1)
                expanded_input_ids = expanded_input_ids.contiguous().view(batch_size * current_chunk_size, -1)#.to(device)
                
                # print(f"tokenizer is {self.tokenizer}")
                # print(f"policy model is {self.policy_ref}")
                # print(f"reward model is {self.policy}")

                # Generate completions.
                generated_outputs = self.policy_ref.generate(
                    input_ids=expanded_input_ids,
                    max_length=max_seq_len,
                    # max_new_token=max_new_token,
                    pad_token_id=self.tokenizer.pad_token_id,
                    temperature=temp,
                    top_k=0,
                    top_p=0.95,
                    do_sample=True,
                    eos_token_id=self.tokenizer.eos_token_id
                )
                generated_outputs = pad_to_length(generated_outputs, max_seq_len, pad_value=0)
                generated_outputs = generated_outputs#.to(device)

                # Reshape: [batch_size, current_chunk_size, seq_len]
                generated_outputs = generated_outputs.view(batch_size, current_chunk_size, -1)

                # Score each completion.
                # reward_model_device = next(self.policy.parameters()).device
                flat_outputs = generated_outputs.view(batch_size * current_chunk_size, -1)#.to(reward_model_device)
                reward_chunk_size = 2
                all_chunk_rewards = []
                for start in range(0, flat_outputs.shape[0], reward_chunk_size):
                    end = start + reward_chunk_size
                    sub_outputs = flat_outputs[start:end]
                    _, sub_rewards, _ = get_reward(
                        model=self.policy,
                        query_responses=sub_outputs,
                        pad_token_id=0,
                        context_length=0
                    )
                    all_chunk_rewards.append(sub_rewards)
                flat_rewards = torch.cat(all_chunk_rewards, dim=0)  # shape: [batch_size * current_chunk_size]
                chunk_rewards = flat_rewards.view(batch_size, current_chunk_size)  # shape: [batch_size, current_chunk_size]

                candidate_rewards_list.append(chunk_rewards.cpu())
                candidate_outputs_list.append(generated_outputs.cpu())

            # Concatenate candidates and rewards along the sample dimension.
            all_candidates = torch.cat(candidate_outputs_list, dim=1)  # shape: [batch_size, total_candidates, seq_len]
            all_rewards = torch.cat(candidate_rewards_list, dim=1)       # shape: [batch_size, total_candidates]
            # In case there are extra samples (should not happen if n_samples is exact), trim:
            if all_rewards.shape[1] > n_samples:
                all_rewards = all_rewards[:, :n_samples]
                all_candidates = all_candidates[:, :n_samples, :]

            # For each prompt, find the index of the best candidate.
            best_indices = torch.argmax(all_rewards, dim=1)  # shape: [batch_size]
            best_rewards_no_grad = torch.gather(all_rewards, 1, best_indices.unsqueeze(1)).squeeze(1)#.to(device)
            best_candidates = torch.gather(
                all_candidates, 1,
                best_indices.unsqueeze(-1).unsqueeze(-1).expand(batch_size, 1, all_candidates.size(2))
            ).squeeze(1)  # shape: [batch_size, seq_len]

        # Stage 2: Re-run the reward computation for the best candidate with gradient tracking.
        best_candidates = best_candidates#.to(reward_model_device)
        # Now compute reward with gradients enabled.
        _, best_rewards_grad, _ = get_reward(
            model=self.policy,
            query_responses=best_candidates,
            pad_token_id=0,
            context_length=0
        )
        final_rewards = best_rewards_grad


        return final_rewards


    def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True, beta=0.1):
        metrics = {}
        train_test = 'train' if train else 'eval'
        
        policy_chosen_rewards, policy_rejected_reward = self.concatenated_forward(self.policy, batch)
        # losses = reward_loss(policy_chosen_rewards, policy_rejected_reward)
        losses = -F.logsigmoid(policy_chosen_rewards - policy_rejected_reward).mean()
        accuracy = (policy_chosen_rewards > policy_rejected_reward).float().mean()

        reward_bon = self.bon_generation(batch=batch, n_samples=64, temp=0.7)
        # print(f"reward BON is {reward_bon}")
        
        reward_ref = self.bon_generation(batch=batch, n_samples=1, temp=0.1)
        # print(f"reward ref is {reward_ref}")
        
        diff = reward_bon - reward_ref
        
        # print(f"The gap between BON and refference is {diff}")
        
        all_devices_losses = all_gather_if_needed(losses, self.rank, self.world_size)
        # print(f"The all_devices_losses are {all_devices_losses}")
        
        final_losses = beta * diff + losses
        
        bon_acc = (reward_bon >= reward_ref).float().mean()
        # metrics[f'loss/{train_test}'] = [all_devices_losses.mean().cpu().numpy()]
        metrics[f'loss/{train_test}'] = [final_losses.mean().detach().cpu().numpy()]
        metrics[f"ratio/{train_test}"] = [diff.mean().detach().cpu().numpy() /all_devices_losses.detach().cpu().numpy()]
        metrics[f"loss/diff_{train_test}"] = [diff.mean().detach().cpu().numpy()]
        metrics[f"loss/reward_loss_{train_test}"] = [all_devices_losses.mean().detach().cpu().numpy()]
        # metrics[f'loss/{train_test}'] = [final_losses.mean().cpu().numpy()]

        # all_devices_acc = all_gather_if_needed(accuracy.detach(), self.rank, self.world_size)
        all_devices_acc = all_gather_if_needed(bon_acc, self.rank, self.world_size)
        metrics[f'accuracy/{train_test}'] = [all_devices_acc.mean().detach().cpu().numpy()]

        return final_losses.mean(), metrics
    
    
    

class PPOTrainer(BasicTrainer):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str, reference_model: Optional[nn.Module] = None, reward_model: Optional[nn.Module] = None, value_model: Optional[nn.Module] = None, rank: int = 0, world_size: int = 1):
        super().__init__(
            policy=policy,
            config=config,
            seed=seed,
            run_dir=run_dir,
            reference_model=reference_model,
            rank=rank,
            world_size=world_size,
        )
        
        # del self.reference_model
        # self.reference_model = create_reference_model(self.policy)
        # disable_dropout(self.reference_model)
        # print(f"ref_model device: {self.reference_model.device}")

        self.policy.generation_config.eos_token_id = (
            None  # disable `pad_token_id` and `eos_token_id` because we just want to
        )
        self.policy.generation_config.pad_token_id = None  # generate tokens without truncation / padding

        if self.config.use_peft:
            policy_model = self.policy
            peft_config = LoraConfig(
                task_type=self.config.loss.lora.lora_task_type,
                r=self.config.loss.lora.lora_r,
                target_modules=find_all_linear_names(policy_model),
                lora_alpha=self.config.loss.lora.lora_alpha,
                lora_dropout=self.config.loss.lora.lora_dropout,
                bias="none",
                use_rslora=self.config.loss.lora.use_rslora,
                modules_to_save=self.config.loss.lora.lora_modules_to_save,
            )
            self.policy = get_peft_model(policy_model, peft_config)
            del policy_model
            # torch.cuda.empty_cache()
            self.policy.enable_input_require_grads()
            disable_dropout(self.policy)
            print_trainable_parameters(self.policy)
        
        del self.reference_model
        if isinstance(self.policy, PeftModel):
            self.reference_model = None
        else:
            self.reference_model = create_reference_model(self.policy)
            self.reference_model = slice_and_move_model_for_device(self.reference_model, self.rank)
            disable_dropout(self.reference_model)

        self.reward_model = slice_and_move_model_for_device(reward_model, self.rank)
        self.value_model = slice_and_move_model_for_device(value_model, self.rank)
        self.model = PolicyAndValueWrapper(policy=self.policy, value_model=self.value_model)

        if self.config.loss.stop_token and self.config.loss.stop_token == "eos":
            self.config.loss.stop_token_id = self.tokenizer.eos_token_id
        
    def train(self):
        # Setup optimizer
        rank0_print(f'Using {self.config.optimizer} optimizer')
        self.optimizer = getattr(torch.optim, self.config.optimizer)(self.model.parameters(), lr=self.config.lr, eps=self.config.optimizer_eps)
        training_steps = self.total_num_train_samples // self.config.batch_size + 1
        self.scheduler = transformers.get_scheduler("linear", optimizer=self.optimizer, num_warmup_steps=self.config.warmup_steps, num_training_steps=training_steps)
    
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        config = self.config.loss

        generation_config = transformers.GenerationConfig(
            max_new_tokens=config.response_length,
            temperature=(config.temperature + 1e-7),
            top_k=0.0,
            top_p=1.0,
            do_sample=True,
        )

        start_time = time.time()
        stats_shape = (config.num_ppo_epochs, self.config.gradient_accumulation_steps)
        approxkl_stats = torch.zeros(stats_shape, device=self.rank)
        pg_clipfrac_stats = torch.zeros(stats_shape, device=self.rank)
        pg_loss_stats = torch.zeros(stats_shape, device=self.rank)
        vf_loss_stats = torch.zeros(stats_shape, device=self.rank)
        vf_clipfrac_stats = torch.zeros(stats_shape, device=self.rank)
        entropy_stats = torch.zeros(stats_shape, device=self.rank)
        ratio_stats = torch.zeros(stats_shape, device=self.rank)
        self.model.train()

        self.example_counter = 0
        self.batch_counter = 0
        last_log = None 
        # PPO training
        for batch in (tqdm.tqdm(self.train_iterator, desc="Training") if self.rank == 0 else self.train_iterator):
            if self.example_counter % config.save_every == 0 and (self.example_counter > 0):
                if self.config.debug:
                    rank0_print('skipping save in debug mode')
                else:
                    output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
                    rank0_print(f'creating checkpoint to write to {output_dir}...')
                    self.save(output_dir)
                    
            metrics = {}
            global_microbatch = slice_and_move_batch_for_device(batch, 0, 1, self.rank)
            local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)

            with torch.no_grad():
                queries = local_microbatch["prompt_input_ids"]
                local_batch_size = queries.shape[0]
                context_length = queries.shape[1]
                responses = []
                postprocessed_responses = []
                logprobs = []
                ref_logprobs = []
                scores = []
                sequence_lengths = []
                values = []

                query_responses, logitss = batch_generation(
                    self.model.policy,
                    queries,
                    config.local_rollout_forward_batch_size,
                    self.tokenizer.pad_token_id,
                    generation_config,
                )

                for i in range(0, queries.shape[0], config.local_rollout_forward_batch_size):
                    query = queries[i : i + config.local_rollout_forward_batch_size]
                    query_response = query_responses[i : i + config.local_rollout_forward_batch_size]
                    response = query_response[:, context_length:]
                    logits = logitss[i : i + config.local_rollout_forward_batch_size]
                    all_logprob = F.log_softmax(logits, dim=-1)
                    logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
                    del logits, all_logprob
                    # torch.cuda.empty_cache()

                    if self.reference_model is None:
                        with self.model.policy.disable_adapter():
                            ref_output = forward(self.model.policy, query_response, self.tokenizer.pad_token_id)
                    else:
                        ref_output = forward(self.reference_model, query_response, self.tokenizer.pad_token_id)
                    # ref_output = forward(self.reference_model, query_response, self.tokenizer.pad_token_id)
                    ref_logits = ref_output.logits[:, context_length - 1 : -1]
                    ref_logits /= config.temperature + 1e-7
                    ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
                    ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
                    del ref_output, ref_logits, ref_all_logprob
                    # torch.cuda.empty_cache()

                    postprocessed_response = response
                    if config.stop_token_id is not None:  # handle the edge case when stop_token_id exists but is 0
                        postprocessed_response = truncate_response(
                            config.stop_token_id, self.tokenizer.pad_token_id, response
                        )

                    # Response Processing 2. run reward model on the truncated responses
                    postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
                    sequence_length = first_true_indices(postprocessed_response == self.tokenizer.pad_token_id) - 1
                    full_value, _, _ = get_reward(
                        self.model.value_model, query_response, self.tokenizer.pad_token_id, context_length
                    )
                    value = full_value[:, context_length - 1 : -1].squeeze(-1)
                    _, score, _ = get_reward(
                        self.reward_model, postprocessed_query_response, self.tokenizer.pad_token_id, context_length
                    )

                    responses.append(response)
                    postprocessed_responses.append(postprocessed_response)
                    logprobs.append(logprob)
                    ref_logprobs.append(ref_logprob)
                    sequence_lengths.append(sequence_length)
                    scores.append(score)
                    values.append(value)
                responses = torch.cat(responses, 0)
                postprocessed_responses = torch.cat(postprocessed_responses, 0)
                logprobs = torch.cat(logprobs, 0)
                ref_logprobs = torch.cat(ref_logprobs, 0)
                sequence_lengths = torch.cat(sequence_lengths, 0)
                scores = torch.cat(scores, 0)
                values = torch.cat(values, 0)
                del (logprob, ref_logprob, full_value, value, score)
                # torch.cuda.empty_cache()
                gc.collect()

                contain_eos_token = torch.any(postprocessed_responses == self.tokenizer.eos_token_id, dim=-1)
                if config.missing_eos_penalty is not None:
                    scores[~contain_eos_token] -= config.missing_eos_penalty

                # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
                response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
                padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
                logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
                ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
                sequence_lengths_p1 = sequence_lengths + 1
                padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
                values = torch.masked_fill(values, padding_mask_p1, 0)

                # 4. compute rewards
                kl = logprobs - ref_logprobs
                non_score_reward = -config.kl_coef * kl
                rewards = non_score_reward.clone()
                actual_start = torch.arange(rewards.size(0), device=rewards.device)
                actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
                rewards[[actual_start, actual_end]] += scores

                # 5. whiten rewards
                rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
                rewards = torch.masked_fill(rewards, padding_mask_p1, 0)

                # 6. compute advantages and returns
                lastgaelam = 0
                advantages_reversed = []
                gen_length = responses.shape[1]
                for t in reversed(range(gen_length)):
                    nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
                    delta = rewards[:, t] + config.gamma * nextvalues - values[:, t]
                    lastgaelam = delta + config.gamma * config.lam * lastgaelam
                    advantages_reversed.append(lastgaelam)
                advantages = torch.stack(advantages_reversed[::-1], axis=1)
                returns = advantages + values
                advantages = masked_whiten(advantages, ~padding_mask)
                advantages = torch.masked_fill(advantages, padding_mask, 0)
                # torch.cuda.empty_cache()

            # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
            for ppo_epoch_idx in range(config.num_ppo_epochs):
                b_inds = np.random.permutation(local_batch_size)
                local_batch_size_per_step = local_batch_size // self.config.gradient_accumulation_steps
                gradient_accumulation_idx = 0
                for mini_batch_start in range(0, local_batch_size, local_batch_size_per_step):
                    mini_batch_end = mini_batch_start + local_batch_size_per_step
                    mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]

                    mb_advantage = advantages[mini_batch_inds]
                    mb_responses = responses[mini_batch_inds]
                    mb_query_responses = query_responses[mini_batch_inds]
                    mb_logprobs = logprobs[mini_batch_inds]
                    mb_return = returns[mini_batch_inds]
                    mb_values = values[mini_batch_inds]

                    output, vpred_temp = forward(self.model, mb_query_responses, self.tokenizer.pad_token_id)
                    logits = output.logits[:, context_length - 1 : -1]
                    logits /= config.temperature + 1e-7
                    new_all_logprobs = F.log_softmax(logits, dim=-1)
                    new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
                    new_logprobs = torch.masked_fill(
                        new_logprobs, padding_mask[mini_batch_inds], INVALID_LOGPROB
                    )
                    vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
                    vpred = torch.masked_fill(vpred, padding_mask_p1[mini_batch_inds], 0)
                    vpredclipped = torch.clamp(
                        vpred,
                        mb_values - config.cliprange_value,
                        mb_values + config.cliprange_value,
                    )
                    vf_losses1 = torch.square(vpred - mb_return)
                    vf_losses2 = torch.square(vpredclipped - mb_return)
                    vf_loss_max = torch.max(vf_losses1, vf_losses2)
                    vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[mini_batch_inds])
                    vf_clipfrac = masked_mean(
                        (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[mini_batch_inds]
                    )
                    logprobs_diff = new_logprobs - mb_logprobs
                    ratio = torch.exp(logprobs_diff)
                    pg_losses = -mb_advantage * ratio
                    pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - config.cliprange, 1.0 + config.cliprange)
                    pg_loss_max = torch.max(pg_losses, pg_losses2)
                    pg_loss = masked_mean(pg_loss_max, ~padding_mask[mini_batch_inds])
                    loss = pg_loss + config.vf_coef * vf_loss
                    (loss / self.config.gradient_accumulation_steps).backward()

                    with torch.no_grad():
                        pg_clipfrac = masked_mean(
                            (pg_losses2 > pg_losses).float(), ~padding_mask[mini_batch_inds]
                        )
                        prob_dist = torch.nn.functional.softmax(logits, dim=-1)
                        entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
                        approxkl = 0.5 * (logprobs_diff**2).mean()
                        approxkl_stats[ppo_epoch_idx, gradient_accumulation_idx] = approxkl
                        pg_clipfrac_stats[ppo_epoch_idx, gradient_accumulation_idx] = (
                            pg_clipfrac
                        )
                        pg_loss_stats[ppo_epoch_idx, gradient_accumulation_idx] = pg_loss
                        vf_loss_stats[ppo_epoch_idx, gradient_accumulation_idx] = vf_loss
                        vf_clipfrac_stats[ppo_epoch_idx, gradient_accumulation_idx] = (
                            vf_clipfrac
                        )
                        entropy_stats[ppo_epoch_idx, gradient_accumulation_idx] = entropy.mean()
                        ratio_stats[ppo_epoch_idx, gradient_accumulation_idx] = ratio.mean()
                    gradient_accumulation_idx += 1

                    del (
                        output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
                        vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
                        pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
                        mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
                    )

                    # torch.cuda.empty_cache()

                self.optimizer.step()
                self.optimizer.zero_grad()

            with torch.no_grad():
                mean_kl = kl.sum(1).mean()
                mean_entropy = (-logprobs).sum(1).mean()
                mean_non_score_reward = non_score_reward.sum(1).mean()
                rlhf_reward = mean_non_score_reward + scores.mean()
                eps = int(self.example_counter / (time.time() - start_time))
                metrics = {}
                metrics["eps"] = eps
                metrics["objective/kl"] = all_gather_if_needed(mean_kl, self.rank, self.world_size).mean().item()
                metrics["objective/entropy"] = all_gather_if_needed(mean_entropy, self.rank, self.world_size).mean().item()
                metrics["objective/non_score_reward"] = all_gather_if_needed(mean_non_score_reward, self.rank, self.world_size).mean().item()
                metrics["objective/rlhf_reward"] = all_gather_if_needed(rlhf_reward, self.rank, self.world_size).mean().item()
                metrics["objective/scores"] = all_gather_if_needed(scores.mean(), self.rank, self.world_size).mean().item()
                metrics["policy/approxkl_avg"] = all_gather_if_needed(approxkl_stats, self.rank, self.world_size).mean().item()
                metrics["policy/clipfrac_avg"] = all_gather_if_needed(pg_clipfrac_stats, self.rank, self.world_size).mean().item()
                metrics["loss/policy_avg"] = all_gather_if_needed(pg_loss_stats, self.rank, self.world_size).mean().item()
                metrics["loss/value_avg"] = all_gather_if_needed(vf_loss_stats, self.rank, self.world_size).mean().item()
                metrics["val/clipfrac_avg"] = all_gather_if_needed(vf_clipfrac_stats, self.rank, self.world_size).mean().item()
                metrics["policy/entropy_avg"] = all_gather_if_needed(entropy_stats, self.rank, self.world_size).mean().item()
                metrics["val/ratio"] = all_gather_if_needed(ratio_stats, self.rank, self.world_size).mean().item()
                metrics["val/ratio_var"] = all_gather_if_needed(ratio_stats, self.rank, self.world_size).var().item()
                metrics["val/num_eos_tokens"] = (responses == self.tokenizer.eos_token_id).sum().item()
                metrics["lr"] = self.scheduler.get_last_lr()[0]
                metrics["episode"] = self.example_counter

            self.scheduler.step()
            del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, non_score_reward
            # torch.cuda.empty_cache()
            gc.collect()

            del (
                query_responses,
                responses,
                postprocessed_responses,
                logprobs,
                ref_logprobs,
                values,
                sequence_lengths,
                contain_eos_token,
                sequence_lengths_p1,
                response_idxs,
                padding_mask,
                padding_mask_p1,
                rewards,
                actual_start,
                actual_end,
                advantages,
                returns,
            )
            # torch.cuda.empty_cache()

            self.batch_counter += 1
            self.example_counter += self.config.batch_size

            if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs:
                metrics['counters/examples'] = self.example_counter
                metrics['counters/updates'] = self.batch_counter
                to_print_metrics = {
                    "RLHF_reward": metrics["objective/rlhf_reward"],
                    "KL": metrics["objective/kl"],
                    "rm_score": metrics["objective/scores"],
                    "approx_kl": metrics["policy/approxkl_avg"]
                }
                rank0_print(f'train stats after {self.example_counter} examples: {formatted_dict(to_print_metrics)}')

                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(metrics, step=self.example_counter)

                last_log = time.time()
            else:
                rank0_print(f'skipping logging after {self.example_counter} examples to avoid logging too frequently')

            del metrics
        
class AdversarialDPOTrainer(BasicTrainer):
    def _get_batch_samples(self, batch: Dict[str, torch.LongTensor]):

        # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
        ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False, recurse=False) if self.config.is_distributed else contextlib.nullcontext())
        with ctx():
            policy_output = self.policy.generate(
                batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)

        ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False, recurse=False) if self.config.is_distributed else contextlib.nullcontext())
        with ctx():
            reference_output = self.reference_model.generate(
                batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'], max_length=self.config.max_length, do_sample=True, pad_token_id=self.tokenizer.pad_token_id)
        
        prompt_len = batch['prompt_input_ids'].shape[1]
        
        # set up policy_output's attn mask
        policy_response = policy_output[:, prompt_len:]
        policy_response_attn_mask = (policy_response != self.tokenizer.pad_token_id).long() # B, response_length
        policy_attn_mask = torch.cat((batch['prompt_attention_mask'], policy_response_attn_mask), dim=-1)
        policy_attn_mask = pad_to_length(policy_attn_mask, self.config.max_length, pad_value=0)

        # set up policy_output's label
        policy_label = deepcopy(policy_output)
        policy_label[:, :prompt_len] = torch.tensor([50277] * prompt_len)
        policy_label[policy_label == self.tokenizer.pad_token_id] = 50277
        policy_label = pad_to_length(policy_label, self.config.max_length, pad_value=50277)

        # pad policy_output
        policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)

        # set up ref_output's attn mask
        ref_response = reference_output[:, prompt_len:]
        ref_response_attn_mask = (ref_response != self.tokenizer.pad_token_id).long() # B, response_length
        ref_attn_mask = torch.cat((batch['prompt_attention_mask'], ref_response_attn_mask), dim=-1)
        ref_attn_mask = pad_to_length(ref_attn_mask, self.config.max_length, pad_value=0)

        # set up ref_output's label
        ref_label = deepcopy(reference_output)
        ref_label[:, :prompt_len] = torch.tensor([50277] * prompt_len)
        ref_label[ref_label == self.tokenizer.pad_token_id] = 50277
        ref_label = pad_to_length(ref_label, self.config.max_length, pad_value=50277)
 
        # pad ref_output
        reference_output = pad_to_length(reference_output, self.config.max_length, self.tokenizer.pad_token_id)

        # sanity check
        assert policy_output.shape[1] == policy_attn_mask.shape[1] == policy_label.shape[1] == reference_output.shape[1] == ref_attn_mask.shape[1] == ref_label.shape[1], f"got {policy_output.shape[1]}, {policy_attn_mask.shape[1]}, {policy_label.shape[1]}, {ref_output.shape[1]}, {ref_attn_mask.shape[1]}, {ref_label.shape[1]}"

        batch = { # just to align with the naming used in concatenated_inputs function, so called chosen and rejected
            "chosen_input_ids": policy_output,
            "chosen_attention_mask": policy_attn_mask,
            "chosen_labels": policy_label,
            "rejected_input_ids": reference_output,
            "rejected_attention_mask": ref_attn_mask,
            "rejected_labels": ref_label,
        }
        return batch

    def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
        """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

        metrics = {}
        train_test = 'train' if train else 'eval'

        policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)
        with torch.no_grad():
            reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch)
        
        sampled_batch = self._get_batch_samples(batch)
        # Explanation for variable names
        # policy_response_policy_logps: \pi (a|x), where a \sim \pi
        # ref_response_policy_logps: \pi (a'|x), where a' \sim \pi_0
        # policy_response_ref_logps: \pi_0(a|x), where a \sim \pi
        # ref_response_ref_logps: \pi_0(a'|x), where a' \sim \pi_0
        policy_response_policy_logps, ref_response_policy_logps = self.concatenated_forward(self.policy, batch)
        with torch.no_grad():
            policy_response_ref_logps, ref_response_ref_logps = self.concatenated_forward(self.reference_model, batch)


        loss_kwargs = {'alpha': loss_config.alpha}

        losses = self.loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, policy_response_policy_logps, ref_response_policy_logps, policy_response_ref_logps, ref_response_ref_logps, **loss_kwargs)

        all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

        return losses.mean(), metrics

    def loss(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, policy_response_policy_logps, ref_response_policy_logps, policy_response_ref_logps, ref_response_ref_logps, alpha):
        policy_response_logratios = policy_response_policy_logps - policy_response_ref_logps
        ref_response_logratios = ref_response_policy_logps - ref_response_ref_logps

        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps

        logits = pi_logratios - ref_logratios

        # policy_response_logratios = policy_response_policy_logps / policy_response_ref_logps
        # ref_response_logratios = ref_response_policy_logps / ref_response_ref_logps
        # pi_logratios = policy_chosen_logps - policy_rejected_logps
        # ref_logratios = reference_chosen_logps - reference_rejected_logps
        # logits = pi_logratios - ref_logratios

        losses = (policy_response_logratios) - (ref_response_logratios) - alpha *  F.logsigmoid(logits)
        
        return losses
    

def OfflineRLTrainer():
    pass

def BRPACTrainer():
    pass
