import pdb

import torch
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn.functional as F
import torch.nn as nn
import transformers
from omegaconf import DictConfig

import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    StateDictType,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.api import FullStateDictConfig, FullOptimStateDictConfig
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import tensor_parallel as tp
import contextlib

from preference_datasets import get_batch_iterator
from utils import (
    slice_and_move_batch_for_device,
    formatted_dict,
    all_gather_if_needed,
    pad_to_length,
    get_block_class_from_model,
    rank0_print,
    get_local_dir,
)
import numpy as np
import wandb
import tqdm

import random
import os
from collections import defaultdict
import time
import json
import functools
from typing import Optional, Dict, List, Union, Tuple

from Term import Term, calculateTerm
from OrdinalModel import OrdinalModel
from trainers import concatenated_inputs
from torch.optim.lr_scheduler import LambdaLR


def calculateSingleCategoryLoss(term, losses):
    if term == 0:
        finalLoss = losses.mean()
    else:
        # termLoss = 1 / self.term * torch.logsumexp(self.term * losses - np.log(losses.shape[0]), dim=0)
        termLoss = 1 / term * torch.logsumexp(term * losses, dim=0)
        finalLoss = termLoss

    return finalLoss


class BasicTrainer(object):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
                rank: int = 0, world_size: int = 1,
                 ordinalModel: Optional[nn.Module] = None):
        """A trainer for a language model, supporting either SFT or DPO training.
           
           If multiple GPUs are present, naively splits the model across them, effectively
           offering N times available memory, but without any parallel computation.
        """
        self.seed = seed
        self.rank = rank
        self.world_size = world_size
        self.config = config
        # self.term = self.config.term
        self.run_dir = run_dir
        self.ordinalModel = ordinalModel

        tokenizer_name_or_path = config.model.tokenizer_name_or_path or config.model.name_or_path
        rank0_print(f'Loading tokenizer {tokenizer_name_or_path}')
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name_or_path,token="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",  cache_dir=get_local_dir(config.local_dirs))
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        data_iterator_kwargs = dict(
            names=config.datasets,
            tokenizer=self.tokenizer,
            shuffle=True,
            max_length=config.max_length,
            max_prompt_length=config.max_prompt_length,
            sft_mode=config.loss.name == 'sft',
        )

        self.policy = policy

        symmetrizeDataset = False
        makeScoresPositive = True
        if self.config.loss.name in {'ordinal', 'allThreshold'}:
            symmetrizeDataset = self.config.loss.symmetrizeDataset
            makeScoresPositive = self.config.loss.makeScoresPositive

        self.train_iterator = get_batch_iterator(**data_iterator_kwargs, split='train', n_epochs=config.n_epochs,
                                                 n_examples=config.n_examples, batch_size=config.batch_size,
                                                 silent=rank != 0, cache_dir=get_local_dir(config.local_dirs),
                                                 flipProbability=config.flipProbability,
                                                 makeScoresPositive=makeScoresPositive,
                                                 symmetrizeDataset=symmetrizeDataset
                                                 )
        rank0_print(f'Loaded train data iterator')

        self.eval_iterator = get_batch_iterator(**data_iterator_kwargs, split='test', n_examples=config.n_eval_examples,
                                                batch_size=config.eval_batch_size, silent=rank != 0,
                                                cache_dir=get_local_dir(config.local_dirs),
                                                makeScoresPositive=makeScoresPositive,
                                                symmetrizeDataset=symmetrizeDataset
                                                )
        self.eval_batches = list(self.eval_iterator)
        rank0_print(f'Loaded {len(self.eval_batches)} eval batches of size {config.eval_batch_size}')

        numberOfLosses = 1 if not self.config.loss.multiObjective else len(self.config.loss.preferences)
        self.term = Term(numberOfLosses, self.config.term, averageTerm=self.config.averageTerm)
        if not self.term.initialized:
            raise NotImplementedError
        
        # Ordinal update interval control
        self.ordinal_update_interval = getattr(self.config.loss, 'ordinal_update_interval', 1)
        self.ordinal_update_counter = 0
        rank0_print(f'Ordinal model will be updated every {self.ordinal_update_interval} batches')
        
        # L2 regularization weight for ordinal parameters
        self.ordinal_l2_weight = getattr(self.config.loss, 'ordinal_l2_weight', 0.0)
        if self.ordinal_l2_weight > 0 and self.ordinalModel is not None:
            rank0_print(f'L2 regularization for ordinal parameters: {self.ordinal_l2_weight}')

    def concatenated_forward(self, batch: Dict[str, Union[List, torch.LongTensor]])\
            -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
        
           We do this to avoid doing two forward passes, because it's faster for FSDP.
        """
        concatenated_batch = concatenated_inputs(batch)
        allRewards = self.policy(concatenated_batch['concatenated_input_ids'],
                                 attention_mask=concatenated_batch['concatenated_attention_mask']).logits.to(torch.float32)
        chosenRewards = allRewards[:batch['chosen_input_ids'].shape[0]].squeeze(-1)
        rejectedRewards = allRewards[batch['chosen_input_ids'].shape[0]:].squeeze(-1)
        return chosenRewards, rejectedRewards

    def compute_ordinal_performance_metric(self, rewardDifference, batch, metrics, train_test):
        """
        Compute the ordinal performance metric by comparing predicted scores with actual scores.
        """
        if rewardDifference.ndim == 0:
            return
        
        # Get thetas and offset from the ordinal model
        thetas, offset = self.ordinalModel.calculateThetas()
        thetas = thetas.to(rewardDifference.device)
        
        # Find which interval each rewardDifference falls into
        # Compare rewardDifference with all thetas
        comparison = (rewardDifference.unsqueeze(1) >= thetas.unsqueeze(0)).float()
        
        # Count how many thetas each rewardDifference is greater than or equal to
        num_greater_equal = comparison.sum(dim=1)
        
        # The interval index is num_greater_equal - 1, clamped to valid range
        interval_indices = torch.clamp(num_greater_equal - 1, 0, thetas.shape[0] - 2)
        
        # Convert interval index to predicted score
        predicted_scores = interval_indices - offset
        
        # Get actual scores from the batch
        actual_scores = torch.tensor(batch['strength'], device=rewardDifference.device, dtype=torch.float32)
        
        # Compute absolute difference between predicted and actual scores
        score_diff = torch.abs(predicted_scores - actual_scores)
        
        # Also compute the raw differences for additional insights
        raw_diff = predicted_scores - actual_scores
        
        # Gather metrics across all devices
        score_diff_gathered = all_gather_if_needed(score_diff.detach(), self.rank, self.world_size)
        raw_diff_gathered = all_gather_if_needed(raw_diff.detach(), self.rank, self.world_size)
        predicted_scores_gathered = all_gather_if_needed(predicted_scores.detach(), self.rank, self.world_size)
        actual_scores_gathered = all_gather_if_needed(actual_scores.detach(), self.rank, self.world_size)
        
        # Store metrics - these will be averaged later
        metrics[f'ordinal_performance_{train_test}/mean_absolute_error'] = score_diff_gathered.cpu().numpy().tolist()
        
        # Compute accuracy at different tolerance levels
        for tolerance in [0, 1, 2]:
            accurate_within_tolerance = (score_diff <= tolerance).float()
            accuracy_gathered = all_gather_if_needed(accurate_within_tolerance.detach(), self.rank, self.world_size)
            metrics[f'ordinal_performance_{train_test}/accuracy_within_{tolerance}'] = accuracy_gathered.cpu().numpy().tolist()

    # def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True,
    #                       lossInitializationPhase=False):
    #     # print("TRAIN/TEST --------------------- ", train)
    #     """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

    #     metrics = {}
    #     train_test = 'train' if train else 'eval'

    #     chosen_rewards, rejected_rewards = self.concatenated_forward(batch)
    #     rewardDifference = chosen_rewards - rejected_rewards

    #     print("SHAPES", chosen_rewards.shape, rejected_rewards.shape, rewardDifference.shape)

    #     # Compute ordinal performance metrics for ordinal and allThreshold losses
    #     if loss_config.name in ["ordinal", "allThreshold"] and self.ordinalModel is not None and not lossInitializationPhase:
    #         self.compute_ordinal_performance_metric(
    #             rewardDifference=rewardDifference,
    #             batch=batch,
    #             metrics=metrics,
    #             train_test=train_test
    #         )

    #     if loss_config.name == "dpo":
    #         # this has nothing to do with dpo. Just didn't want to add any other names to the loss configurations.
    #         losses = - F.logsigmoid(rewardDifference)
    #     elif loss_config.name == "marginBT":
    #         losses = - F.logsigmoid(rewardDifference
    #                                 - torch.tensor(batch['strength']).to(rewardDifference.device))
    #     elif loss_config.name == "scaledBT":
    #         losses = -torch.tensor(batch['strength']).to(rewardDifference.device) * F.logsigmoid(rewardDifference)
    #     elif loss_config.name == "softLabel":
    #         p = torch.tensor([0.65, 0.75, 0.85, 0.95]).to(rewardDifference.device)
    #         strengths = np.array(batch['strength'])
    #         probabilities = p[strengths]
    #         losses = (-probabilities * F.logsigmoid(rewardDifference)
    #                   - (1 - probabilities) * F.logsigmoid(-rewardDifference))
    #     elif loss_config.name == "ordinal":
    #         # Negative Log-Likelihood (NLL) loss for ordinal regression
    #         thetas, offset = self.ordinalModel.calculateThetas()
    #         thetas = thetas.to(rewardDifference.device)
    #         strengths = np.array(batch['strength']) + offset
    #         x = thetas[strengths + 1] - rewardDifference
    #         y = thetas[strengths] - rewardDifference
    #         innerTerm = F.sigmoid(x) - F.sigmoid(y)
    #         losses = -torch.log(innerTerm + 1e-6)

    #         # Enhanced ordinal logging
    #         detachedThetas = thetas[1:-1].detach()
    #         for i, t in enumerate(detachedThetas):
    #             metrics[f'ordinal/thetas_{i}'] = [t.cpu().numpy().tolist()]

    #     elif loss_config.name == "allThreshold":
    #         # All-Threshold loss implementation
    #         if self.ordinalModel is None:
    #             raise ValueError("allThreshold loss requires an OrdinalModel instance.")

    #         thetas, offset = self.ordinalModel.calculateThetas()
    #         thetas = thetas.to(rewardDifference.device)  # shape (2K+1,)
    #         strengths = torch.tensor(batch['strength'], device=rewardDifference.device)
    #         strengths_offset = strengths + offset  # shift so that -K..K maps to 0..2K

    #         # Exclude -inf and +inf endpoints
    #         inner_thetas = thetas[1:-1]  # shape (2K-1,)
    #         num_thresholds = inner_thetas.shape[0]

    #         # Prepare sign matrix ν(l;z)
    #         threshold_indices = torch.arange(num_thresholds, device=rewardDifference.device).unsqueeze(0)  # (1, 2K-1)
    #         sign_matrix = torch.where(threshold_indices < strengths_offset.unsqueeze(1),
    #                                   -1.0,  # l < z  -> -1
    #                                   1.0)   # l >= z -> +1

    #         # Broadcast theta and reward difference for vectorised computation
    #         zetas_expand = inner_thetas.unsqueeze(0).expand(rewardDifference.shape[0], -1)  # (B, 2K-1)
    #         rewards_expand = rewardDifference.unsqueeze(1).expand_as(zetas_expand)  # (B, 2K-1)

    #         # All-threshold logistic penalty
    #         losses_each = -F.logsigmoid(sign_matrix * (zetas_expand - rewards_expand))  # (B, 2K-1)
    #         losses = losses_each.sum(dim=1)  # (B,)

    #         # Enhanced allThreshold logging
    #         detachedThetas = thetas[1:-1].detach()
    #         for i, t in enumerate(detachedThetas):
    #             metrics[f'ordinal/thetas_{i}'] = [t.cpu().numpy().tolist()]
            
    #         # Log per-threshold loss contributions
    #         losses_each_mean = losses_each.mean(dim=0).detach()
    #         for i, loss_contrib in enumerate(losses_each_mean):
    #             metrics[f'allThreshold/loss_contrib_threshold_{i}'] = [loss_contrib.cpu().numpy().tolist()]

    #     else:
    #         raise NotImplementedError(f"Unknown loss name: {loss_config.name}")

    #     # Save losses before regularization for reporting
    #     losses_no_reg = losses.detach().clone()

    #     # Add L2 regularization for ordinal parameters if specified
    #     l2_reg_loss = torch.tensor(0.0, device=losses.device)
    #     if self.ordinal_l2_weight > 0 and self.ordinalModel is not None:
    #         l2_reg_loss = self.ordinal_l2_weight * sum(
    #             param.pow(2).sum() for param in self.ordinalModel.parameters()
    #         )
    #         # Ensure L2 regularization loss is on the same device as main losses
    #         l2_reg_loss = l2_reg_loss.to(losses.device)
    #         losses = losses + l2_reg_loss
            
    #         # Track L2 regularization loss
    #         l2_reg_gathered = all_gather_if_needed(l2_reg_loss.detach().unsqueeze(0), self.rank, self.world_size)
    #         metrics[f'regularization_{train_test}/l2_ordinal'] = l2_reg_gathered.cpu().numpy().tolist()

    #     chosen_rewards = chosen_rewards.detach()
    #     rejected_rewards = rejected_rewards.detach()

    #     reward_accuracies = (chosen_rewards > rejected_rewards).float()

    #     if loss_config.name in {"ordinal", "allThreshold"}:
    #         strengthTensor = torch.tensor(batch['strength']) * 1.0
    #         strengthTensor = strengthTensor.to(reward_accuracies.device)
    #         if not reward_accuracies.shape == strengthTensor.shape:
    #             strengthTensor = torch.zeros_like(reward_accuracies)
    #         multipliedRewards = (2 * reward_accuracies - 1) * strengthTensor
    #         multipliedRewards[strengthTensor == 0] = reward_accuracies[strengthTensor == 0]  # so not to skew the results because of those with 0 strength
    #         multipliedRewards[multipliedRewards > 0] = 1.
    #         multipliedRewards[multipliedRewards < 0] = 0.
    #         reward_accuracies = multipliedRewards

    #     # Gather metrics across devices
    #     chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size)
    #     rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size)
    #     reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size)

    #     metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist()
    #     metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist()
    #     metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist()
    #     metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist()

    #     # Report loss without regularization
    #     all_devices_losses_no_reg = all_gather_if_needed(losses_no_reg, self.rank, self.world_size)
    #     metrics[f'loss_no_reg/{train_test}'] = all_devices_losses_no_reg.cpu().numpy().tolist()

    #     # Report loss with regularization (if any)
    #     all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
    #     metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

    #     if loss_config.multiObjective:
    #         raise NotImplementedError
    #     else:
    #         if lossInitializationPhase:
    #             return [losses]
            
    #         # Compute Term loss on non-regularized losses separately for clean reporting
    #         # finalLoss_no_reg = calculateSingleCategoryLoss(self.term.specificTerms[0], losses_no_reg)
    #         # Update Term with regularized losses for actual training
    #         finalLoss = self.term.updateSpecificTerm(0, losses)
            
    #         # Term loss logging (without regularization)
    #         # all_devices_final_loss_no_reg = all_gather_if_needed(finalLoss_no_reg.unsqueeze(0).detach(), self.rank, self.world_size)
    #         # metrics[f'TermLoss_no_reg/{train_test}'] = all_devices_final_loss_no_reg.cpu().numpy().tolist()
            
    #         # Term loss logging (with regularization)
    #         all_devices_final_loss = all_gather_if_needed(finalLoss.unsqueeze(0).detach(), self.rank, self.world_size)
    #         metrics[f'TermLoss/{train_test}'] = all_devices_final_loss.cpu().numpy().tolist()

    #     return finalLoss, 


    def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True,
                        lossInitializationPhase=False):
        """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

        metrics = {}
        train_test = 'train' if train else 'eval'

        chosen_rewards, rejected_rewards = self.concatenated_forward(batch)
        rewardDifference = chosen_rewards - rejected_rewards

        # Compute ordinal performance metrics for ordinal and allThreshold losses
        if loss_config.name in ["ordinal", "allThreshold"] and self.ordinalModel is not None and not lossInitializationPhase:
            self.compute_ordinal_performance_metric(
                rewardDifference=rewardDifference,
                batch=batch,
                metrics=metrics,
                train_test=train_test
            )

        if loss_config.name == "dpo":
            # this has nothing to do with dpo. Just didn't want to add any other names to the loss configurations.
            losses = - F.logsigmoid(rewardDifference)
        elif loss_config.name == "marginBT":
            losses = - F.logsigmoid(rewardDifference
                                    - torch.tensor(batch['strength']).to(rewardDifference.device))
        elif loss_config.name == "scaledBT":
            losses = -torch.tensor(batch['strength']).to(rewardDifference.device) * F.logsigmoid(rewardDifference)
        elif loss_config.name == "softLabel":
            p = torch.tensor([0.65, 0.75, 0.85, 0.95]).to(rewardDifference.device)
            strengths = np.array(batch['strength'])
            probabilities = p[strengths]
            losses = (-probabilities * F.logsigmoid(rewardDifference)
                    - (1 - probabilities) * F.logsigmoid(-rewardDifference))
        elif loss_config.name == "ordinal":
            # Negative Log-Likelihood (NLL) loss for ordinal regression
            thetas, offset = self.ordinalModel.calculateThetas()
            thetas = thetas.to(rewardDifference.device)
            strengths = np.array(batch['strength']) + offset
            x = thetas[strengths + 1] - rewardDifference
            y = thetas[strengths] - rewardDifference
            innerTerm = F.sigmoid(x) - F.sigmoid(y)
            losses = -torch.log(innerTerm + 1e-6)

            # Enhanced ordinal logging
            detachedThetas = thetas[1:-1].detach()
            for i, t in enumerate(detachedThetas):
                metrics[f'ordinal/zeta_{i}'] = [t.cpu().numpy().tolist()]

        elif loss_config.name == "allThreshold":
            # All-Threshold loss implementation
            if self.ordinalModel is None:
                raise ValueError("allThreshold loss requires an OrdinalModel instance.")

            thetas, offset = self.ordinalModel.calculateThetas()
            thetas = thetas.to(rewardDifference.device)  # shape (2K+1,)
            strengths = torch.tensor(batch['strength'], device=rewardDifference.device)
            strengths_offset = strengths + offset  # shift so that -K..K maps to 0..2K

            # Exclude -inf and +inf endpoints
            inner_thetas = thetas[1:-1]  # shape (2K-1,)
            num_thresholds = inner_thetas.shape[0]

            # Prepare sign matrix ν(l;z)
            threshold_indices = torch.arange(num_thresholds, device=rewardDifference.device).unsqueeze(0)  # (1, 2K-1)
            sign_matrix = torch.where(threshold_indices < strengths_offset.unsqueeze(1),
                                    -1.0,  # l < z  -> -1
                                    1.0)   # l >= z -> +1

            # Broadcast theta and reward difference for vectorised computation
            zetas_expand = inner_thetas.unsqueeze(0).expand(rewardDifference.shape[0], -1)  # (B, 2K-1)
            rewards_expand = rewardDifference.unsqueeze(1).expand_as(zetas_expand)  # (B, 2K-1)

            # All-threshold logistic penalty
            losses_each = -F.logsigmoid(sign_matrix * (zetas_expand - rewards_expand))  # (B, 2K-1)
            losses = losses_each.sum(dim=1)  # (B,)

            # Enhanced allThreshold logging
            detachedThetas = thetas[1:-1].detach()
            for i, t in enumerate(detachedThetas):
                metrics[f'ordinal/zeta_{i}'] = [t.cpu().numpy().tolist()]
            
            # Log per-threshold loss contributions
            losses_each_mean = losses_each.mean(dim=0).detach()
            for i, loss_contrib in enumerate(losses_each_mean):
                metrics[f'allThreshold/loss_contrib_threshold_{i}'] = [loss_contrib.cpu().numpy().tolist()]

        else:
            raise NotImplementedError(f"Unknown loss name: {loss_config.name}")

        # Save losses before any modifications for reporting (these are the clean losses)
        losses_no_reg = losses.detach().clone()

        chosen_rewards = chosen_rewards.detach()
        rejected_rewards = rejected_rewards.detach()

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if loss_config.name in {"ordinal", "allThreshold"}:
            strengthTensor = torch.tensor(batch['strength']) * 1.0
            strengthTensor = strengthTensor.to(reward_accuracies.device)
            if not reward_accuracies.shape == strengthTensor.shape:
                strengthTensor = torch.zeros_like(reward_accuracies)
            multipliedRewards = (2 * reward_accuracies - 1) * strengthTensor
            multipliedRewards[strengthTensor == 0] = reward_accuracies[strengthTensor == 0]  # so not to skew the results because of those with 0 strength
            multipliedRewards[multipliedRewards > 0] = 1.
            multipliedRewards[multipliedRewards < 0] = 0.
            reward_accuracies = multipliedRewards

        # Gather metrics across devices
        chosen_rewards = all_gather_if_needed(chosen_rewards, self.rank, self.world_size)
        rejected_rewards = all_gather_if_needed(rejected_rewards, self.rank, self.world_size)
        reward_accuracies = all_gather_if_needed(reward_accuracies, self.rank, self.world_size)

        metrics[f'rewards_{train_test}/chosen'] = chosen_rewards.cpu().numpy().tolist()
        metrics[f'rewards_{train_test}/rejected'] = rejected_rewards.cpu().numpy().tolist()
        metrics[f'rewards_{train_test}/accuracies'] = reward_accuracies.cpu().numpy().tolist()
        metrics[f'rewards_{train_test}/margins'] = (chosen_rewards - rejected_rewards).cpu().numpy().tolist()

        # Report loss without regularization (clean losses)
        all_devices_losses_no_reg = all_gather_if_needed(losses_no_reg, self.rank, self.world_size)
        metrics[f'loss_no_reg/{train_test}'] = all_devices_losses_no_reg.cpu().numpy().tolist()

        # Note: We don't add L2 reg to individual losses anymore
        # Report the same clean losses for now (will add L2 reg to final loss later)
        all_devices_losses = all_gather_if_needed(losses.detach(), self.rank, self.world_size)
        metrics[f'loss/{train_test}'] = all_devices_losses.cpu().numpy().tolist()

        if loss_config.multiObjective:
            raise NotImplementedError
        else:
            if lossInitializationPhase:
                return [losses]
            
            # Keep Term loss calculation exactly as you had it
            finalLoss_no_reg = self.term.updateSpecificTerm(0, losses)
        
            # NOW add L2 regularization to the final scalar loss (not per-example losses)
            l2_reg_loss = torch.tensor(0.0, device=finalLoss_no_reg.device)
            if self.ordinal_l2_weight > 0 and self.ordinalModel is not None:
                l2_reg_loss = self.ordinal_l2_weight * sum(
                    param.pow(2).sum() for param in self.ordinalModel.parameters()
                )
                l2_reg_loss = l2_reg_loss.to(finalLoss_no_reg.device)
                
                # Track L2 regularization loss
                l2_reg_gathered = all_gather_if_needed(l2_reg_loss.detach().unsqueeze(0), self.rank, self.world_size)
                metrics[f'regularization_{train_test}/l2_ordinal'] = l2_reg_gathered.cpu().numpy().tolist()
            
            # Final loss for backprop (includes L2 reg)
            finalLoss = finalLoss_no_reg + l2_reg_loss
            
            # Term loss logging (without regularization)
            all_devices_final_loss_no_reg = all_gather_if_needed(finalLoss_no_reg.unsqueeze(0).detach(), self.rank, self.world_size)
            metrics[f'TermLoss_no_reg/{train_test}'] = all_devices_final_loss_no_reg.cpu().numpy().tolist()
            
            # Term loss logging (with regularization)
            all_devices_final_loss = all_gather_if_needed(finalLoss.unsqueeze(0).detach(), self.rank, self.world_size)
            metrics[f'TermLoss/{train_test}'] = all_devices_final_loss.cpu().numpy().tolist()

        return finalLoss, metrics


    def train(self):
        """Begin either SFT or DPO training, with periodic evaluation."""

        rank0_print(f'Using {self.config.optimizer} optimizer')
        self.optimizer = getattr(torch.optim, self.config.optimizer)(self.policy.parameters(), lr=self.config.lr)
        self.scheduler = (
            torch.optim.lr_scheduler.LambdaLR(self.optimizer,
                                              lr_lambda=lambda step:
                                              min(1.0, (step + 1) / (self.config.warmup_steps + 1))))

        self.ordinalOptimizer = None
        self.ordinalScheduler = None

        if self.ordinalModel is not None:
            self.ordinalOptimizer = getattr(torch.optim, self.config.optimizer)(self.ordinalModel.parameters(),
                                                                                lr=self.config.loss.ordinalLr)
            # self.ordinalScheduler = torch.optim.lr_scheduler.ExponentialLR(self.ordinalOptimizer,
                                                                        #    gamma=self.config.loss.schedulerGamma)
            self.ordinalScheduler = LambdaLR(self.ordinalOptimizer, lr_lambda=lambda _: 1.0)
    
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)
        random.seed(self.seed)

        self.example_counter = 0
        self.batch_counter = 0
        last_log = None

        for batch in self.train_iterator:

            #### BEGIN EVALUATION ####
            if self.example_counter % self.config.eval_every == 0 and (self.example_counter > 0 or self.config.do_first_eval):
                rank0_print(f'Running evaluation after {self.example_counter} train examples')

                self.policy.eval()
                all_eval_metrics = defaultdict(list)
                if self.config.sample_during_eval:
                    raise NotImplementedError(
                        'Sampling during eval is not for implemented for reward training pipeline!')

                for eval_batch in (tqdm.tqdm(self.eval_batches, desc='Computing eval metrics') if self.rank == 0 else self.eval_batches):
                    local_eval_batch = slice_and_move_batch_for_device(eval_batch, self.rank, self.world_size, self.rank)
                    with torch.no_grad():
                        _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False)

                    for k, v in eval_metrics.items():
                        all_eval_metrics[k].extend(v)

                mean_eval_metrics = {k: sum(v) / len(v) for k, v in all_eval_metrics.items()}

                rank0_print(f'eval after {self.example_counter}: {formatted_dict(mean_eval_metrics)}')

                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_eval_metrics, step=self.example_counter)

                if self.example_counter > 0 and self.config.saveCheckpoints:
                    if self.config.debug:
                        rank0_print('skipping save in debug mode')
                    else:
                        output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
                        rank0_print(f'creating checkpoint to write to {output_dir}...')
                        self.save(output_dir, mean_eval_metrics)
            #### END EVALUATION ####

            #### BEGIN TRAINING ####
            self.policy.train()

            start_time = time.time()
            batch_metrics = defaultdict(list)
            
            # ------ MICRO-BATCH (gradient accumulation) LOOP ------
            for microbatch_idx in range(self.config.gradient_accumulation_steps):
                global_microbatch = slice_and_move_batch_for_device(batch, microbatch_idx, self.config.gradient_accumulation_steps, self.rank)
                local_microbatch = slice_and_move_batch_for_device(global_microbatch, self.rank, self.world_size, self.rank)
                loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
                (loss / self.config.gradient_accumulation_steps).backward()

                for k, v in metrics.items():
                    batch_metrics[k].extend(v)

            # -------- AFTER ACCUMULATION --------
            # 1️⃣ Compute full raw norm BEFORE clipping (policy)
            raw_policy_grad_norm = torch.sqrt(
                sum(
                    (p.grad.detach() ** 2).sum()
                    for p in self.policy.parameters()
                    if p.grad is not None
                )
            ).item()
            batch_metrics['grad_norm_raw/policy'].append(raw_policy_grad_norm)

            # 2️⃣ Clip gradients and record the *pre-clip* norm (as returned by PyTorch)
            pre_clip_norm = torch.nn.utils.clip_grad_norm_(
                self.policy.parameters(), self.config.max_grad_norm
            ).item()
            batch_metrics['grad_norm/policy_pre_clip'].append(pre_clip_norm)

            # 3️⃣ Compute the full norm AFTER clipping (what will be used by optimizer)
            post_clip_norm = torch.sqrt(
                sum(
                    (p.grad.detach() ** 2).sum()
                    for p in self.policy.parameters()
                    if p.grad is not None
                )
            ).item()
            batch_metrics['grad_norm/policy_clipped'].append(post_clip_norm)

            # --- Handle ordinal model gradients separately if it exists ---
            if self.ordinalModel is not None:
                ord_grad_norm = torch.sqrt(
                    sum(
                        (p.grad.detach() ** 2).sum()
                        for p in self.ordinalModel.parameters()
                        if p.grad is not None
                    )
                ).item()
                batch_metrics['grad_norm/ordinal'].append(ord_grad_norm)

            # Determine if we should update ordinal model this batch
            self.ordinal_update_counter += 1
            should_update_ordinal = (self.ordinal_update_counter % self.ordinal_update_interval == 0)
            
            # Track ordinal update decisions
            if self.ordinalModel is not None:
                batch_metrics['ordinal_updates/should_update'].append(1.0 if should_update_ordinal else 0.0)
                batch_metrics['ordinal_updates/update_counter'].append(self.ordinal_update_counter)

            # optimizer + scheduler step for both policy and ordinal
            for optimizer, scheduler, name in zip(
                [self.optimizer, self.ordinalOptimizer],
                [self.scheduler, self.ordinalScheduler],
                ['policy', 'ordinal']
            ):
                if optimizer is None:
                    continue
                
                # For ordinal model, only update if it's time; otherwise zero gradients
                if name == 'ordinal':
                    if should_update_ordinal:
                        optimizer.step()
                        scheduler.step()
                        batch_metrics['ordinal_updates/actual_update'].append(1.0)
                    else:
                        # Zero out gradients without updating
                        optimizer.zero_grad()
                        batch_metrics['ordinal_updates/actual_update'].append(0.0)
                else:
                    # Always update policy
                    optimizer.step()
                    scheduler.step()

                # --- PARAMETER NORMS (always report for consistency) ---
                if name == 'policy':
                    param_norm = torch.sqrt(
                        sum(
                            (p.data ** 2).sum()
                            for p in self.policy.parameters()
                        )
                    ).item()
                    batch_metrics['param_norm/policy'].append(param_norm)
                elif name == 'ordinal':
                    # Always compute and report ordinal param norm for consistency
                    param_norm_ord = torch.sqrt(
                        sum(
                            (p.data ** 2).sum()
                            for p in self.ordinalModel.parameters()
                        )
                    ).item()
                    batch_metrics['param_norm/ordinal'].append(param_norm_ord)
                    
                    # --- LR MULTIPLIER (always report for consistency) ---
                    last_lr = scheduler.get_last_lr()[-1]
                    multiplier = last_lr / self.config.loss.ordinalLr
                    batch_metrics['lr_multiplier/ordinal'].append(multiplier)

                # LR logging (always report for consistency)
                last_lr_val = (
                    scheduler.get_last_lr()[0]
                    if isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR)
                    else scheduler.get_last_lr()[-1]
                )
                batch_metrics[f'lr/{name}'].append(last_lr_val)

                # Zero gradients (for policy always, for ordinal only if we updated)
                if name == 'policy' or (name == 'ordinal' and should_update_ordinal):
                    optimizer.zero_grad()

            step_time = time.time() - start_time
            examples_per_second = self.config.batch_size / step_time
            batch_metrics['examples_per_second'].append(examples_per_second)

            self.batch_counter += 1
            self.example_counter += self.config.batch_size

            if last_log is None or time.time() - last_log > self.config.minimum_log_interval_secs:
                mean_train_metrics = {k: sum(v) / len(v) for k, v in batch_metrics.items()}

                mean_train_metrics['counters/examples'] = self.example_counter
                mean_train_metrics['counters/updates'] = self.batch_counter

                if self.config.wandb.enabled and self.rank == 0:
                    wandb.log(mean_train_metrics, step=self.example_counter)

                last_log = time.time()
            #### END TRAINING ####


    def clip_gradient(self):
        """Clip the gradient norm of the parameters of a non-FSDP policy."""
        return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm).item()

    def write_state_dict(self, step: int, state: Dict[str, torch.Tensor], metrics: Dict, filename: str, dir_name: Optional[str] = None):
        """Write a checkpoint to disk."""
        if dir_name is None:
            dir_name = os.path.join(self.run_dir, f'LATEST')

        os.makedirs(dir_name, exist_ok=True)
        output_path = os.path.join(dir_name, filename)
        rank0_print(f'writing checkpoint to {output_path}...')
        torch.save({
            'step_idx': step,
            'state': state,
            'metrics': metrics if metrics is not None else {},
        }, output_path)
    
    def save(self, output_dir: Optional[str] = None, metrics: Optional[Dict] = None):
        """Save policy, optimizer, and scheduler state to disk."""

        policy_state_dict = self.policy.state_dict()
        self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'rewardModel.pt', output_dir)
        del policy_state_dict

        if self.ordinalModel is not None:
            ordinalModel_state_dict = self.ordinalModel.state_dict()
            self.write_state_dict(self.example_counter, ordinalModel_state_dict, metrics, 'ordinalModel.pt',
                                  output_dir)
            del ordinalModel_state_dict


class FSDPTrainer(BasicTrainer):
    def __init__(self, policy: nn.Module, config: DictConfig, seed: int, run_dir: str,
                 rank: int = 0, world_size: int = 1,
                 ordinalModel: Optional[nn.Module] = None):
        """A trainer subclass that uses PyTorch FSDP to shard the model across multiple GPUs.
        
           This trainer will shard both the policy and reference model across all available GPUs.
           Models are sharded at the block level, where the block class name is provided in the config.
        """

        super().__init__(policy, config, seed, run_dir, rank, world_size, ordinalModel)
        assert config.model.block_name is not None, 'must specify model.block_name (e.g., GPT2Block or GPTNeoXLayer) for FSDP'

        wrap_class = get_block_class_from_model(policy, config.model.block_name)
        model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class},)

        shared_fsdp_kwargs = dict(
            auto_wrap_policy=model_auto_wrap_policy,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=False),
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            device_id=rank,
            ignored_modules=None,
            limit_all_gathers=False,
            use_orig_params=config.peft,
            sync_module_states=False
        )

        rank0_print('Sharding policy...')
        mp_dtype = getattr(torch, config.model.fsdp_policy_mp) if config.model.fsdp_policy_mp is not None else None
        policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
        self.policy = FSDP(policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)

        if config.activation_checkpointing:
            rank0_print('Attempting to enable activation checkpointing...')
            try:
                from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
                    checkpoint_wrapper,
                    apply_activation_checkpointing,
                    CheckpointImpl,
                )
                non_reentrant_wrapper = functools.partial(
                    checkpoint_wrapper,
                    offload_to_cpu=False,
                    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
                )
            except Exception as e:
                rank0_print('FSDP activation checkpointing not available:', e)
            else:
                check_fn = lambda submodule: isinstance(submodule, wrap_class)
                rank0_print('Applying activation checkpointing wrapper to policy...')
                apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
                rank0_print('FSDP activation checkpointing enabled!')
        
        print('Loaded model on rank', rank)
        dist.barrier()

    def clip_gradient(self):
        """Clip the gradient norm of the parameters of an FSDP policy, gathering the gradients across all GPUs."""
        return self.policy.clip_grad_norm_(self.config.max_grad_norm).item()
    
    def save(self, output_dir=None, metrics=None):
        """Save policy, optimizer, and scheduler state to disk, gathering from all processes and saving only on the rank 0 process."""
        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
        with FSDP.state_dict_type(self.policy, StateDictType.FULL_STATE_DICT, state_dict_config=save_policy):
            policy_state_dict = self.policy.state_dict()

        if self.rank == 0:
            self.write_state_dict(self.example_counter, policy_state_dict, metrics,
                                  'rewardModel.pt', output_dir)
        del policy_state_dict
        dist.barrier()

        if self.ordinalModel is not None:
            ordinal_state_dict = self.ordinalModel.state_dict()
            if self.rank == 0:
                self.write_state_dict(self.example_counter, ordinal_state_dict, metrics,
                                      'ordinalModel.pt', output_dir)
            del ordinal_state_dict
            dist.barrier()


class TensorParallelTrainer(BasicTrainer):
    def __init__(self, policy, config, seed, run_dir, rank=0, world_size=1,
                 ordinalModel=None):
        """A trainer subclass that uses TensorParallel to shard the model across multiple GPUs.

           Based on https://github.com/BlackSamorez/tensor_parallel. Note sampling is extremely slow,
              see https://github.com/BlackSamorez/tensor_parallel/issues/66.
        """
        super().__init__(policy, config, seed, run_dir, rank, world_size, ordinalModel)
        
        rank0_print('Sharding policy...')
        self.policy = tp.tensor_parallel(policy, sharded=True)

    def save(self, output_dir=None, metrics=None):
        """Save (unsharded) policy state to disk."""
        with tp.save_tensor_parallel(self.policy):
            policy_state_dict = self.policy.state_dict()
    
        self.write_state_dict(self.example_counter, policy_state_dict, metrics, 'policy.pt', output_dir)
        del policy_state_dict
