# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torch.distributed as dist
from torch.profiler import record_function

from lmdeploy.pytorch import consts
from lmdeploy.pytorch.config import DLLMConfig
from lmdeploy.pytorch.distributed import DistContext
from lmdeploy.pytorch.engine.logits_process import SamplingInputs
from lmdeploy.pytorch.messages import SchedulerSequence
from lmdeploy.pytorch.model_inputs import ModelInputs

from ..base.model_agent import ExtraInputs, ExtraOutputs, ModelAgentStrategy, StoppingCriteria
from .unmasking import UnmaskingProcessor

SeqList = List[SchedulerSequence]


@dataclass
class DLLMExtraInputs(ExtraInputs):
    """DLLM extra inputs."""
    dllm_mask: torch.Tensor

    def broadcast(self, src: int, group, async_op=False):
        return dist.broadcast(self.dllm_mask, src=src, group=group, async_op=async_op)


@dataclass
class DLLMExtraOutputs(ExtraOutputs):
    """Ar extra outputs."""
    dllm_mask: torch.Tensor


def _check_stopwords_dllm(token_ids: torch.Tensor, stop_words: torch.Tensor, is_unmasked: torch.Tensor,
                          stopped: torch.Tensor, stop_pos: torch.Tensor, num_appendable_ids: torch.Tensor,
                          output_start_pos: torch.Tensor, inputs: ModelInputs):
    num_tokens = token_ids.size(0)
    batch_size = num_appendable_ids.size(0)
    block_size = num_tokens // batch_size

    # blocks might contain stop words in prev-round chat
    # these stop words should be ignored
    kv_seqlens = inputs.history_lengths + inputs.seq_length
    ignore_pos = (output_start_pos - (kv_seqlens - block_size)).clamp_min(0)
    ignore_range = torch.arange(0, block_size, dtype=ignore_pos.dtype, device=ignore_pos.device)
    ignore_mask = (ignore_range[None, :] < ignore_pos[:, None]).flatten()
    token_ids = token_ids.clone()
    token_ids[ignore_mask] = -1

    # find stop words
    sw_stopped = (token_ids[:, None] == stop_words).any(1)
    sw_stopped = sw_stopped.view(batch_size, block_size)
    sw_stop_pos = sw_stopped.int().argmax(1)

    stop_pos = torch.where(stopped, stop_pos, sw_stop_pos)
    sw_stopped = sw_stopped.any(dim=1)
    sw_stopped = sw_stopped & is_unmasked
    stopped = stopped | sw_stopped

    # update num_appendable_ids
    one_ids = torch.clamp_max(num_appendable_ids, 0)
    num_appendable_ids = torch.where(sw_stopped, one_ids, num_appendable_ids)

    return stopped, stop_pos, num_appendable_ids


@dataclass
class DLLMStoppingCriteria(StoppingCriteria):
    num_appendable_ids: torch.Tensor
    output_start_pos: torch.Tensor

    @record_function('stopping_criteria')
    def step(self,
             token_ids: torch.Tensor,
             stop_words: torch.Tensor,
             inputs: Optional[ModelInputs] = None,
             extra_inputs: Optional[DLLMExtraInputs] = None):
        """Check whether to stop generation."""
        num_appendable_ids = self.num_appendable_ids
        output_start_pos = self.output_start_pos
        num_tokens = token_ids.size(0)
        batch_size = num_appendable_ids.size(0)
        block_size = num_tokens // batch_size

        dllm_mask = extra_inputs.dllm_mask
        dllm_mask = dllm_mask.view(batch_size, block_size)
        is_unmasked = (dllm_mask == consts.DLLM_UNMASKED).all(dim=1)

        # check stop by num_new_tokens
        num_appendable_ids -= is_unmasked * block_size
        stopped = num_appendable_ids <= 0
        stop_pos = block_size - 1 + num_appendable_ids

        # check stop words
        if stop_words is not None:
            stopped, stop_pos, num_appendable_ids = _check_stopwords_dllm(token_ids,
                                                                          stop_words,
                                                                          is_unmasked,
                                                                          stopped,
                                                                          stop_pos,
                                                                          num_appendable_ids,
                                                                          output_start_pos=output_start_pos,
                                                                          inputs=inputs)

        new_stopping = DLLMStoppingCriteria(num_appendable_ids=num_appendable_ids, output_start_pos=output_start_pos)
        return stopped, stop_pos, new_stopping


