              
                                                      
                                                                 
                 

import logging
from typing import Literal, Optional, List, Dict
from typing_extensions import override
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange

from megatron.core import mpu
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.parallel_state import get_tensor_model_parallel_group

from gpatch.core.tensor_parallel.mappings import reduce_from_context_parallel_region
from gpatch.core.tensor_parallel.mappings import (
    all_gather_to_context_parallel_region, )


class GptDpoModel(GPTModel):

    def __init__(
        self,
        *args,
                           
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: str = 'sigmoid',
        ftx_gamma: float = 0.,
        model_using: str = 'both',
        golden_margin: float = 0.5,
        forward_without_loss=False,
        orpo_loss=False,
        golden_loss=False,
        pair_gamma: float = 1.0,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        assert model_using == "policy" or model_using == "ref"
        self.beta = beta
        self.label_smoothing = label_smoothing
        self.loss_type = loss_type
        self.ftx_gamma = ftx_gamma
        self.pair_gamma = pair_gamma
        self.forward_without_loss = forward_without_loss
        self.orpo_loss = orpo_loss
        self.golden_loss = golden_loss
        assert not self.golden_loss
        self.dpo_golden_margin = golden_margin

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            print(f"GptDpoModel init {golden_loss=} {ftx_gamma=} {pair_gamma=}")

        self.dpo_reward_models_cnt = self.config.dpo_reward_models_cnt

    def forward(
        self,
        input_ids: Tensor,
        position_ids: Tensor,
        loss_mask_full: Tensor,
        attention_mask: Tensor,
        decoder_input: Tensor = None,
        labels: Tensor = None,
        inference_params: InferenceParams = None,
        packed_seq_params: PackedSeqParams = None,
        dpo_weights: Tensor = None,
        w0_weights: Tensor = None,
        reward_margins: Tensor = None,
        ref_logps: Tensor = None,
        extra_block_kwargs: dict = None,
    ) -> Tensor:
        hidden_states = None
        hidden_states = super().forward(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            decoder_input=decoder_input,
            labels=None,
            inference_params=inference_params,
            packed_seq_params=packed_seq_params,
            extra_block_kwargs=extra_block_kwargs,
        )

        if not self.post_process:
            return hidden_states

        logits = hidden_states

        if labels is None:
                                
            return logits

        if self.forward_without_loss:
            return self.margin_or_logps(labels, loss_mask_full, logits)

        return self.modpo(labels, logits, loss_mask_full, dpo_weights, w0_weights, reward_margins,
                          ref_logps)

    def convert_logit(self, logits):
        logits = logits.float().contiguous()
        return logits

    def margin_or_logps(self, labels, loss_mask_full, logits):
                              
        logps = None
        logits = self.convert_logit(logits)
        logps = self.get_batch_logps(logits, labels, loss_mask_full)
        return logps

    def modpo(self, labels, logits, loss_mask_full, dpo_weights, w0_weights, reward_margins,
              ref_logps):
                    
        logits = self.convert_logit(logits)

        rbs = logits.shape[0] // 2
        policy_chosen_logits, policy_rejected_logits = logits.split(rbs, dim=0)
                      
        policy_chosen_logits_mean = policy_chosen_logits.detach().mean().float()
        policy_rejected_logits_mean = policy_rejected_logits.detach().mean().float()

        if self.config.context_parallel_size > 1:
            reduce_op = torch.distributed.ReduceOp.AVG
            torch.distributed.all_reduce(policy_chosen_logits_mean,
                                         op=reduce_op,
                                         group=mpu.get_context_parallel_group())
            torch.distributed.all_reduce(policy_rejected_logits_mean,
                                         op=reduce_op,
                                         group=mpu.get_context_parallel_group())

        if False:
            assert False, "never reach"
        else:
            torch.distributed.all_reduce(
                policy_chosen_logits_mean,
                op=torch.distributed.ReduceOp.AVG,
                group=get_tensor_model_parallel_group(),
            )
            torch.distributed.all_reduce(
                policy_rejected_logits_mean,
                op=torch.distributed.ReduceOp.AVG,
                group=get_tensor_model_parallel_group(),
            )

                                                    
        logps = self.get_batch_logps(logits, labels, loss_mask_full)
        assert ref_logps is not None

        policy_chosen_logps, policy_rejected_logps = logps.split(rbs, dim=0)
        reference_chosen_logps, reference_rejected_logps = ref_logps.split(rbs, dim=0)

                  
        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

                  
        if self.ftx_gamma > 1e-6:
            chosen_loss_mask_full = loss_mask_full.split(rbs, dim=0)[0]
            chosen_loss_mask_full_sum = chosen_loss_mask_full.sum(-1)
            losses -= self.ftx_gamma * policy_chosen_logps / chosen_loss_mask_full_sum
                                                         
                                                             
                                                       
                                                                                    
                                                                            

        metrics = {
            "dpo-metrics/rewards-accuracies": reward_accuracies.mean().float(),
            "dpo-metrics/logps-rejected": policy_rejected_logps.detach().mean().float(),
            "dpo-metrics/logps-chosen": policy_chosen_logps.detach().mean().float(),
            "dpo-metrics/logits-rejected": policy_rejected_logits_mean,
            "dpo-metrics/logits-chosen": policy_chosen_logits_mean,
        }
                     
        metrics["dpo-metrics/rewards-chosen"] = chosen_rewards.mean().float()
        metrics["dpo-metrics/rewards-rejected"] = rejected_rewards.mean().float()
        metrics["dpo-metrics/rewards-margins"] = (chosen_rewards - rejected_rewards).mean().float()
                                
        return losses

    def get_batch_logps(self, logits, labels, loss_mask_full, average_log_prob=False):
        logits = rearrange(logits, 'b s h -> s b h').contiguous()
        labels = rearrange(labels, 'b s -> s b').contiguous()
        per_token_logps = -1 * tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
        per_token_logps = rearrange(per_token_logps, 's b -> b s').contiguous()

                                                       
                                                   
                                                                
        if self.config.context_parallel_size > 1:
            per_token_logps = all_gather_to_context_parallel_region(per_token_logps)
        logps = (per_token_logps * loss_mask_full).sum(-1)
        return logps

    def dpo_loss(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,
                 reference_rejected_logps):
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps)
        logits = chosen_rewards - rejected_rewards

                 
        if self.loss_type == 'sigmoid':
            losses = (-F.logsigmoid(logits) * (1 - self.label_smoothing) -
                      F.logsigmoid(-logits) * self.label_smoothing)
        else:
            raise ValueError(f'unknown loss type {self.loss_type}')

        return losses, chosen_rewards.detach(), rejected_rewards.detach()
