              
                                                      
                                                                 

from abc import ABC, abstractclassmethod
from copy import deepcopy
from typing import Literal, Optional, Tuple, Union, Dict, Callable, List
from enum import Enum
from datetime import datetime
import logging
import random
import math
import sys

from einops import rearrange
from torch import Tensor
from unittest.mock import patch
import torch
import torch.nn.functional as F

from megatron.core import mpu, parallel_state, tensor_parallel, InferenceParams
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.tensor_parallel.layers import RowParallelLinear
from megatron.core.transformer.enums import AttnMaskType, ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import divide
from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint

from gpatch.core.aligner_helper import get_iterator_k_split
from gpatch.core.aligner_helper import get_ltor_masks_and_position_ids
from gpatch.core.aligner_interface import SupervisedInterface, Inferrable, ResetArgsMixin
from gpatch.core.device_type import is_wxacc1
from gpatch.core.model_parallel_config import ForcedConfig


class StateDictState(Enum):
    CRITIC = 0
    REWARD = 1


class RewardModelHead(RowParallelLinear):
    """
    Reward model head to convert from output_size to scalar reward.
    """

    def __init__(
        self,
        input_size,
        output_size,
        *,
        config: TransformerConfig,
        init_method: Callable,
        bias: bool,
        input_is_parallel: bool,
        skip_bias_add: bool,
        stride: int = 1,
                 
        output_sequence: bool = False,
        output_scalar: bool = True,
        use_avg_pool: bool = False,
        dtype: torch.dtype = torch.float32,
        merge_attributes: bool = False,
        attributes_weights: Optional[List[Union[float, int]]] = None,
    ):
        assert output_size > 0, "Output size of reward model head should be greater than zero"
        assert not input_is_parallel                                               
        self.yet_forced_config = config
        forced_config = ForcedConfig(config, params_dtype=dtype, sequence_parallel=False)

        super().__init__(
            input_size,
            output_size,
            config=forced_config,
            init_method=init_method,
            bias=bias,
            input_is_parallel=input_is_parallel,
            skip_bias_add=skip_bias_add,
            stride=stride,
        )

                                   
         
                           
                                                             
                                                                           
                                 
         
                      
                                                                 
                                                                 
        self.output_sequence = output_sequence
        self.output_scalar = output_scalar

        self.use_avg_pool = use_avg_pool
        self.dtype = dtype
        self.merge_attributes = merge_attributes

        if attributes_weights is None:
            self.attributes_weights = torch.full((self.output_size, ), 1.0 / self.output_size)
        else:
            self.attributes_weights = torch.tensor(attributes_weights, dtype=torch.float)

        assert self.attributes_weights.size(0) == self.output_size

    @torch.no_grad()
    def debug_values(self, hidden_states, lengths):
        seqlen = hidden_states.size(0)
        mask = torch.arange(seqlen, device=lengths.device).unsqueeze(-1) < lengths

        per_token_rewards = super().forward(hidden_states.to(
            self.weight.dtype))[0]                              
        per_token_rewards = (per_token_rewards.squeeze(-1) * mask.float()).permute(1, 0)

        if torch.distributed.get_rank() == 7:
            if random.randint(0, 1000) < 50:
                print(
                    f'per_token_rewards {torch.min(per_token_rewards)} {torch.max(per_token_rewards)}'
                )

    def _compute_attributes(self, hidden_states, lengths, prompt_lens):
        """
        for critic, return a tensor with shape [B x S x self.output_size]
        for reward, return a tensor with shape [B x self.output_size]
        """

                                                                      
        autocast_context = torch.autocast(device_type=hidden_states.device.type, dtype=self.dtype)

                                  
        if not self.output_sequence:
            seq_out = None
        else:
            with autocast_context:
                output = super().forward(hidden_states.to(
                    self.weight.dtype))[0]                              

                                                                                        
            seq_out = output.permute(1, 0, 2).contiguous()                              

        if not self.output_scalar:
            scal_out = None
        else:
            if self.use_avg_pool:
                                                                                 
                                                   
                if prompt_lens is None:
                    mask = torch.arange(hidden_states.size(0),
                                        device=lengths.device).unsqueeze(-1) < lengths
                else:
                    mask = torch.arange(hidden_states.size(0), device=lengths.device).unsqueeze(-1)
                    mask = torch.logical_and(mask < lengths, mask >= prompt_lens)

                                       
                last_state = (hidden_states * mask.unsqueeze(-1)).sum(0)

                                         
                                                                                                    
                last_state = last_state / lengths.unsqueeze(-1)
            else:
                last_state = hidden_states[
                    lengths - 1,
                    torch.arange(lengths.shape[0], device=hidden_states.device), :]

                                                                
            last_state = last_state.unsqueeze(0)

                                                                 
            with autocast_context:
                scal_out = super().forward(last_state.to(self.weight.dtype))[0].squeeze(0)

        return seq_out, scal_out

    def forward(self, hidden_states, lengths, prompt_lens, return_attributes=False):
        assert self.output_sequence or self.output_scalar

        seq_out, scal_out = self._compute_attributes(
            hidden_states, lengths,
            prompt_lens)                                                        

        if return_attributes:
            if seq_out is not None:
                seq_out = seq_out.float()
            if scal_out is not None:
                scal_out = scal_out.float()
            return seq_out, scal_out

        attributes = seq_out if seq_out is not None else scal_out
        self.attributes_weights = self.attributes_weights.to(attributes.device)
        assert self.attributes_weights.dtype == torch.float32

        if seq_out is None:
            scores_seq = None
        else:
                                                                                                               
            attributes = seq_out
            assert attributes.dim(
            ) == 3, "for critic, attributes should have shape [B x S x self.output_size]"
            scores_seq = (attributes @ self.attributes_weights.to(attributes.dtype)).float()

        if scal_out is None:
            scores_scal = None
        else:
            attributes = scal_out
            assert attributes.dim(
            ) == 2, "for reward, attributes should have shape [B x self.output_size]"
            if not self.merge_attributes:
                                                                                          
                scores_scal = attributes.float()
            else:
                                                                
                scores_scal = (attributes @ self.attributes_weights.to(
                    attributes.dtype)).unsqueeze(-1).float()

        return scores_seq, scores_scal


