              
                                                      
                                                                 

import logging
from typing import Literal, Optional, List, Dict
from typing_extensions import override
from copy import deepcopy

import torch
import torch.nn.functional as F
from torch import Tensor
from einops import rearrange

from megatron.core import mpu
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.common.language_module.language_module import LanguageModule
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.models.gpt import GPTModel
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.parallel_state import get_tensor_model_parallel_group

from gpatch.core.tensor_parallel.mappings import reduce_from_context_parallel_region
from gpatch.core.tensor_parallel.mappings import (
    all_gather_to_context_parallel_region, )


class GptDpoModel(LanguageModule):

    def __init__(
        self,
        config: TransformerConfig,
        transformer_layer_spec: ModuleSpec,
        vocab_size: int,
        max_sequence_length: int,
        pre_process: bool = True,
        post_process: bool = True,
        fp16_lm_cross_entropy: bool = False,
        parallel_output: bool = True,
        share_embeddings_and_output_weights: bool = False,
        position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute',
        rotary_percent: float = 1.0,
        rotary_base: int = 10000,
        rope_scaling: bool = False,
        rope_scaling_factor: float = 8.0,
        scatter_embedding_sequence_parallel: bool = True,
        seq_len_interpolation_factor: Optional[float] = None,
                           
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: str = 'sigmoid',
        ftx_gamma: float = 0.,
        model_using: str = 'both',
        golden_margin: float = 0.5,
        forward_without_loss=False,
        orpo_loss=False,
        golden_loss=False,
        pair_gamma: float = 1.0,
    ) -> None:
        super().__init__(config=config)

        self.vocab_size = vocab_size
        self.max_sequence_length = max_sequence_length
        self.pre_process = pre_process
        self.post_process = post_process
        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
        self.parallel_output = parallel_output
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
        self.position_embedding_type = position_embedding_type
        self.rope_scaling = rope_scaling
        self.rope_scaling_factor = rope_scaling_factor
        self.scatter_embedding_sequence_parallel = scatter_embedding_sequence_parallel

        self.beta = beta
        self.label_smoothing = label_smoothing
        self.loss_type = loss_type
        self.ftx_gamma = ftx_gamma
        self.pair_gamma = pair_gamma
        self.forward_without_loss = forward_without_loss
        self.orpo_loss = orpo_loss
        self.golden_loss = golden_loss
        assert not self.golden_loss
        assert not self.orpo_loss
        self.dpo_golden_margin = golden_margin

        if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
            print(f"GptDpoModel init {golden_loss=} {ftx_gamma=} {pair_gamma=}")

        self.policy_model = None
        if model_using == 'both' or model_using == 'policy':
            self.policy_model = GPTModel(
                config=config,
                transformer_layer_spec=transformer_layer_spec,
                vocab_size=vocab_size,
                max_sequence_length=max_sequence_length,
                pre_process=pre_process,
                post_process=post_process,
                fp16_lm_cross_entropy=fp16_lm_cross_entropy,
                parallel_output=parallel_output,
                share_embeddings_and_output_weights=share_embeddings_and_output_weights,
                position_embedding_type=position_embedding_type,
                rotary_percent=rotary_percent,
                rotary_base=rotary_base,
                rope_scaling=rope_scaling,
                rope_scaling_factor=rope_scaling_factor,
                scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
            )

        self.ref_model = None
        if model_using == 'both' or model_using == 'ref':
            self.ref_model = GPTModel(
                config=config,
                transformer_layer_spec=transformer_layer_spec,
                vocab_size=vocab_size,
                max_sequence_length=max_sequence_length,
                pre_process=pre_process,
                post_process=post_process,
                fp16_lm_cross_entropy=fp16_lm_cross_entropy,
                parallel_output=parallel_output,
                share_embeddings_and_output_weights=share_embeddings_and_output_weights,
                position_embedding_type=position_embedding_type,
                rotary_percent=rotary_percent,
                rotary_base=rotary_base,
                rope_scaling=rope_scaling,
                rope_scaling_factor=rope_scaling_factor,
                scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel,
                seq_len_interpolation_factor=seq_len_interpolation_factor,
            )
            self.ref_model.requires_grad_(False)
        assert self.policy_model is not None or self.ref_model is not None

        self.reward_models = None
        self.dpo_reward_models_cnt = config.dpo_reward_models_cnt
        assert self.dpo_reward_models_cnt == 0, f"current forbidden"
        if self.dpo_reward_models_cnt > 0:
            self.reward_models = torch.nn.ModuleList()
            for _ in range(self.dpo_reward_models_cnt):
                reward_model = GPTModel(
                    config=config,
                    transformer_layer_spec=transformer_layer_spec,
                    vocab_size=vocab_size,
                    max_sequence_length=max_sequence_length,
                    pre_process=pre_process,
                    post_process=post_process,
                    fp16_lm_cross_entropy=fp16_lm_cross_entropy,
                    parallel_output=parallel_output,
                    share_embeddings_and_output_weights=share_embeddings_and_output_weights,
                    position_embedding_type=position_embedding_type,
                    rotary_percent=rotary_percent,
                    rotary_base=rotary_base,
                    rope_scaling=rope_scaling,
                    rope_scaling_factor=rope_scaling_factor,
                    scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel,
                    seq_len_interpolation_factor=seq_len_interpolation_factor,
                )
                reward_model.requires_grad_(False)
                self.reward_models.append(reward_model)

    def set_input_tensor(self, input_tensor: Tensor) -> None:
        """Sets input tensor to the model.

        See megatron.model.transformer.set_input_tensor()

        Args:
            input_tensor (Tensor): Sets the input tensor for the model.
        """
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]
        assert len(input_tensor) == 1
        if input_tensor[0] is None:
            if self.policy_model:
                self.policy_model.set_input_tensor(None)
            if self.ref_model:
                self.ref_model.set_input_tensor(None)
            if self.reward_models:
                for reward_model in self.reward_models:
                    reward_model.set_input_tensor(None)
        else:
            tensor_cnt = 2 if self.policy_model and self.ref_model else 1
            assert input_tensor[0].shape[
                0] == tensor_cnt + self.dpo_reward_models_cnt, f"shape:{input_tensor[0].shape} len: {tensor_cnt}-{self.dpo_reward_models_cnt}"
            _idx = 0
            if self.policy_model:
                self.policy_model.set_input_tensor(input_tensor=input_tensor[0][_idx])
                _idx += 1
            if self.ref_model:
                self.ref_model.set_input_tensor(input_tensor=input_tensor[0][_idx])
                _idx += 1
            if self.reward_models:
                for i in range(self.dpo_reward_models_cnt):
                    self.reward_models[i].set_input_tensor(input_tensor=input_tensor[0][i + _idx])

    @override
    def shared_embedding_or_output_weight(self) -> Tensor:
        """Gets the emedding weight or output logit weights when share embedding and output weights set to True.

        Returns:
            Tensor: During pre processing it returns the input embeddings weight while during post processing it returns the final output layers weight
        """
        if self.pre_process:
            return self.policy_model.embedding.word_embeddings.weight
        elif self.post_process:
            return self.policy_model.output_layer.weight
        return None

    def forward(
        self,
        input_ids: Tensor,
        position_ids: Tensor,
        loss_mask_full: Tensor,
        attention_mask: Tensor,
        decoder_input: Tensor = None,
        labels: Tensor = None,
        inference_params: InferenceParams = None,
        packed_seq_params: PackedSeqParams = None,
        dpo_weights: Tensor = None,
        w0_weights: Tensor = None,
        reward_margins: Tensor = None,
        ref_logps: Tensor = None,
        extra_block_kwargs: dict = None,
    ) -> Tensor:
        hidden_states = None
        if self.policy_model:
            hidden_states = self.policy_model(
                input_ids=input_ids,
                position_ids=position_ids,
                attention_mask=attention_mask,
                decoder_input=decoder_input,
                labels=None,
                inference_params=inference_params,
                packed_seq_params=packed_seq_params,
                extra_block_kwargs=extra_block_kwargs,
            )
        with torch.no_grad():
            ref_hidden_states = None
            if self.ref_model:
                ref_hidden_states = self.ref_model(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    attention_mask=attention_mask,
                    decoder_input=decoder_input,
                    labels=None,
                    inference_params=inference_params,
                    packed_seq_params=packed_seq_params,
                    extra_block_kwargs=extra_block_kwargs,
                )
            reward_hidden_states = []
            assert self.dpo_reward_models_cnt == 0, f"current forbidden"
            for i in range(self.dpo_reward_models_cnt):
                reward_hidden_states.append(self.reward_models[i](
                    input_ids=input_ids,
                    position_ids=position_ids,
                    attention_mask=attention_mask,
                    decoder_input=decoder_input,
                    labels=None,
                    inference_params=inference_params,
                    packed_seq_params=packed_seq_params,
                    extra_block_kwargs=extra_block_kwargs,
                ))

        if not self.post_process:
            ret = []
            if self.policy_model:
                ret.append(hidden_states)
            if self.ref_model:
                ret.append(ref_hidden_states)
            ret = ret + reward_hidden_states
            return torch.stack(ret), {}

                         
                                                        
        logits = hidden_states
        ref_logits = ref_hidden_states
        reward_logits_list = reward_hidden_states

        if labels is None:
                     
            ret = []
            if self.policy_model:
                ret.append(logits)
            if self.ref_model:
                ret.append(ref_logits)
            return torch.stack(ret + reward_logits_list), {}

        if self.forward_without_loss:
            return self.margin_or_logps(labels, logits, ref_logits, loss_mask_full), {}

        return self.dpo(labels, logits, ref_logits, loss_mask_full, reward_logits_list, dpo_weights,
                        w0_weights, reward_margins, ref_logps)

    def convert_logit(self, logits):
        logits = logits.float().contiguous()
        return logits

    def margin_or_logps(
        self,
        labels,
        logits,
        ref_logits,
        loss_mask_full,
    ):
                              
        logps = None
        if self.policy_model:
            logits = self.convert_logit(logits)
            logps = self.get_batch_logps(logits, labels, loss_mask_full)

        ref_logps = None
        if self.ref_model:
            ref_logits = self.convert_logit(ref_logits)
            ref_logps = self.get_batch_logps(ref_logits, labels, loss_mask_full)

        if logps is None:
            return ref_logps
        if ref_logps is None:
            return logps

        return logps - ref_logps

    def dpo(self, labels, logits, ref_logits, loss_mask_full, reward_logits_list, dpo_weights,
            w0_weights, reward_margins, ref_logps):
                    
        logits = self.convert_logit(logits)

        rbs = logits.shape[0] // 2
        policy_chosen_logits, policy_rejected_logits = logits.split(rbs, dim=0)
                      
        policy_chosen_logits_mean = policy_chosen_logits.detach().mean().float()
        policy_rejected_logits_mean = policy_rejected_logits.detach().mean().float()

        if self.config.context_parallel_size > 1:
            reduce_op = torch.distributed.ReduceOp.AVG
            torch.distributed.all_reduce(policy_chosen_logits_mean,
                                         op=reduce_op,
                                         group=mpu.get_context_parallel_group())
            torch.distributed.all_reduce(policy_rejected_logits_mean,
                                         op=reduce_op,
                                         group=mpu.get_context_parallel_group())

        if False:
            assert False, "never reach"
        else:
            torch.distributed.all_reduce(
                policy_chosen_logits_mean,
                op=torch.distributed.ReduceOp.AVG,
                group=get_tensor_model_parallel_group(),
            )
            torch.distributed.all_reduce(
                policy_rejected_logits_mean,
                op=torch.distributed.ReduceOp.AVG,
                group=get_tensor_model_parallel_group(),
            )

                                                    
        logps = self.get_batch_logps(logits, labels, loss_mask_full)

        if self.ref_model is not None:
            ref_logits = self.convert_logit(ref_logits)
            ref_logps = self.get_batch_logps(ref_logits, labels, loss_mask_full)

        policy_chosen_logps, policy_rejected_logps = logps.split(rbs, dim=0)
        reference_chosen_logps, reference_rejected_logps = ref_logps.split(rbs, dim=0)

                  
        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            reference_chosen_logps,
            reference_rejected_logps,
        )

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

                  
        if self.ftx_gamma > 1e-6:
            chosen_loss_mask_full = loss_mask_full.split(rbs, dim=0)[0]
            chosen_loss_mask_full_sum = chosen_loss_mask_full.sum(-1)
            losses -= self.ftx_gamma * policy_chosen_logps / chosen_loss_mask_full_sum
                                                         
                                                             
                                                       
                                                                                    
                                                                            

        metrics = {
            "dpo-metrics/rewards-accuracies": reward_accuracies.mean().float(),
            "dpo-metrics/logps-rejected": policy_rejected_logps.detach().mean().float(),
            "dpo-metrics/logps-chosen": policy_chosen_logps.detach().mean().float(),
            "dpo-metrics/logits-rejected": policy_rejected_logits_mean,
            "dpo-metrics/logits-chosen": policy_chosen_logits_mean,
        }
                     
        metrics["dpo-metrics/rewards-chosen"] = chosen_rewards.mean().float()
        metrics["dpo-metrics/rewards-rejected"] = rejected_rewards.mean().float()
        metrics["dpo-metrics/rewards-margins"] = (chosen_rewards - rejected_rewards).mean().float()
        return losses, metrics

    def get_batch_logps(self, logits, labels, loss_mask_full, average_log_prob=False):
        logits = rearrange(logits, 'b s h -> s b h').contiguous()
        labels = rearrange(labels, 'b s -> s b').contiguous()
        per_token_logps = -1 * tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
        per_token_logps = rearrange(per_token_logps, 's b -> b s').contiguous()

        if self.config.context_parallel_size > 1:
            per_token_logps = all_gather_to_context_parallel_region(per_token_logps)
        logps = (per_token_logps * loss_mask_full).sum(-1)
        return logps

    def dpo_loss(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps,
                 reference_rejected_logps):
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps)
        logits = chosen_rewards - rejected_rewards

                 
        if self.loss_type == 'sigmoid':
            losses = (-F.logsigmoid(logits) * (1 - self.label_smoothing) -
                      F.logsigmoid(-logits) * self.label_smoothing)
        else:
            raise ValueError(f'unknown loss type {self.loss_type}')

        return losses, chosen_rewards.detach(), rejected_rewards.detach()

    def sharded_state_dict(self,
                           prefix: str = '',
                           sharded_offsets: tuple = (),
                           metadata: Optional[Dict] = None) -> ShardedStateDict:
        sharded_state_dict = {}
        if self.policy_model:
            sharded_state_dict.update(
                self.policy_model.sharded_state_dict(f'{prefix}policy_model.', sharded_offsets,
                                                     metadata))
        if self.ref_model:
            sharded_state_dict.update(
                self.ref_model.sharded_state_dict(f'{prefix}ref_model.', sharded_offsets, metadata))

        if self.reward_models:
            for i in range(self.dpo_reward_models_cnt):
                sharded_state_dict.update(self.reward_models[i].sharded_state_dict(
                    f'{prefix}reward_models.{i}.', sharded_offsets, metadata))
        return sharded_state_dict
