import torch
import torch.distributed
from torch.distributed._tensor.placement_types import Replicate

from ..communication import Communication
from ..distributed import tensor_to_dtensor
from ..hf_models import get_autoregressive_language_modeling_loss, get_aux_loss
from ..utils import MetricsTrackingDict, ProcessGroupManager
from .base import ModelWrapper


class ModelWrapperForFinetuning(ModelWrapper):
    def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTrackingDict:
        """forward function for a batch

        Args:
            batch (dict): a dict of key, value pairs for a batch

        Returns:
            MetricsTrackingDict: loss tracking dict
        """

        if ProcessGroupManager.is_tensor_parallel_enabled():
            batch = self._broadcast_inputs_for_tensor_parallel(batch)

        labels = batch.pop("labels")

        model_outputs = self.model(**batch)

        return self.get_loss(
            model_outputs=model_outputs,
            labels=labels,
            cu_seqlens=batch.get("cu_seqlens", None),
            lm_loss_multiplier=lm_loss_multiplier,
        )

    def get_loss(
        self, model_outputs, labels: torch.Tensor, cu_seqlens: torch.Tensor | None, lm_loss_multiplier: float = 1
    ) -> torch.Tensor | dict:
        logits: torch.Tensor = model_outputs.logits
        aux_loss = get_aux_loss()

        lm_loss = get_autoregressive_language_modeling_loss(
            lm_logits=logits,
            labels=labels,
            upcast_logits_for_loss=self.upcast_logits_for_loss,
            cu_seqlens=cu_seqlens,
            use_padding_free_transformer=self.use_padding_free_transformer,
            reduction="sum",
        )

        lm_loss = lm_loss * lm_loss_multiplier

        if aux_loss == 0:
            loss = lm_loss
            output = {"loss": loss}
        else:
            if ProcessGroupManager.is_tensor_parallel_enabled():
                aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate())

            loss = lm_loss + self.router_aux_loss_coef * aux_loss
            output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss}

        return output

    def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict:
        device = torch.cuda.current_device()

        is_tp_first_rank = ProcessGroupManager.is_tensor_parallel_first_rank()
        tp_source_rank = ProcessGroupManager.get_tensor_parallel_first_rank()
        tp_group = ProcessGroupManager.get_tensor_parallel_group()

        if self.use_padding_free_transformer:
            keys = ["input_ids", "position_ids", "labels", "cu_seqlens", "max_seqlen"]

            if is_tp_first_rank:
                metadata = torch.tensor([batch["cu_seqlens"].numel(), batch["input_ids"].numel()], device=device)
            else:
                metadata = torch.empty(2, dtype=torch.long, device=device)

            torch.distributed.broadcast(metadata, src=tp_source_rank, group=tp_group)
            cu_seqlens_num_elements, input_ids_num_elements = metadata

            if not is_tp_first_rank:
                batch = {
                    "input_ids": torch.empty(input_ids_num_elements, dtype=torch.long, device=device),
                    "position_ids": torch.empty(input_ids_num_elements, dtype=torch.long, device=device),
                    "labels": torch.empty(input_ids_num_elements, dtype=torch.long, device=device),
                    "cu_seqlens": torch.empty(cu_seqlens_num_elements, dtype=torch.int32, device=device),
                    "max_seqlen": torch.empty(1, dtype=torch.long, device=device),
                }
        else:
            keys = ["input_ids", "attention_mask", "labels"]

            batch_shape = batch["input_ids"].shape if is_tp_first_rank else None
            batch_shape = Communication.broadcast_object(batch_shape, src=tp_source_rank, group=tp_group)

            if not is_tp_first_rank:
                batch = {key: torch.empty(batch_shape, dtype=torch.long, device=device) for key in keys}

        for key in keys:
            torch.distributed.broadcast(batch[key], src=tp_source_rank, group=tp_group)

        return batch
