              
                                                      
                       

from typing import List, Optional, Dict

import torch
import torch.nn.functional as F

from megatron.core import InferenceParams
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.models.gpt import GPTModel
from megatron.core import tensor_parallel

from gpatch.core.tensor_parallel.mappings import reduce_from_context_parallel_region
from gpatch.core.transformer.transformer_config import GpatchTransformerConfig

DEFAULT_IMAGE_TOKEN_INDEX = -200


class LLaVAModelDPO(MegatronModule):
    """LLaVA DPO multi-modal model.

    Args: reference by: megatron/core/models/multimodal/llava_model.py
    """

    def __init__(
        self,
        language_transformer_config: GpatchTransformerConfig,
        language_transformer_layer_spec: ModuleSpec,
        language_vocab_size: int,
        language_max_sequence_length: int,
        vision_transformer_config: GpatchTransformerConfig,
        vision_transformer_layer_spec: ModuleSpec,
        vision_projection_config: GpatchTransformerConfig,
        vision_projection_layer_spec: ModuleSpec,
        vision_projection_type: str = "mlp",
        parallel_output: bool = True,
        language_position_embedding_type: str = 'learned_absolute',
        language_rotary_percent: float = 1.0,
        pre_process: bool = True,
        post_process: bool = True,
        add_encoder: bool = True,
        add_decoder: bool = True,
        language_rotary_base: int = 10000,
        share_embeddings_and_output_weights: bool = False,
        llava_model_class: type = LLaVAModel,

                    
        beta: float = 0.1,
        label_smoothing: float = 0,
        loss_type: str = 'sigmoid',
        ftx_gamma: float = 0.,

                                       
        **kwargs,
    ) -> None:
        super().__init__(config=language_transformer_config)
        self.pre_process = pre_process
        self.post_process = post_process
        self.beta = beta
        self.label_smoothing = label_smoothing
        self.loss_type = loss_type
        self.ftx_gamma = ftx_gamma
        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights

        self.policy_model: LLaVAModel = llava_model_class(
            language_transformer_config=language_transformer_config,
            language_transformer_layer_spec=language_transformer_layer_spec,
            language_vocab_size=language_vocab_size,
            language_max_sequence_length=language_max_sequence_length,
            vision_transformer_config=vision_transformer_config,
            vision_transformer_layer_spec=vision_transformer_layer_spec,
            vision_projection_config=vision_projection_config,
            vision_projection_layer_spec=vision_projection_layer_spec,
            vision_projection_type=vision_projection_type,
            parallel_output=parallel_output,
            language_rotary_base=language_rotary_base,
            share_embeddings_and_output_weights=share_embeddings_and_output_weights,
            language_position_embedding_type=language_position_embedding_type,
            language_rotary_percent=language_rotary_percent,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=add_encoder,
            add_decoder=add_decoder,
            **kwargs,
        )
        self.ref_model: LLaVAModel = llava_model_class(
            language_transformer_config=language_transformer_config,
            language_transformer_layer_spec=language_transformer_layer_spec,
            language_vocab_size=language_vocab_size,
            language_max_sequence_length=language_max_sequence_length,
            vision_transformer_config=vision_transformer_config,
            vision_transformer_layer_spec=vision_transformer_layer_spec,
            vision_projection_config=vision_projection_config,
            vision_projection_layer_spec=vision_projection_layer_spec,
            vision_projection_type=vision_projection_type,
            parallel_output=parallel_output,
            language_rotary_base=language_rotary_base,
            share_embeddings_and_output_weights=share_embeddings_and_output_weights,
            language_position_embedding_type=language_position_embedding_type,
            language_rotary_percent=language_rotary_percent,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=add_encoder,
            add_decoder=add_decoder,
            **kwargs,
        )
        self.ref_model.requires_grad_(False)

    def shared_embedding_or_output_weight(self):
        """only need train model"""
        return self.policy_model.language_model.shared_embedding_or_output_weight()

    def set_input_tensor(self, input_tensor) -> None:
        """Set model chunk input tensor."""
        if not isinstance(input_tensor, list):
            input_tensor = [input_tensor]
        assert len(input_tensor) == 1
        if input_tensor[0] is None:
            return
                                          
        self.policy_model.set_input_tensor(input_tensor[0][0])
        self.ref_model.set_input_tensor(input_tensor[0][1])

    def freeze(
        self,
        freeze_language_model: bool,
        freeze_vision_model: bool,
        freeze_vision_projection: bool,
    ):
        self.policy_model.freeze(
            freeze_language_model,
            freeze_vision_model,
            freeze_vision_projection,
        )

    def forward(
        self,
        images: torch.Tensor,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        loss_mask: Optional[torch.Tensor] = None,
        inference_params: Optional[InferenceParams] = None,
        num_image_tiles: Optional[List[int]] = None,
        image_token_index: Optional[int] = None,
        runtime_gather_output: Optional[bool] = None,
        packed_seq_params: Optional[PackedSeqParams] = None,
    ) -> torch.Tensor:
        """ Args: reference by: megatron/core/models/multimodal/llava_model.py """
        hidden_states, _ = self.policy_model(
            images=images,
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            labels=None,
            loss_mask=loss_mask,
            inference_params=inference_params,
            num_image_tiles=num_image_tiles,
            image_token_index=image_token_index,
            runtime_gather_output=runtime_gather_output,
            packed_seq_params=packed_seq_params,
        )
        with torch.no_grad():
            ref_hidden_states, _ = self.ref_model(
                images=images,
                input_ids=input_ids,
                position_ids=position_ids,
                attention_mask=attention_mask,
                labels=None,
                loss_mask=loss_mask,
                inference_params=inference_params,
                num_image_tiles=num_image_tiles,
                image_token_index=image_token_index,
                runtime_gather_output=runtime_gather_output,
                packed_seq_params=packed_seq_params,
            )

        if not self.post_process:
                               
            return torch.stack([hidden_states, ref_hidden_states]), {}

        assert labels is not None, 'dpo model need labels'
                                           
        logits = hidden_states.float()
        ref_logits = ref_hidden_states.float()

        return self.dpo(labels, logits, ref_logits)

    def get_batch_logps(self, logits, labels):
        per_token_logps = -1 * tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
        loss_mask = (labels != -100)
        logps = (per_token_logps * loss_mask).sum(-1)
        if self.config.context_parallel_size > 1:
            logps = reduce_from_context_parallel_region(logps)
        return logps

    def dpo_loss(self, policy_chosen_logps, policy_rejected_logps, ref_chosen_logps,
                 ref_rejected_logps):
        chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
        rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)
        logits = chosen_rewards - rejected_rewards

                 
        if self.loss_type == 'sigmoid':
            losses = (-F.logsigmoid(logits) * (1 - self.label_smoothing) -
                      F.logsigmoid(-logits) * self.label_smoothing)
        else:
            raise ValueError(f'unknown loss type {self.loss_type}')

        return losses, chosen_rewards.detach(), rejected_rewards.detach()

    def dpo(self, labels, logits, ref_logits):
        assert logits.shape[0] % 2 == 0, "mbs must be 2*n"
        rbs = logits.shape[0] // 2

        policy_chosen_logits, policy_rejected_logits = logits.split(rbs, dim=0)
        policy_chosen_logits_mean = policy_chosen_logits.detach().mean().float()
        policy_rejected_logits_mean = policy_rejected_logits.detach().mean().float()

        logps = self.get_batch_logps(logits, labels)
        ref_logps = self.get_batch_logps(ref_logits, labels)

        policy_chosen_logps, policy_rejected_logps = logps.split(rbs, dim=0)
        ref_chosen_logps, ref_rejected_logps = ref_logps.split(rbs, dim=0)

        losses, chosen_rewards, rejected_rewards = self.dpo_loss(
            policy_chosen_logps,
            policy_rejected_logps,
            ref_chosen_logps,
            ref_rejected_logps,
        )
        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        if self.ftx_gamma > 1e-6:
            chosen_labels = labels.split(rbs, dim=0)[0]
            loss_mask_sum = (chosen_labels != -100).sum(-1)
            if self.config.context_parallel_size > 1:
                loss_mask_sum = reduce_from_context_parallel_region(loss_mask_sum)
            losses -= self.ftx_gamma * policy_chosen_logps / loss_mask_sum

        metrics = {
            "dpo-metrics/rewards-accuracies": reward_accuracies.mean().float(),
            "dpo-metrics/rewards-chosen": chosen_rewards.mean().float(),
            "dpo-metrics/rewards-rejected": rejected_rewards.mean().float(),
            "dpo-metrics/rewards-margins": (chosen_rewards - rejected_rewards).mean().float(),
            "dpo-metrics/logps-rejected": policy_rejected_logps.detach().mean().float(),
            "dpo-metrics/logps-chosen": policy_chosen_logps.detach().mean().float(),
            "dpo-metrics/logits-rejected": policy_rejected_logits_mean.detach().mean().float(),
            "dpo-metrics/logits-chosen": policy_chosen_logits_mean.detach().mean().float(),
        }
        return losses, metrics

    def sharded_state_dict(self,
                           prefix: str = '',
                           sharded_offsets: tuple = (),
                           metadata: Optional[Dict] = None) -> ShardedStateDict:
        sharded_state_dict = {}
        sharded_state_dict.update(
            self.policy_model.sharded_state_dict(f'{prefix}policy_model.', sharded_offsets,
                                                 metadata))
        sharded_state_dict.update(
            self.ref_model.sharded_state_dict(f'{prefix}ref_model.', sharded_offsets, metadata))

        return sharded_state_dict


