import torch

torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn.functional as F
import torch.nn as nn
import transformers
from omegaconf import DictConfig
import logging
from imdb_sentiment import get_pos_sentiment_rewards
import collections

import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    StateDictType,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import torch.distributed as dist
import tensor_parallel as tp
import contextlib

from preference_datasets import get_batch_iterator
from utils import (
    slice_and_move_batch_for_device,
    formatted_dict,
    all_gather_if_needed,
    pad_to_length,
    get_block_class_from_model,
    rank0_print,
    get_local_dir,
)
import numpy as np
import wandb
import tqdm

import random
import os
from collections import defaultdict
import time
import json
import functools
from typing import Optional, Dict, List, Union, Tuple


def preference_loss(policy_chosen_logps: torch.FloatTensor,
                    policy_rejected_logps: torch.FloatTensor,
                    reference_chosen_logps: torch.FloatTensor,
                    reference_rejected_logps: torch.FloatTensor,
                    beta: float,
                    label_smoothing: float = 0.0,
                    ipo: bool = False,
                    reference_free: bool = False,
                    offset: bool = False,
                    ratio: bool = False,
                    alpha: float = 1.,
                    chosen_rewards: torch.FloatTensor = None,
                    rejected_rewards: torch.FloatTensor = None) -> Tuple[
    torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios  # also known as h_{\pi_\theta}^{y_w,y_l}

    if ipo:
        losses = (logits - 1 / (
                    2 * beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
    elif offset:
        if ratio:
            logging.warning("using ratio")
            margin = torch.stack(chosen_rewards, dim=0) / torch.stack(rejected_rewards, dim=0)
        else:
            margin = torch.stack(chosen_rewards, dim=0) - torch.stack(rejected_rewards, dim=0)
        margin = torch.log(margin.to(logits.device))

        losses = -F.logsigmoid(beta * logits - alpha * margin)
    else:
        # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(
            -beta * logits) * label_smoothing

    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards

def preference_loss(policy_chosen_logps: torch.FloatTensor,
                    policy_rejected_logps: torch.FloatTensor,
                    reference_chosen_logps: torch.FloatTensor,
                    reference_rejected_logps: torch.FloatTensor,
                    beta: float,
                    label_smoothing: float = 0.0,
                    ipo: bool = False,
                    reference_free: bool = False,
                    offset: bool = False,
                    ratio: bool = False,
                    alpha: float = 1.,
                    chosen_rewards: torch.FloatTensor = None,
                    rejected_rewards: torch.FloatTensor = None) -> Tuple[
    torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios 

    losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(
            -beta * logits) * label_smoothing

    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards


def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor,
                     average_log_prob: bool = False, pad_id: int = -100) -> torch.FloatTensor:
    """Compute the log probabilities of the given labels under the given logits.

    Args:
        logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
        labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length)
        average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.

    Returns:
        A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
    """
    assert logits.shape[:-1] == labels.shape

    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = (labels != pad_id)  # mask out padding tokens, was == -100 before

    labels[labels == -100] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2,
                                   index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)



def concatenated_inputs(batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[
    str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.
    
    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
        
    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """
    max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
    concatenated_batch = {}
    for k in batch:
        if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
            pad_value = -100 if 'labels' in k else 0
            concatenated_key = k.replace('chosen', 'concatenated')
            concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length,
                                                                 pad_value=pad_value)
    for k in batch:
        if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
            pad_value = -100 if 'labels' in k else 0
            concatenated_key = k.replace('rejected', 'concatenated')
            concatenated_batch[concatenated_key] = torch.cat((
                concatenated_batch[concatenated_key],
                pad_to_length(batch[k], max_length, pad_value=pad_value),
            ), dim=0)
    return concatenated_batch



class BasicTrainer(object):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
                 reference_model: Optional[nn.Module] = None, rank: int = 0,
                 world_size: int = 1):
        """A trainer for a language model, supporting either SFT or DPO training.
           
           If multiple GPUs are present, naively splits the model across them, effectively
           offering N times available memory, but without any parallel computation.
        """
        self.seed = seed
        self.rank = rank
        self.world_size = world_size
        self.config = config
        self.run_dir = run_dir

        tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path
        rank0_print(f'Loading tokenizer {tokenizer_name_or_path}')
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path,
                                                                    cache_dir=get_local_dir(
                                                                        config.local_dirs))
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        data_iterator_kwargs = dict(
            names=config.datasets,
            tokenizer=self.tokenizer,
            shuffle=True,
            max_length=config.max_length,
            max_prompt_length=config.max_prompt_length,
            sft_mode=config.loss.name == 'sft',
        )

        self.policy = policy
        self.reference_model = reference_model
        self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train',
                                                #  n_epochs=config.n_epochs,
                                                 n_epochs=1,
                                                 n_examples=config.n_examples,
                                                 batch_size=config.batch_size,
                                                 silent=rank != 0,
                                                 cache_dir=get_local_dir(config.local_dirs),
                                                 loss=config.loss.name)
        self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test',
                                                n_examples=config.n_eval_examples,
                                                batch_size=config.eval_batch_size,
                                                silent=rank != 0,
                                                cache_dir=get_local_dir(config.local_dirs),
                                                loss=config.loss.name)
        self.eval_batches = list(self.eval_iterator)
        self.test_batches = self.eval_batches

    def get_batch_samples(self, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
        """Generate samples from the policy (and reference model, if doing DPO training) for the given batch of inputs."""

        # FSDP generation according to https://github.com/pytorch/pytorch/issues/100069
        ctx = lambda: (FSDP.summon_full_params(self.policy, writeback=False,
                                               recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
        with ctx():
            policy_output = self.policy.generate(
                batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'],
                max_new_tokens=100, do_sample=True, temperature=self.config.temperature,
                top_p=self.config.topp,
                pad_token_id=self.tokenizer.pad_token_id)  # TODO: maxlen->self.config.max_length,

        if self.config.loss.name in {'dpo', 'ipo', 'odpo'}:
            ctx = lambda: (FSDP.summon_full_params(self.reference_model, writeback=False,
                                                   recurse=False) if 'FSDP' in self.config.trainer else contextlib.nullcontext())
            with ctx():
                reference_output = self.reference_model.generate(
                    batch['prompt_input_ids'], attention_mask=batch['prompt_attention_mask'],
                    max_new_tokens=100, do_sample=True, temperature=self.config.temperature,
                    top_p=self.config.topp,
                    pad_token_id=self.tokenizer.pad_token_id)  # TODO: maxlen->self.config.max_length,

        policy_output = pad_to_length(policy_output, self.config.max_length,
                                      self.tokenizer.pad_token_id)
        policy_output_decoded = self.tokenizer.batch_decode(policy_output,
                                                            skip_special_tokens=True)

        if self.config.loss.name in {'dpo', 'ipo', 'odpo'}:
            reference_output = pad_to_length(reference_output, self.config.max_length,
                                             self.tokenizer.pad_token_id)
            reference_output_decoded = self.tokenizer.batch_decode(reference_output,
                                                                   skip_special_tokens=True)
        else:
            reference_output_decoded = []

        return policy_output_decoded, reference_output_decoded, policy_output, reference_output

    def concatenated_forward(self, model: nn.Module,
                             batch: Dict[str, Union[List, torch.LongTensor]], loss: str,
                             k: int = 3) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
        
           We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = concatenated_inputs(batch)
        all_logits = model(concatenated_batch['concatenated_input_ids'],
                           attention_mask=concatenated_batch[
                               'concatenated_attention_mask']).logits.to(torch.float32)
        all_logps = _get_batch_logps(all_logits, concatenated_batch['concatenated_labels'],
                                     average_log_prob=True)
        

        chosen_logps = all_logps[:batch['chosen_input_ids'].shape[0]]
        rejected_logps = all_logps[batch['chosen_input_ids'].shape[0]:]

        return chosen_logps, rejected_logps


    def get_eval_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
        metrics = {}
        train_test = 'train' if train else 'eval'

        if loss_config.name in {'ipo','dpo'}:
            policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch, 'dpo')
            with torch.no_grad():
                reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch, 'dpo')

            loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': loss_config.name == 'ipo'}
            losses, chosen_rewards, rejected_rewards = preference_loss(
                policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)
                
            filtered_losses = losses
            filtered_chosen_rewards = chosen_rewards
            filtered_rejected_rewards = rejected_rewards

            filtered_reward_accuracies = (filtered_chosen_rewards > filtered_rejected_rewards).float()
            win_count = (chosen_rewards > rejected_rewards).sum().item()
            total_count = chosen_rewards.size(0)


            filtered_chosen_rewards = all_gather_if_needed(filtered_chosen_rewards, self.rank, self.world_size)
            filtered_rejected_rewards = all_gather_if_needed(filtered_rejected_rewards, self.rank, self.world_size)
            filtered_reward_accuracies = all_gather_if_needed(filtered_reward_accuracies, self.rank, self.world_size)
            
            metrics[f'rewards_{train_test}/chosen'] = filtered_chosen_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/rejected'] = filtered_rejected_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/accuracies'] = filtered_reward_accuracies.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/margins'] = (filtered_chosen_rewards - filtered_rejected_rewards).cpu().numpy().tolist()

            filtered_policy_rejected_logps = policy_rejected_logps
            filtered_policy_rejected_logps = all_gather_if_needed(filtered_policy_rejected_logps.detach(), self.rank, self.world_size)
            metrics[f'logps_{train_test}/rejected'] = filtered_policy_rejected_logps.cpu().numpy().tolist()

        elif loss_config.name == 'sft':
            policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
            policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)

            sorted_indices = torch.argsort(-policy_chosen_logps, descending=True)
            cutoff_index = 4
            selected_indices = sorted_indices[cutoff_index:]

            filtered_losses = -policy_chosen_logps[selected_indices]

        filtered_policy_chosen_logps = policy_chosen_logps
        filtered_policy_chosen_logps = all_gather_if_needed(policy_chosen_logps.detach(), self.rank, self.world_size)
        metrics[f'logps_{train_test}/chosen'] = filtered_policy_chosen_logps.cpu().numpy().tolist()

        all_devices_losses = all_gather_if_needed(filtered_losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

        return filtered_losses.mean(), metrics, losses

    def get_train_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, 
                                epoch_idx, train=True, remove_num=0):
        metrics = {}
        train_test = 'train' if train else 'eval'
        if loss_config.name in {'ipo', 'dpo'}:
            policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch, 'dpo')
            with torch.no_grad():
                reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch, 'dpo')
            loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 
                        'label_smoothing': loss_config.label_smoothing, 'ipo': loss_config.name == 'ipo'}
            losses, chosen_rewards, rejected_rewards = preference_loss(
                policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)
            
            reference_ratio = reference_chosen_logps - reference_rejected_logps
            condition = (torch.exp(reference_ratio) < 1.0) 
            filtered_indices = torch.nonzero(condition, as_tuple=True)[0]  
            filtered_losses = losses[filtered_indices] if filtered_indices.numel() > 0 else torch.tensor([], device=losses.device)
            removed_losses = torch.zeros(remove_num, dtype=losses.dtype, device=losses.device) 
            logging.warning(f"[DEBUG get_train_batch_metrics]Rank {self.rank}: All filtered losses before all_gather: {filtered_losses.detach().cpu().numpy().tolist()}")
            local_shape = torch.tensor(filtered_losses.shape, device=filtered_losses.device)
            shapes = [torch.empty_like(local_shape) for _ in range(self.world_size)]
            dist.all_gather(shapes, local_shape)
            
            if self.rank == 0:
                all_filtered_losses = []
                for src_rank in range(self.world_size):
                    if src_rank != self.rank:
                        shape = tuple(shapes[src_rank].cpu().numpy().tolist())
                        local_filtered_losses = torch.empty(shape, dtype=filtered_losses.dtype, device=filtered_losses.device)
                        dist.recv(local_filtered_losses, src=src_rank)
                    else:
                        local_filtered_losses = filtered_losses
                    all_filtered_losses.append(local_filtered_losses)
                all_filtered_losses = torch.cat([t.flatten() for t in all_filtered_losses], dim=0)
                logging.warning(f"[DEBUG get_train_batch_metrics] all_filtered_losses: {all_filtered_losses}")
                if all_filtered_losses.numel() > 0:
                    sorted_indices = torch.argsort(all_filtered_losses, descending=True) 
                    for i in range(min(remove_num, all_filtered_losses.numel())): 
                        remove_idx = sorted_indices[i]
                        removed_losses[i] = all_filtered_losses[remove_idx]  

            else:
                dist.send(filtered_losses, dst=0)        
    
            dist.broadcast(removed_losses, src=0)  
            logging.warning(f"[DEBUG get_train_batch_metrics] Rank {self.rank}: Removed_losses: {removed_losses}")
            
            final_losses = torch.tensor([], device=losses.device)
            for i in range(losses.numel()):
                tag = 0
                for j in range(remove_num):
                    if losses[i] == removed_losses[j]:
                        tag = 1
                        break
                if tag == 0:
                    final_losses = torch.cat((final_losses, losses[i].unsqueeze(0)), dim=0)
            
            logging.warning(f"[DEBUG get_train_batch_metrics]Rank {self.rank}:\n   All training loss:{losses.size(), losses} \n   Removed_loss: {removed_losses}\n   Remained training loss: {final_losses.size(), final_losses}")

            reward_accuracies = (chosen_rewards > rejected_rewards).float()
            chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size)
            rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size)
            reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size)
            logging.warning(f"[DEBUG get_train_batch_metrics]Rank {self.rank}: Final Losses After Removal: {final_losses.detach().mean()}")
            all_devices_losses = all_gather_if_needed(final_losses.detach().mean(), self.rank, self.world_size)
            logging.warning(f"[DEBUG get_train_batch_metrics]Rank {self.rank}: all_devices_losses: {all_devices_losses}")

            metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist()
            metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist()
            metrics[f'loss/{train_test}'] = all_devices_losses.mean().cpu().item()

            return final_losses.mean(), metrics, losses

   

    def calc_kl(self, batch):
        with torch.no_grad():
            attention_mask = batch != self.tokenizer.pad_token_id
            policy_logps = self.policy(batch, attention_mask=attention_mask).logits.to(
                torch.float32)
            policy_logps = _get_batch_logps(policy_logps, batch, average_log_prob=True,
                                            pad_id=self.tokenizer.pad_token_id)

            logging.warning(
                f'reference policy logits {self.reference_model(batch, attention_mask=attention_mask).logits}')
            reference_logps = self.reference_model(batch,
                                                   attention_mask=attention_mask).logits.to(
                torch.float32)
            reference_logps = _get_batch_logps(reference_logps, batch, average_log_prob=True,
                                               pad_id=self.tokenizer.pad_token_id)
            r = reference_logps - policy_logps

            return (torch.exp(r) - 1) - r

    def train(self):
        """Begin either SFT or DPO training, with periodic evaluation."""
        rank0_print(f'Using {self.config.optimizer} optimizer')
        self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(),
                                                                     lr=self.config.lr)

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer,
                                                           lr_lambda=lambda step: min(1.0, (step + 1) / (self.config.warmup_steps + 1)))

        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        if self.config.loss.name in {'dpo', 'ipo', 'odpo'}:
            self.reference_model.eval()
        
        self.removed_reference_ratios = []
        self.example_counter = 0
        self.batch_counter = 0
        last_log = None

        all_C_theta = []
        all_C_ref = []
        removed_C_theta = []
        removed_C_ref = []
        epoch_idx = 0 

        data_iterator_kwargs = dict(
        names=self.config.datasets,
        tokenizer=self.tokenizer,
        shuffle=True,
        max_length=self.config.max_length,
        max_prompt_length=self.config.max_prompt_length,
        sft_mode=self.config.loss.name == 'sft',
        )
        self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train',
                                                #  n_epochs=config.n_epochs,
                                                n_epochs=1,
                                                n_examples=self.config.n_examples,
                                                batch_size=self.config.batch_size,
                                                silent=True,
                                                cache_dir=get_local_dir(self.config.local_dirs),
                                                loss=self.config.loss.name)

        for batch in self.train_iterator:
            #### BEGIN EVALUATION ####
            if self.example_counter % self.config.eval_every == 0 and (
                    self.example_counter > 0 or self.config.do_first_eval):
                logging.warning(
                    f'Running evaluation after {self.example_counter} train examples')
                self.policy.eval()

                all_eval_metrics = defaultdict(list)
                if self.config.sample_during_eval:
                    all_policy_samples, all_reference_samples = [], []
                    policy_text_table = wandb.Table(columns=["step", "prompt", "sample"])
                    if self.config.loss.name in {'dpo', 'ipo', 'odpo'}:
                        reference_text_table = wandb.Table(
                            columns=["step", "prompt", "sample"])

                for eval_batch in (tqdm.tqdm(self.eval_batches,
                                            desc='Computing eval metrics') if self.rank == 0 else self.eval_batches):
                    local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank,
                                                                    self.world_size,
                                                                    self.rank)
                    with torch.no_grad():
                        _, eval_metrics, _ = self.get_eval_batch_metrics(local_eval_batch,
                                                                self.config.loss,
                                                                train=False)

                    for k, v in eval_metrics.items():
                        all_eval_metrics[k].append(v)

                mean_eval_metrics = {
                    k: sum([item for sublist in v for item in sublist] if isinstance(v[0], list) else v) / len([item for sublist in v for item in sublist] if isinstance(v[0], list) else v)
                    for k, v in all_eval_metrics.items()
                }
                logging.warning(
                    f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}')
                if self.config.sample_during_eval:
                    logging.warning(json.dumps(all_policy_samples[:10], indent=2))
                    if self.config.loss.name in {'dpo', 'ipo'}:
                        logging.warning(json.dumps(all_reference_samples[:10], indent=2))
                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_eval_metrics, step=self.example_counter)

                    if self.config.sample_during_eval:
                        wandb.log({"policy_samples": policy_text_table},
                                step=self.example_counter)
                        if self.config.loss.name in {'dpo', 'ipo'}:
                            wandb.log({"reference_samples": reference_text_table},
                                    step=self.example_counter)
            #### END EVALUATION ####

            #### BEGIN TRAINING ####
            self.policy.train()

            start_time = time.time()
            batch_metrics = defaultdict(list)
            for microbatch_idx in range(self.config.gradient_accumulation_steps):
                global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx,
                                                                    self.config.gradient_accumulation_steps,
                                                                    self.rank)
                # logging.warning(f"[DEBUG train] rank {self.rank}: global_microbatch size: {len(global_microbatch['chosen_input_ids'])}")
                local_microbatch = slice_and_move_batch_for_device(global_microbatch,
                                                                self.rank, self.world_size,
                                                                self.rank)
                # logging.warning(f"[DEBUG train] rank {self.rank}: local_microbatch size: {len(local_microbatch['chosen_input_ids'])}")
                loss, metrics, loss_each = self.get_train_batch_metrics(local_microbatch, self.config.loss,
                                                    epoch_idx=epoch_idx, train=True, remove_num=self.config.remove_num)
                    
                (loss / self.config.gradient_accumulation_steps).backward()

                for k, v in metrics.items():
                    # batch_metrics[k].extend(v)
                    batch_metrics[k].append(v)

            grad_norm = self.clip_gradient()
            self.optimizer.step()
            self.scheduler.step()
            self.optimizer.zero_grad()

            step_time = time.time() - start_time
            examples_per_second = self.config.batch_size / step_time
            batch_metrics['examples_per_second'].append(examples_per_second)
            batch_metrics['grad_norm'].append(grad_norm)

            self.batch_counter += 1
            self.example_counter += self.config.batch_size

            if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs:
                # mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()}
                mean_train_metrics = {
                k: sum([item for sublist in v for item in sublist] if isinstance(v[0], list) else v) / len([item for sublist in v for item in sublist] if isinstance(v[0], list) else v)
                for k, v in batch_metrics.items()
            }

                mean_train_metrics['counters/examples'] = self.example_counter
                mean_train_metrics['counters/updates'] = self.batch_counter
                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_train_metrics, step=self.example_counter)

                last_log = time.time()

        


    def clip_gradient(self):
        """Clip the gradient norm of the parameters of a non-FSDP policy."""
        return torch.nn.utils.clip_grad_norm_(self.policy.parameters(),
                                              self.config.max_grad_norm).item()

    def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict,
                         filename: str, dir_name: Optional[str] = None):
        """Write a checkpoint to disk."""
        if dir_name is None:
            dir_name = os.path.join(self.run_dir, f'LATEST')

        os.makedirs(dir_name, exist_ok=True)
        output_path = os.path.join(dir_name, filename)
        rank0_print(f'writing checkpoint to {output_path}...')
        torch.save({
            'step_idx': step,
            'state': state,
            'metrics': metrics if metrics is not None else {},
        }, output_path)

    def run_classifier_evals(self, policy_samples, reference_samples):
        logging.warning("in the tldr eval function")
        policy_reward, policy_reward_binary = get_pos_sentiment_rewards(policy_samples,
                                                                        self.config.datasets[
                                                                            0])

        self.metrics['policy_rewards'].extend(policy_reward)
        self.metrics['policy_rewards_binary'].extend(policy_reward_binary)

        reference_reward, reference_reward_binary = get_pos_sentiment_rewards(
            reference_samples, self.config.datasets[0])
        self.metrics['reference_rewards'].extend(reference_reward)
        self.metrics['reference_rewards_binary'].extend(reference_reward_binary)

        self.all_policy_samples.extend(policy_samples)
        self.all_reference_samples.extend(reference_samples)

    def run_evals(self, example_counter=0):
        logging.warning("in the sample an write function")

        self.all_policy_samples, self.all_reference_samples = [], []
        self.metrics = collections.defaultdict(list)
        text_table = wandb.Table(columns=["prompt", "policy_sample", "reference_sample"])
        results = []

        if self.config.n_eval_model_samples < self.config.eval_batch_size:
            rank0_print(
                f'Warning: n_eval_model_samples ({self.config.n_eval_model_samples}) < eval_batch_size ({self.config.eval_batch_size}). Sampling from the first complete eval batch of prompts.')
            sample_batches = self.test_batches[:1]
        else:
            n_sample_batches = self.config.n_eval_model_samples // self.config.eval_batch_size
            sample_batches = self.test_batches[:n_sample_batches]
        for eval_batch in (tqdm.tqdm(sample_batches,
                                     desc='Generating samples...') if self.rank == 0 else sample_batches):
            local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank,
                                                               self.world_size, self.rank)
            policy_samples, reference_samples, policy_output, _ = self.get_batch_samples(
                local_eval_batch)
            logging.warning(f'policy samples: {policy_samples}')
            logging.warning(f'policy output {policy_output}')
            self.metrics['KL'].extend(self.calc_kl(policy_output))

            if self.config.datasets[0] == 'imdb' or self.config.datasets[0] == 'toxicity':
                self.run_classifier_evals(policy_samples, reference_samples)

            for prompt, policy_sample, reference_sample in zip(eval_batch['prompt'],
                                                               policy_samples,
                                                               reference_samples):
                text_table.add_data(prompt, policy_sample, reference_sample)
                results.append({
                "prompt": prompt,
                "policy_sample": policy_sample,
                "reference_sample": reference_sample
            })
                print("log success")

        output_file_path = "tldr_dpo_drop1.csv"  
        # output_file_path = "generated_samples_seed1_overdrop4.csv"  
        import pandas as pd
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_file_path, index=False)
        print(f"Generated samples saved to {output_file_path}")
        logging.warning(f'length of policy samples {len(self.all_policy_samples)}')
        logging.warning(f'length of kl {len(self.metrics["KL"])}')

        if self.config.wandb.enabled and self.rank == 0:
            logging.warning('LOG samples')
            wandb.log({"samples": text_table}, step=example_counter)

        logging.warning(f'metrics: {self.metrics}')
        mean_eval_metrics = {k: sum(v) / len(v) for k, v in self.metrics.items()}
        if self.config.wandb.enabled and self.rank == 0:
            logging.warning(f'LOG RESULTS {mean_eval_metrics}')
            wandb.log(mean_eval_metrics, step=example_counter)

    def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None):
        """Save policy, optimizer, and scheduler state to disk."""

        policy_state_dict = self.policy.state_dict()
        self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt',
                              output_dir)
        del policy_state_dict

        optimizer_state_dict = self.optimizer.state_dict()
        self.write_state_dict(self.example_counter, optimizer_state_dict, metrics,
                              'optimizer.pt', output_dir)
        del optimizer_state_dict

        scheduler_state_dict = self.scheduler.state_dict()
        self.write_state_dict(self.example_counter, scheduler_state_dict, metrics,
                              'scheduler.pt', output_dir)


class FSDPTrainer(BasicTrainer):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
                 reference_model: Optional[nn.Module] = None, rank: int = 0,
                 world_size: int = 1):
        """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs.
        
           This trainer will shard both the policy and reference model across all available GPUs.
           Models are sharded at the block level, where the block class name is provided in the config.
        """

        super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size)
        assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP'

        wrap_class = get_block_class_from_model(policy, config.model.block_name)
        logging.warning(wrap_class)
        model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy,
                                                   transformer_layer_cls={wrap_class}, )

        shared_fsdp_kwargs = dict(
            auto_wrap_policy=model_auto_wrap_policy,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=False),
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            device_id=rank,
            ignored_modules=None,
            limit_all_gathers=False,
            use_orig_params=False,
            sync_module_states=False
        )

        rank0_print('Sharding policy...')
        mp_dtype = getattr(torch,
                           config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None
        policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype,
                                          buffer_dtype=mp_dtype)
        self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)

        if config.activation_checkpointing:
            rank0_print('Attempting to enable activation checkpointing...')
            try:
                # use activation checkpointing, according to:
                # https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/
                #
                # first, verify we have FSDP activation support ready by importing:
                from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
                    checkpoint_wrapper,
                    apply_activation_checkpointing,
                    CheckpointImpl,
                )
                non_reentrant_wrapper = functools.partial(
                    checkpoint_wrapper,
                    offload_to_cpu=False,
                    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
                )
            except Exception as e:
                rank0_print('FSDP activation checkpointing not available:', e)
            else:
                check_fn = lambda submodule: isinstance(submodule, wrap_class)
                rank0_print('Applying activation checkpointing wrapper to policy...')
                apply_activation_checkpointing(self.policy,
                                               checkpoint_wrapper_fn=non_reentrant_wrapper,
                                               check_fn=check_fn)
                rank0_print('FSDP activation checkpointing enabled!')

        if config.loss.name in {'dpo', 'ipo', 'odpo'}:
            rank0_print('Sharding reference model...')
            self.reference_model = FSDP(reference_model, **shared_fsdp_kwargs)

        print('Loaded model on rank', rank)
        dist.barrier()

    def clip_gradient(self):
        """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs."""
        return self.policy.clip_grad_norm_(self.config.max_grad_norm).item()

    def save(self, output_dir=None, metrics=None):

        """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process."""
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT,
                                  state_dict_config=save_policy):
            policy_state_dict = self.policy.state_dict()

        if self.rank == 0:
            self.write_state_dict(self.example_counter, policy_state_dict, metrics,
                                  'policy.pt', output_dir)
        del policy_state_dict
        dist.barrier()

        save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)

        with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT,
                                  optim_state_dict_config=save_policy):
            optimizer_state_dict = FSDP.optim_state_dict(self.policy, self.optimizer)

        if self.rank == 0:
            self.write_state_dict(self.example_counter, optimizer_state_dict, metrics,
                                  'optimizer.pt', output_dir)
        del optimizer_state_dict
        dist.barrier()

        if self.rank == 0:
            scheduler_state_dict = self.scheduler.state_dict()
            self.write_state_dict(self.example_counter, scheduler_state_dict, metrics,
                                  'scheduler.pt', output_dir)
        dist.barrier()


class TensorParallelTrainer(BasicTrainer):
    def __init__(self, policy, config, seed, run_dir, reference_model=None, rank=0,
                 world_size=1):
        """A trainer subclass that uses TensorParallel to shard the model across multiple GPUs.

           Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow,
              see https://github.com/BlackSamorez/tensor_parallel/issues/66.
        """
        super().__init__(policy, config, seed, run_dir, reference_model, rank, world_size)

        rank0_print('Sharding policy...')
        self.policy = tp.tensor_parallel(policy, sharded=True)
        if config.loss.name in {'dpo', 'ipo', 'odpo'}:
            rank0_print('Sharding reference model...')
            self.reference_model = tp.tensor_parallel(reference_model, sharded=False)

    def save(self, output_dir=None, metrics=None):
        """Save (unsharded) policy state to disk."""
        with tp.save_tensor_parallel(self.policy):
            policy_state_dict = self.policy.state_dict()

        self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt',
                              output_dir)
        del policy_state_dict