class DLLMModelAgentStrategy(ModelAgentStrategy):

    def __init__(self, dllm_config: DLLMConfig, dllm_mask_token: int):
        block_size = dllm_config.block_length
        self.block_size = block_size
        self.dllm_mask_token = dllm_mask_token

        self.unmasking_processor = UnmaskingProcessor(dllm_config=dllm_config)
        self._logits_buffer: torch.Tensor | None = None
        self._logits_buffer_capacity: int = 0
        self._logits_vocab_size: int = 0

    def _use_delayed_cache(self, inputs: ModelInputs) -> bool:
        cfg = self.unmasking_processor.dllm_config
        return (cfg.enable_delayed_cache and inputs.processing_indices is not None
                and inputs.processing_q_lens is not None)

    def _get_full_logits_buffer(self, template: torch.Tensor, batch_size: int,
                                vocab_size: int) -> torch.Tensor:
        """Return a view of a reusable logits buffer large enough for this batch."""
        need_new = (self._logits_buffer is None or self._logits_buffer.device != template.device
                    or self._logits_buffer.dtype != template.dtype
                    or self._logits_vocab_size != vocab_size
                    or self._logits_buffer_capacity < batch_size)
        if need_new:
            capacity = max(batch_size, self._logits_buffer_capacity)
            if capacity == 0:
                capacity = batch_size
            self._logits_buffer = template.new_empty(capacity, self.block_size, vocab_size)
            self._logits_buffer_capacity = capacity
            self._logits_vocab_size = vocab_size
        return self._logits_buffer[:batch_size]

    def _scatter_logits(self, flat_logits: torch.Tensor, inputs: ModelInputs) -> torch.Tensor:
        batch_size = inputs.seq_length.size(0)
        vocab_size = flat_logits.size(-1)
        block_size = self.block_size
        full_logits = self._get_full_logits_buffer(flat_logits, batch_size, vocab_size)
        device = flat_logits.device
        proc_indices = inputs.processing_indices.to(device=device, dtype=torch.long)
        seq_lengths_src = inputs.processing_q_lens if inputs.processing_q_lens is not None else inputs.seq_length
        num_proc = proc_indices.numel()
        seq_lengths = seq_lengths_src.to(device=device, dtype=torch.long)
        seq_ids = torch.arange(batch_size, device=device, dtype=torch.long)
        batch_indices = torch.repeat_interleave(seq_ids, seq_lengths, output_size=int(num_proc))
        dense_row_indices = batch_indices * block_size + proc_indices
        total_logits = flat_logits.size(0)
        # Kernels may output ragged logits (one row per processing index) or dense block
        # logits. Select the corresponding rows before scattering into the full buffer.
        if total_logits == num_proc:
            assign_logits = flat_logits
        elif total_logits == batch_size * block_size:
            assign_logits = flat_logits.index_select(0, dense_row_indices)
        else:
            raise RuntimeError(
                f'Unexpected logits rows: got {total_logits}, '
                f'but need either {num_proc} (sparse) or {batch_size * block_size} (dense).')

        flat_full = full_logits.view(-1, vocab_size)
        flat_full.index_copy_(0, dense_row_indices, assign_logits)
        return flat_full

    def _temporarily_cache_unprocessed(self, dllm_mask: torch.Tensor,
                                       inputs: ModelInputs) -> Optional[torch.Tensor]:
        """Mark unprocessed masked slots as cached before unmasking."""
        if not self._use_delayed_cache(inputs):
            return None
        proc_indices = inputs.processing_indices
        proc_q_lens = inputs.processing_q_lens
        block_size = self.block_size
        mask_view = dllm_mask.view(-1, block_size)
        device = dllm_mask.device
        proc_indices = proc_indices.to(device=device, dtype=torch.long)
        proc_q_lens = proc_q_lens.to(device=device, dtype=torch.long)

        total_seqs = mask_view.size(0)
        seq_count = proc_q_lens.numel()

        seq_mask_view = mask_view[:seq_count]
        processed_mask = torch.zeros(seq_count * block_size, dtype=torch.bool, device=device)
        if proc_indices.numel() > 0:
            seq_ids = torch.repeat_interleave(torch.arange(seq_count, device=device, dtype=torch.long),
                                              proc_q_lens,
                                              output_size=int(proc_indices.numel()))
            dense_offsets = seq_ids * block_size + proc_indices
            processed_mask.index_fill_(0, dense_offsets, True)
        processed_mask = processed_mask.view(seq_count, block_size)

        masked_slots = (seq_mask_view == consts.DLLM_MASKED)
        cached_mask = (~processed_mask) & masked_slots
        cached_mask &= (proc_q_lens < block_size).view(-1, 1)
        seq_mask_view.masked_fill_(cached_mask, consts.DLLM_CACHED)

        if seq_count == total_seqs:
            return cached_mask
        cached_mask_full = torch.zeros_like(mask_view, dtype=torch.bool)
        cached_mask_full[:seq_count].copy_(cached_mask)
        return cached_mask_full

    def _restore_cached_unprocessed(self, dllm_mask: torch.Tensor, cached_masks: Optional[torch.Tensor]):
        """Restore cached slots to the masked state after unmasking."""
        if cached_masks is None:
            return
        block_size = self.block_size
        mask_view = dllm_mask.view(-1, block_size)
        mask_view.masked_fill_(cached_masks, consts.DLLM_MASKED)

    def reshape_logits(self, logits: torch.Tensor, inputs: ModelInputs) -> torch.Tensor:
        if not self._use_delayed_cache(inputs):
            return logits
        vocab_size = logits.size(-1)
        flat_logits = logits.reshape(-1, vocab_size)
        restored = self._scatter_logits(flat_logits, inputs)
        return restored.unsqueeze(0)

    def _update_dllm(self, next_token_ids: torch.Tensor, dllm_mask: torch.Tensor, seqlens: torch.Tensor):
        """Update token_ids and dllm_mask."""
        dllm_mask_token = self.dllm_mask_token
        dllm_block_length = self.block_size

        # reshape to (batch, dllm_block_length)
        next_token_ids = next_token_ids.view(-1, dllm_block_length).clone()
        dllm_mask = dllm_mask.view(-1, dllm_block_length).clone()

        # flags
        is_cached = (dllm_mask == consts.DLLM_CACHED).all(dim=1)

        is_masked = (dllm_mask == consts.DLLM_MASKED)
        next_token_ids[is_cached[:, None] | is_masked] = dllm_mask_token
        dllm_mask[is_cached] = consts.DLLM_MASKED
        seqlens = torch.where(is_cached.view(-1), seqlens, seqlens.new_zeros((1, )))

        return next_token_ids.flatten(), dllm_mask.flatten(), seqlens

    def slice_outputs(self, inputs: torch.Tensor, seq_length: torch.LongTensor) -> torch.Tensor:
        """Slice outputs."""
        block_length = self.block_size
        # batch size = 1
        if len(seq_length) == 1:
            return inputs[-block_length:]

        if len(seq_length) * block_length == inputs.size(0):
            return inputs
        last_idx = seq_length.cumsum(0)
        block_range = torch.arange(-block_length, 0, device=last_idx.device)
        index = (last_idx[:, None] + block_range[None, :]).flatten()
        inputs = inputs[index]
        return inputs

    def slice_extra_inputs(self, extra_inputs: DLLMExtraInputs, model_inputs: ModelInputs,
                           model_outputs: Dict[str, torch.Tensor], **kwargs) -> DLLMExtraInputs:
        """Slice outputs."""
        dllm_mask = self.slice_outputs(extra_inputs.dllm_mask, model_inputs.seq_length)
        return DLLMExtraInputs(dllm_mask=dllm_mask)

    def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor,
                              dllm_mask: torch.Tensor, **kwargs):
        """step."""
        from lmdeploy.pytorch import consts
        dllm_block_size = self.block_size
        DLLM_UNMASKED = consts.DLLM_UNMASKED
        is_unmasked = (dllm_mask == DLLM_UNMASKED).view(-1, dllm_block_size).all(dim=1, keepdim=True)
        num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size)
        num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)
        sampling_inputs.num_ignore_eos = num_ignore_eos.flatten()
        if sampling_inputs.random_offsets is not None:
            # random offset is used to generate random numbers for multinomial sampling
            # so we need to increase it by 1 at each step
            sampling_inputs.random_offsets += 1
        return sampling_inputs

    def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria:
        """Create stopping criteria."""
        # num_appendable
        num_appendable = [seq.sampling_param.max_new_tokens - seq.num_new_tokens for seq in seqs]
        num_appendable = torch.tensor(num_appendable)
        block_size = self.block_size
        remain = [seq.num_valid_ids % block_size for seq in seqs]
        num_appendable += torch.tensor(remain)

        # output_start_pos
        pos = [seq.output_start_pos for seq in seqs]
        output_start_pos = torch.tensor(pos)

        return DLLMStoppingCriteria(num_appendable_ids=num_appendable, output_start_pos=output_start_pos)

    def make_extra_inputs(self, seqs: 'SeqList') -> ExtraInputs:
        """Create extra inputs."""
        dllm_masks = [seq.dllm_mask for seq in seqs]
        dllm_masks = torch.as_tensor(np.concatenate(dllm_masks))
        return DLLMExtraInputs(dllm_mask=dllm_masks)

    def make_extra_outputs(self, extra_inputs: DLLMExtraInputs, **kwargs) -> DLLMExtraOutputs:
        """Create extra outputs."""
        dllm_mask = extra_inputs.dllm_mask
        return DLLMExtraOutputs(dllm_mask=dllm_mask)

    def update_inputs_for_next_step(self, model_inputs: 'ModelInputs', sampling_inputs: 'SamplingInputs',
                                    next_token_ids: torch.Tensor, model_metas: Any, extra_inputs: DLLMExtraInputs,
                                    **kwargs):
        """Step next inputs."""
        model_inputs.model_metas = model_metas
        dllm_mask = extra_inputs.dllm_mask

        next_token_ids, dllm_mask, step_seqlens = self._update_dllm(next_token_ids, dllm_mask, model_inputs.seq_length)
        model_inputs.step(next_token_ids, step_seqlens)
        self._step_sampling_inputs(sampling_inputs, next_token_ids, dllm_mask=dllm_mask)

        extra_inputs = DLLMExtraInputs(dllm_mask=dllm_mask)
        return model_inputs, extra_inputs

    def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
                      extra_inputs: DLLMExtraInputs):
        """Post sampling."""
        dllm_mask = extra_inputs.dllm_mask
        input_ids = inputs.input_ids
        input_ids = self.slice_outputs(input_ids.flatten(), inputs.seq_length)

        cached_masks = self._temporarily_cache_unprocessed(dllm_mask, inputs)
        dllm_mask, next_token_ids = self.unmasking_processor(logits, input_ids, next_token_ids, dllm_mask)
        self._restore_cached_unprocessed(dllm_mask, cached_masks)

        extra_inputs.dllm_mask = dllm_mask
        return next_token_ids, extra_inputs

    def make_dummy_next_token(self, inputs: 'ModelInputs', logits: torch.Tensor, extra_inputs: DLLMExtraInputs):
        """Make dummy next token for broadcast."""
        with torch.inference_mode():
            next_token_ids = inputs.input_ids.new_zeros(logits.size(0))
        return next_token_ids, extra_inputs

    @contextmanager
    def broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: DLLMExtraInputs, dist_ctx: DistContext):
        """Broadcast next token ids and extra inputs."""
        tp_gpu_group = dist_ctx.attn_tp_group.gpu_group
        rank = dist.get_global_rank(tp_gpu_group, 0)
        dist.broadcast(next_token_ids, src=rank, group=tp_gpu_group, async_op=True)
        handle = extra_inputs.broadcast(src=rank, group=tp_gpu_group, async_op=True)
        yield
        handle.wait()
