from vllm.sequence import SequenceData, SequenceGroupMetadata, get_all_seq_ids
from vllm.spec_decode.interfaces import (
    SpeculativeProposals,
    SpeculativeScorer,
    SpeculativeScores,
)
from vllm.spec_decode.mqa_scorer import MQAScorer
from vllm.spec_decode.util import nvtx_range

from minedraft.patching import MinePatch
from minedraft.plugin.sequence import MineExecuteModelRequest, MineSequenceGroupMetadata
from minedraft.plugin.spec_decode.interfaces import ParallelSpeculativeScorer


class MQAScorerPatch(MinePatch[MQAScorer], SpeculativeScorer):

    # [Parallel SD] Fix varlen hidden states handling in MQA scorer.
    def score_proposals(
        self,
        execute_model_req: MineExecuteModelRequest,
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        target_seq_group_metadata_list = []
        target_seq_id_start = max(
            get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
        all_proposal_tokens = proposals.proposal_token_ids.tolist()
        all_proposal_lengths = proposals.proposal_lens.tolist()
        for i, seq_group_metadata in enumerate(
                execute_model_req.seq_group_metadata_list):
            if all_proposal_lengths[i] == 0:
                # Keep prompt seqs untouched (keep computed_tokens for chunks).
                target_seq_group_metadata_list.append(seq_group_metadata)
                continue

            seq_data_dict = seq_group_metadata.seq_data
            assert len(seq_data_dict) == 1
            seq_id = next(iter(seq_data_dict.keys()))

            seq_data: SequenceData = seq_data_dict[seq_id]
            prompt_token_ids = seq_data.get_prompt_token_ids()
            output_token_ids = seq_data.get_output_token_ids()
            proposal_token_ids = all_proposal_tokens[
                i][:all_proposal_lengths[i]]
            new_output_token_ids = [*output_token_ids, *proposal_token_ids]

            target_seq_id = target_seq_id_start + i
            new_seq_data = SequenceData.from_seqs(
                prompt_token_ids=prompt_token_ids,
                output_token_ids=new_output_token_ids,
            )
            new_seq_data.update_num_computed_tokens(
                len(prompt_token_ids) + len(output_token_ids) - 1)

            # Ensure that the new decode sequence has at least one token.
            assert len(output_token_ids) >= 1
            new_seq_data_dict = {target_seq_id: new_seq_data}

            new_seq_group_metadata = MineSequenceGroupMetadata(
                request_id=seq_group_metadata.request_id,
                is_prompt=seq_group_metadata.is_prompt,
                seq_data=new_seq_data_dict,
                sampling_params=seq_group_metadata.sampling_params,
                block_tables={
                    target_seq_id: seq_group_metadata.block_tables[seq_id],
                },
                lora_request=None,
            )
            target_seq_group_metadata_list.append(new_seq_group_metadata)

        target_sampler_output = self._scorer_worker.execute_model(
            execute_model_req=execute_model_req.clone(
                seq_group_metadata_list=target_seq_group_metadata_list))

        target_sampler_output = target_sampler_output[0]

        k = execute_model_req.num_lookahead_slots
        bs = len(execute_model_req.seq_group_metadata_list)
        target_token_ids = target_sampler_output.sampled_token_ids
        target_probs = target_sampler_output.sampled_token_probs
        target_logprobs = target_sampler_output.logprobs
        prompt_logprobs = None

        # If all requests have the same number of query tokens, we can avoid
        # the for loop to build output for better performance.
        if min(all_proposal_lengths) == k:
            # Regular decodes only.
            assert all(not sg.is_prompt
                       for sg in target_seq_group_metadata_list
                       if sg.is_prompt)
            bs, _ = proposals.proposal_token_ids.shape
            all_tokens = target_token_ids.reshape(bs, k + 1)
            all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
            all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
            hidden_states = None
            if target_sampler_output.hidden_states is not None:
                hidden_states = target_sampler_output.hidden_states.reshape(
                    bs, (k + 1), -1)
        else:
            # We either have decodes with different lens or prefill+decodes.
            all_tokens = target_token_ids.new_full(size=(bs, k + 1),
                                                   fill_value=-1)
            all_probs = target_probs.new_zeros(*all_tokens.shape,
                                               self._vocab_size)
            all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                    fill_value=-float("inf"))
            hidden_states = target_hidden_states = None
            if target_sampler_output.hidden_states is not None:
                target_hidden_states = target_sampler_output.hidden_states
                hidden_states = target_hidden_states.new_zeros(
                    *all_tokens.shape, target_hidden_states.shape[-1])
            target_token_ids = target_token_ids.flatten()

            # When prompt logprobs is enabled, lens of returned tensors go from
            # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
            # We adjust stride accordingly to get the generated tokens and
            # their probs, but pass on prompt_logprobs as is, since it may be
            # that n_prompts >> K.
            has_prompt_log = any((sg.sampling_params.prompt_logprobs
                                  and sg.sampling_params.prompt_logprobs > 0)
                                 for sg in target_seq_group_metadata_list)
            # TODO (NickLucche) we should surface `disable_logprobs` as to not
            # break abstraction to get its value.
            if (not self._scorer_worker.model_runner.disable_logprobs\
                and has_prompt_log):
                prompt_logprobs = [
                    o.prompt_logprobs for o in target_sampler_output.outputs
                ]

            # Split loop into prefill|decode for readability.
            start_loc, i = 0, 0
            while i < len(target_seq_group_metadata_list
                          ) and target_seq_group_metadata_list[i].is_prompt:
                seq_meta = target_seq_group_metadata_list[i]
                end_loc = start_loc
                if has_prompt_log:
                    end_loc += seq_meta.token_chunk_size
                elif seq_meta.do_sample:
                    end_loc += 1

                # Skip chunks with no output tokens.
                if seq_meta.do_sample:
                    # Get sampled token (last position in chunk) and its prob.
                    all_tokens[i, 0] = target_token_ids[end_loc - 1]
                    all_probs[i, 0] = target_probs[end_loc - 1]
                    all_logprobs[i, 0] = target_logprobs[end_loc - 1]
                    if hidden_states is not None:
                        hidden_states[i, 0] = target_hidden_states[end_loc - 1]

                i += 1
                start_loc = end_loc
            # Decodes.
            while i < len(target_seq_group_metadata_list):
                proposed_len, seq_meta = all_proposal_lengths[
                    i], target_seq_group_metadata_list[i]
                output_len = proposed_len + 1
                end_loc = start_loc + output_len
                all_tokens[
                    i, :output_len] = target_token_ids[start_loc:end_loc]
                all_probs[i, :output_len] = target_probs[start_loc:end_loc]
                all_logprobs[
                    i, :output_len] = target_logprobs[start_loc:end_loc]
                if hidden_states is not None:
                    hidden_states[i, :output_len] = target_hidden_states[
                        start_loc:end_loc]
                start_loc = end_loc
                i += 1

        return SpeculativeScores(probs=all_probs,
                                 token_ids=all_tokens,
                                 logprobs=all_logprobs,
                                 hidden_states=hidden_states,
                                 prompt_logprobs=prompt_logprobs)

class ParallelMQAScorer(ParallelSpeculativeScorer):

    @nvtx_range("ParallelMQAScorer.start_score_proposals")
    def start_score_proposals(
        self,
        execute_model_req: MineExecuteModelRequest,
        proposals: SpeculativeProposals,
    ) -> None:
        target_seq_group_metadata_list = []
        target_seq_id_start = max(
            get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
        all_proposal_tokens = proposals.proposal_token_ids.tolist()
        all_proposal_lengths = proposals.proposal_lens.tolist()
        for i, seq_group_metadata in enumerate(
                execute_model_req.seq_group_metadata_list):
            if all_proposal_lengths[i] == 0:
                # Keep prompt seqs untouched (keep computed_tokens for chunks).
                target_seq_group_metadata_list.append(seq_group_metadata)
                continue

            seq_data_dict = seq_group_metadata.seq_data
            assert len(seq_data_dict) == 1
            seq_id = next(iter(seq_data_dict.keys()))

            seq_data: SequenceData = seq_data_dict[seq_id]
            prompt_token_ids = seq_data.get_prompt_token_ids()
            output_token_ids = seq_data.get_output_token_ids()
            proposal_token_ids = all_proposal_tokens[
                i][:all_proposal_lengths[i]]
            new_output_token_ids = [*output_token_ids, *proposal_token_ids]

            target_seq_id = target_seq_id_start + i
            new_seq_data = SequenceData.from_seqs(
                prompt_token_ids=prompt_token_ids,
                output_token_ids=new_output_token_ids,
            )
            new_seq_data.update_num_computed_tokens(
                len(prompt_token_ids) + len(output_token_ids) - 1)

            # Ensure that the new decode sequence has at least one token.
            assert len(output_token_ids) >= 1
            new_seq_data_dict = {target_seq_id: new_seq_data}

            new_seq_group_metadata = SequenceGroupMetadata(
                request_id=seq_group_metadata.request_id,
                is_prompt=seq_group_metadata.is_prompt,
                seq_data=new_seq_data_dict,
                sampling_params=seq_group_metadata.sampling_params,
                block_tables={
                    target_seq_id: seq_group_metadata.block_tables[seq_id],
                },
                lora_request=None,
            )
            target_seq_group_metadata_list.append(new_seq_group_metadata)

        target_execute_model_req = execute_model_req.clone(
            seq_group_metadata_list=target_seq_group_metadata_list)
        self._scorer_worker.execute_model(execute_model_req=target_execute_model_req)
        execute_model_req.scoring_async_handle = target_execute_model_req.scoring_async_handle

    @nvtx_range("ParallelMQAScorer.score_proposals")
    def score_proposals(
        self,
        execute_model_req: MineExecuteModelRequest,
        proposals: SpeculativeProposals,
    ) -> SpeculativeScores:
        work = execute_model_req.scoring_async_handle
        execute_model_req.scoring_async_handle = None
        assert work is not None, "No async handle found for scoring"
        work.wait()

        target_sampler_output = self._recv_sampler_output_from_repr_scorer()

        all_proposal_lengths = proposals.proposal_lens.tolist()

        k = execute_model_req.num_lookahead_slots
        bs = len(execute_model_req.seq_group_metadata_list)
        target_token_ids = target_sampler_output.sampled_token_ids
        target_probs = target_sampler_output.sampled_token_probs
        target_logprobs = target_sampler_output.logprobs
        prompt_logprobs = None

        # If all requests have the same number of query tokens, we can avoid
        # the for loop to build output for better performance.
        if min(all_proposal_lengths) == k:
            # Regular decodes only.
            assert all(not sg.is_prompt
                       for sg in execute_model_req.seq_group_metadata_list
                       if sg.is_prompt)
            bs, _ = proposals.proposal_token_ids.shape
            all_tokens = target_token_ids.reshape(bs, k + 1)
            all_probs = target_probs.reshape(bs, k + 1, self._vocab_size)
            all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size)
            hidden_states = None
            if target_sampler_output.hidden_states is not None:
                hidden_states = target_sampler_output.hidden_states.reshape(
                    bs, (k + 1), -1)
        else:
            # We either have decodes with different lens or prefill+decodes.
            all_tokens = target_token_ids.new_full(size=(bs, k + 1),
                                                   fill_value=-1)
            all_probs = target_probs.new_zeros(*all_tokens.shape,
                                               self._vocab_size)
            all_logprobs = target_logprobs.new_full(size=all_probs.shape,
                                                    fill_value=-float("inf"))
            hidden_states = target_hidden_states = None
            if target_sampler_output.hidden_states is not None:
                target_hidden_states = target_sampler_output.hidden_states
                hidden_states = target_hidden_states.new_zeros(
                    *all_tokens.shape, target_hidden_states.shape[-1])
            target_token_ids = target_token_ids.flatten()

            # When prompt logprobs is enabled, lens of returned tensors go from
            # n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
            # We adjust stride accordingly to get the generated tokens and
            # their probs, but pass on prompt_logprobs as is, since it may be
            # that n_prompts >> K.
            has_prompt_log = any((sg.sampling_params.prompt_logprobs
                                  and sg.sampling_params.prompt_logprobs > 0)
                                 for sg in execute_model_req.seq_group_metadata_list)
            # TODO (NickLucche) we should surface `disable_logprobs` as to not
            # break abstraction to get its value.
            if (not self._scorer_worker.model_runner.disable_logprobs\
                and has_prompt_log):
                prompt_logprobs = [
                    o.prompt_logprobs for o in target_sampler_output.outputs
                ]

            # Split loop into prefill|decode for readability.
            start_loc, i = 0, 0
            while i < len(execute_model_req.seq_group_metadata_list
                          ) and execute_model_req.seq_group_metadata_list[i].is_prompt:
                seq_meta = execute_model_req.seq_group_metadata_list[i]
                end_loc = start_loc
                if has_prompt_log:
                    end_loc += seq_meta.token_chunk_size
                elif seq_meta.do_sample:
                    end_loc += 1

                # Skip chunks with no output tokens.
                if seq_meta.do_sample:
                    # Get sampled token (last position in chunk) and its prob.
                    all_tokens[i, 0] = target_token_ids[end_loc - 1]
                    all_probs[i, 0] = target_probs[end_loc - 1]
                    all_logprobs[i, 0] = target_logprobs[end_loc - 1]
                    if hidden_states is not None:
                        hidden_states[i, 0] = target_hidden_states[end_loc - 1]

                i += 1
                start_loc = end_loc
            # Decodes.
            while i < len(execute_model_req.seq_group_metadata_list):
                proposed_len, seq_meta = all_proposal_lengths[
                    i], execute_model_req.seq_group_metadata_list[i]
                output_len = proposed_len + 1
                end_loc = start_loc + output_len
                all_tokens[
                    i, :output_len] = target_token_ids[start_loc:end_loc]
                all_probs[i, :output_len] = target_probs[start_loc:end_loc]
                all_logprobs[
                    i, :output_len] = target_logprobs[start_loc:end_loc]
                if hidden_states is not None:
                    hidden_states[i, :output_len] = target_hidden_states[
                        start_loc:end_loc]
                start_loc = end_loc
                i += 1

        return SpeculativeScores(probs=all_probs,
                                 token_ids=all_tokens,
                                 logprobs=all_logprobs,
                                 hidden_states=hidden_states,
                                 prompt_logprobs=prompt_logprobs)