class MultiLayerRMHead(MegatronModule):
                                                                      

    def __init__(
        self,
        input_size,
        output_size,
        *,
        config: TransformerConfig,
        init_method: Callable,
        stride: int = 1,
                 
        output_sequence: bool = False,
        output_scalar: bool = True,
        use_avg_pool: bool = False,
        dtype: torch.dtype = torch.float32,
    ):
        super().__init__(config=config)
        self.input_size = input_size
        self.output_size = output_size

        self.stride = stride
        self.output_sequence = output_sequence
        self.output_scalar = output_scalar
        self.use_avg_pool = use_avg_pool
        self.dtype = dtype
        self.config = config
        self.init_method = init_method

        self.score = torch.nn.Sequential(torch.nn.Linear(self.input_size, self.config.hidden_size),
                                         torch.nn.ReLU(),
                                         torch.nn.Linear(self.config.hidden_size, self.output_size))
        for module in self.score:
            if isinstance(module, torch.nn.Linear):
                self.init_method(module.weight)

    def forward(self, hidden_states, lengths, prompt_lens, return_attributes=False):
        assert not return_attributes
                                                                      
        autocast_context = torch.autocast(device_type=hidden_states.device.type, dtype=self.dtype)

        if not self.output_sequence:
            seq_out = None
        else:
            with autocast_context:
                output = self.score(hidden_states.to(self.score[0].weight.dtype))

                                                                                        
            seq_out = output.permute(1, 0, 2).contiguous()                              
            seq_out = seq_out.squeeze(-1).float()
        if not self.output_scalar:
            scal_out = None
        else:
            assert not self.use_avg_pool, f"qwen-rm-72B multi-layer rm does not support avg_pool"
                                     
            if self.use_avg_pool:
                                                                                 
                                                   
                if prompt_lens is None:
                    mask = torch.arange(hidden_states.size(0),
                                        device=lengths.device).unsqueeze(-1) < lengths
                else:
                    mask = torch.arange(hidden_states.size(0), device=lengths.device).unsqueeze(-1)
                    mask = torch.logical_and(mask < lengths, mask >= prompt_lens)

                                       
                last_state = (hidden_states * mask.unsqueeze(-1)).sum(0)

                                         
                                                                                                    
                last_state = last_state / lengths.unsqueeze(-1)
            else:
                last_state = hidden_states[
                    lengths - 1,
                    torch.arange(lengths.shape[0], device=hidden_states.device), :]
                                                                
            last_state = last_state.unsqueeze(0)

                                                                 
            with autocast_context:
                scal_out = self.score(last_state.to(self.score[0].weight.dtype)).squeeze(0).float()

        return seq_out, scal_out


                                                         