class Qwen2VLModelDPO(LLaVAModelDPO):

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        vision_data: torch.Tensor = None,
        vision_grid_thw: torch.Tensor = None,
        video_start_index: int = -1,
        image_input_mask: torch.Tensor = None,
        video_input_mask: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,
        inference_params: InferenceParams = None,
        packed_seq_params: PackedSeqParams = None,
        image_padded: bool = False,
        extra_block_kwargs: dict = None,
    ) -> torch.Tensor:
        """ Args: reference by: gpatch/core/models/multimodal/qwen2vl_model.py """
        hidden_states = self.policy_model(
            input_ids=input_ids,
            position_ids=position_ids,
            vision_data=vision_data,
            vision_grid_thw=vision_grid_thw,
            video_start_index=video_start_index,
            image_input_mask=image_input_mask,
            video_input_mask=video_input_mask,
            attention_mask=attention_mask,
            labels=None,
            inference_params=inference_params,
            packed_seq_params=packed_seq_params,
            image_padded=image_padded,
            **(extra_block_kwargs or {}),
        )
        with torch.no_grad():
            ref_hidden_states = self.ref_model(
                input_ids=input_ids,
                position_ids=position_ids,
                vision_data=vision_data,
                vision_grid_thw=vision_grid_thw,
                video_start_index=video_start_index,
                image_input_mask=image_input_mask,
                video_input_mask=video_input_mask,
                attention_mask=attention_mask,
                labels=None,
                inference_params=inference_params,
                packed_seq_params=packed_seq_params,
                image_padded=image_padded,
                **(extra_block_kwargs or {}),
            )

        if not self.post_process:
                               
            return torch.stack([hidden_states, ref_hidden_states]), {}

        assert labels is not None, 'dpo model need labels'
                                           
        logits = hidden_states.float()
        ref_logits = ref_hidden_states.float()

        return self.dpo(labels, logits, ref_logits)