class GptRewardModel(GPTModel):

    def __init__(
        self,
        config: TransformerConfig,
        transformer_layer_spec: ModuleSpec,
        vocab_size: int,
        max_sequence_length: int,
        pre_process: bool = True,
        post_process: bool = True,
        fp16_lm_cross_entropy: bool = False,
        parallel_output: bool = True,
        share_embeddings_and_output_weights: bool = False,
        position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute',
        rotary_percent: float = 1.0,
        rotary_base: int = 10000,
        rope_scaling: bool = False,
        rope_scaling_factor: float = 8.0,
        scatter_embedding_sequence_parallel: bool = True,
        seq_len_interpolation_factor: Optional[float] = None,
                 
        output_sequence: bool = False,
        output_scalar: bool = True,
        use_avg_pool: bool = False,
        head_dtype: torch.dtype = None,
        num_attributes: int = 1,
        attribute_weights: Optional[List[Union[float, int]]] = None,
        merge_attributes: bool = False,
        mask_prompt: bool = False,
    ):
        GPTModel.__init__(
            self,
            config,
            transformer_layer_spec,
            vocab_size,
            max_sequence_length,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=fp16_lm_cross_entropy,
            parallel_output=parallel_output,
            share_embeddings_and_output_weights=share_embeddings_and_output_weights,
            position_embedding_type=position_embedding_type,
            rotary_percent=rotary_percent,
            rotary_base=rotary_base,
            rope_scaling=rope_scaling,
            rope_scaling_factor=rope_scaling_factor,
            scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel,
            seq_len_interpolation_factor=seq_len_interpolation_factor,
        )
        self.num_attributes = num_attributes
        self.mask_prompt = mask_prompt

        assert self.config.expert_model_parallel_size >= 1

        if self.post_process and self.config.rm_head_arch == "single_layer":
            self.rm_head = RewardModelHead(
                self.config.hidden_size,
                num_attributes,
                config=config,
                init_method=self.config.init_method,
                bias=False,
                input_is_parallel=False,
                skip_bias_add=False,
                output_sequence=output_sequence,
                output_scalar=output_scalar,
                use_avg_pool=use_avg_pool,
                dtype=config.params_dtype if head_dtype is None else head_dtype,
                merge_attributes=merge_attributes,
                attributes_weights=attribute_weights,
            )
        elif self.post_process and self.config.rm_head_arch == "multi_layers":
            self.rm_head = MultiLayerRMHead(
                self.config.hidden_size,
                num_attributes,
                config=config,
                init_method=self.config.init_method,
                output_sequence=output_sequence,
                output_scalar=output_scalar,
                use_avg_pool=use_avg_pool,
                dtype=config.params_dtype if head_dtype is None else head_dtype,
            )

    def forward(
        self,
        input_ids: Tensor,
        lengths: Tensor,
        position_ids: Tensor,
        attention_mask: Tensor,
        prompt_lens: Tensor = None,
        decoder_input: Tensor = None,
        labels: Tensor = None,
        inference_params=None,
        packed_seq_params: PackedSeqParams = None,
        extra_block_kwargs: dict = None,
        return_attributes: bool = False,
    ):
        assert packed_seq_params == None and extra_block_kwargs == None

                                       
                                                    
        with patch.object(self, "post_process", False):
            hidden_states = super().forward(
                input_ids=input_ids,
                position_ids=position_ids,
                attention_mask=attention_mask,
                decoder_input=decoder_input,
                labels=labels,
                inference_params=inference_params,
            )
        if not self.post_process:
            return hidden_states

        if self.config.sequence_parallel:
            hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
                hidden_states, tensor_parallel_output_grad=False)
        if not self.mask_prompt:
            prompt_lens = None
        else:
            assert prompt_lens is not None

                                  
        out_seq, out_scalar = self.rm_head(hidden_states,
                                           lengths,
                                           prompt_lens,
                                           return_attributes=return_attributes)

                                                                    
        rets = ()
        if self.rm_head.output_sequence:
            rets += (out_seq, )
        if self.rm_head.output_scalar:
            rets += (out_scalar, )
        if len(rets) == 1:
            return rets[0]
        else:
            return rets

    def split_output_tensor(self, output_tensor):
        out_golden = None
        if self.config.rm_use_triplet_loss:
                         
            assert output_tensor.shape[0] % 3 == 0, f"mbs must be divisible by 3"
            rbs = output_tensor.shape[0] // 3
            out_chosen, out_rejected, out_golden = torch.split(output_tensor.float(), rbs, dim=0)
        else:
            out_chosen, out_rejected = torch.split(output_tensor.float(),
                                                   output_tensor.shape[0] // 2,
                                                   dim=0)
        return out_chosen, out_rejected, out_golden

    def focal_loss(self, focal_loss_lambda, focal_loss_gamma, focal_loss_range, ranking_coef,
                   out_chosen, out_rejected):
                                           
                                                            
                                       
        if is_wxacc1():
                                                                      
            out_chosen = out_chosen.squeeze(0)
            out_rejected = out_rejected.squeeze(0)

        assert ranking_coef > 0
        ranking_inverse_one_minus_coef = 1. / (1 - ranking_coef)

        p_ij = torch.nn.functional.sigmoid(out_chosen - out_rejected)
        l_rank = -(
            (1. - ranking_inverse_one_minus_coef * torch.nn.functional.relu(p_ij - ranking_coef))**
            focal_loss_gamma) * torch.log(p_ij)
        l_panalty_c = -(torch.nn.functional.logsigmoid(out_chosen + focal_loss_range) +
                        torch.nn.functional.logsigmoid(focal_loss_range - out_chosen))
        l_panalty_r = -(torch.nn.functional.logsigmoid(out_rejected + focal_loss_range) +
                        torch.nn.functional.logsigmoid(focal_loss_range - out_rejected))
        l_panalty = 0.5 * (l_panalty_c + l_panalty_r)
        loss = (l_rank[0] + focal_loss_lambda * l_panalty[0]).mean()
        return loss

    def rm_triplet_loss(self, focal_loss_lambda, focal_loss_gamma, focal_loss_range, out_chosen,
                        out_rejected, out_golden, acc_chosen):
        triplet_loss = F.triplet_margin_loss(out_golden,
                                             out_chosen,
                                             out_rejected,
                                             margin=self.config.rm_golden_margin)
        ranking_coefs = self.config.rm_focal_loss_ranking_coef
        loss_cr = self.focal_loss(focal_loss_lambda, focal_loss_gamma, focal_loss_range,
                                  ranking_coefs[0], out_chosen, out_rejected)
        loss_gc = self.focal_loss(focal_loss_lambda, focal_loss_gamma, focal_loss_range,
                                  ranking_coefs[1], out_golden, out_chosen)
        loss_gr = self.focal_loss(focal_loss_lambda, focal_loss_gamma, focal_loss_range,
                                  ranking_coefs[2], out_golden, out_rejected)
                                                                                   
        triplet_focal_loss_coef = self.config.rm_triplet_focal_coef
        focal_loss = (triplet_focal_loss_coef[0] * loss_cr + \
                      triplet_focal_loss_coef[1] * loss_gc + \
                      triplet_focal_loss_coef[2] * loss_gr)
        loss = self.config.rm_triplet_coef * triplet_loss + (
            1 - self.config.rm_triplet_coef) * focal_loss

        acc_gc_comp = out_golden > out_chosen
        acc_gc = torch.sum(acc_gc_comp) / acc_gc_comp.shape[0]
        acc_gr_comp = out_golden > out_rejected
        acc_gr = torch.sum(acc_gr_comp) / acc_gr_comp.shape[0]

        return loss, acc_chosen, triplet_loss, loss_cr, loss_gc, loss_gr, acc_gc, acc_gr

                                           
    def rm_loss(self, output_tensor, use_focal_loss, focal_loss_lambda, focal_loss_gamma,
                focal_loss_range):
        out_chosen, out_rejected, out_golden = self.split_output_tensor(output_tensor)
        if not self.config.rm_use_triplet_loss:
            assert out_golden is None

        comp = out_chosen > out_rejected
        acc_chosen = torch.sum(comp) / comp.shape[0]

        if self.config.rm_use_triplet_loss:
            triplet_loss_output = self.rm_triplet_loss(focal_loss_lambda, focal_loss_gamma,
                                                       focal_loss_range, out_chosen, out_rejected,
                                                       out_golden, acc_chosen)
            return triplet_loss_output
        elif not use_focal_loss:
            loss = -torch.nn.functional.logsigmoid(out_chosen - out_rejected).mean()
        else:
            loss = self.focal_loss(focal_loss_lambda, focal_loss_gamma, focal_loss_range,
                                   self.config.rm_focal_loss_ranking_coef[0], out_chosen,
                                   out_rejected)

        return loss, acc_chosen

              
                                                                          
                                                                                                                                                                      
    def rm_sentence_loss(self, score, scattered_labels):
        score = score.view(-1, self.num_attributes)
        scattered_labels = scattered_labels.view(-1, self.num_attributes).float()

        fn = torch.nn.BCEWithLogitsLoss(reduction='none')
        losses = fn(score, scattered_labels)
        return losses

              
    def rm_loss_softmax(self, output_tensor):
        out_chosen, out_rejected = self.split_output_tensor(output_tensor)
        comp = out_chosen > out_rejected
        acc_chosen = torch.sum(comp) / comp.shape[0]

        fn = torch.nn.LogSoftmax(dim=1)
        out = torch.concat([out_chosen, out_rejected], dim=1)
        out = fn(out)
        label = torch.zeros([out.shape[0]], device=out.device).long()
        loss = F.nll_loss(out, label)
        return loss, acc_chosen

              
    def rm_qa_dt_loss(self, score, scattered_labels):
        score = score.view(-1)
        scattered_labels = scattered_labels.view(-1).float()
        fn = torch.nn.BCEWithLogitsLoss(reduction='none')

        target_ignore = torch.where(scattered_labels != -1, 1.0, 0.0).float()
        loss = torch.sum(
            fn(score, scattered_labels) * target_ignore) / torch.sum(target_ignore + 1e-5)
        return loss

              
    def rm_loss_mo(self, output_tensor, p_labels, masks):

        out_chosen, out_rejected, _ = self.split_output_tensor(output_tensor)
        p_labels = p_labels[:p_labels.shape[0] // 2]
        masks = masks[:masks.shape[0] // 2]

        def get_pair_wise_loss(out_chosen, out_rejected, p_labels, masks, idx):
            fn = torch.nn.LogSoftmax(dim=1)

                          
                                                                                                                          
                   
                                                                                      

            out = torch.stack([out_chosen[:, idx], out_rejected[:, idx]], dim=1)

            pred = fn(out)
            label = p_labels[:, idx]
            mask = masks[:, idx]
            loss = F.nll_loss(pred, label, reduction='none')
            loss = torch.sum(loss * mask) / (torch.sum(mask) + 1e-5)
            return loss

                 
        idx = 0
        qa_use_loss = get_pair_wise_loss(out_chosen, out_rejected, p_labels, masks, idx)

                  
        idx = 1
        qa_read_loss = get_pair_wise_loss(out_chosen, out_rejected, p_labels, masks, idx)

                  
        idx = 2
        qa_total_loss = get_pair_wise_loss(out_chosen, out_rejected, p_labels, masks, idx)

        loss = qa_use_loss + qa_read_loss + qa_total_loss
                            

        idx = 2
        comp = out_chosen[:, idx] > out_rejected[:, idx]
        acc_chosen = torch.sum(comp) / comp.shape[0]

        return loss, acc_chosen

              
    def rm_loss_odin(self, tokens, output_tensor):
                         
        tokens_lens = torch.sum(tokens != 0, dim=1)
        lens_chosen, lens_rejected = self.split_output_tensor(tokens_lens)

        out_chosen, out_rejected = self.split_output_tensor(output_tensor)
        ranking_loss = -torch.nn.functional.logsigmoid(out_chosen - out_rejected).mean()

        dp_group = parallel_state.get_data_parallel_group()
        tokens_lens = tokens_lens.to(device=torch.cuda.current_device(), dtype=torch.float32)
        tokens_lens_gather_list = [
            torch.empty_like(tokens_lens) for _ in range(torch.distributed.get_world_size(dp_group))
        ]
        torch.distributed.all_gather(tokens_lens_gather_list, tokens_lens, group=dp_group)
        tokens_lens_gather = torch.concat(tokens_lens_gather_list, dim=0)

        output_tensor = output_tensor.to(device=torch.cuda.current_device(), dtype=torch.float32)
        output_tensor_gather_list = [
            torch.empty_like(output_tensor)
            for _ in range(torch.distributed.get_world_size(dp_group))
        ]
        torch.distributed.all_gather(output_tensor_gather_list, output_tensor, group=dp_group)
        output_tensor_gather = torch.concat(output_tensor_gather_list, dim=0)

                                                                                                

        length_corr_matrix1 = torch.stack((tokens_lens_gather, output_tensor_gather[:, 0]))
        length_corr1 = torch.corrcoef(length_corr_matrix1.float())[0, 1]
        length_loss1 = 1 - length_corr1                         
                                                                       
        length_corr_matrix2 = torch.stack((tokens_lens_gather, output_tensor_gather[:, 1]))
        length_corr2 = torch.corrcoef(length_corr_matrix2.float())[0, 1]
        length_loss2 = torch.abs(length_corr2)                          

        loss = ranking_loss + (length_loss2) / torch.distributed.get_world_size(dp_group)

                            
               
        comp = out_chosen > out_rejected
        acc_chosen = torch.sum(comp) / comp.shape[0]

        return loss, acc_chosen, ranking_loss.detach(), length_loss1.detach(), length_loss2.detach()
