from typing import Iterable, List, Optional, Tuple, Type, Union
from vllm.outputs import RequestOutput
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput,
                           SequenceStatus)

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
                                    outputs: SequenceGroupOutput) -> None:

    # Process prompt logprobs
    prompt_logprobs = outputs.prompt_logprobs
    logits = outputs.logits
    if prompt_logprobs is not None:
        self.detokenizer.decode_prompt_logprobs_inplace(
            seq_group, prompt_logprobs)
        seq_group.prompt_logprobs = prompt_logprobs
    seq_group.logits = logits 
    # Process samples
    samples = outputs.samples
    parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
    existing_finished_seqs = seq_group.get_finished_seqs()
    parent_child_dict = {
        parent_seq.seq_id: []
        for parent_seq in parent_seqs
    }
    for sample in samples:
        sample.logits = logits
        parent_child_dict[sample.parent_seq_id].append(sample)
    # List of (child, parent)
    child_seqs: List[Tuple[Sequence, Sequence]] = []

    # Process the child samples for each parent sequence
    for parent in parent_seqs:
        child_samples: List[SequenceOutput] = parent_child_dict[
            parent.seq_id]
        if len(child_samples) == 0:
            # This parent sequence has no children samples. Remove
            # the parent sequence from the sequence group since it will
            # not be used in the future iterations.
            parent.status = SequenceStatus.FINISHED_ABORTED
            seq_group.remove(parent.seq_id)
            self.scheduler.free_seq(parent)
            continue
        # Fork the parent sequence if there are multiple child samples.
        for child_sample in child_samples[:-1]:
            new_child_seq_id = next(self.seq_counter)
            child = parent.fork(new_child_seq_id)
            child.append_token_id(child_sample.output_token,
                                    child_sample.logprobs,
                                    child_sample.logits)
            child_seqs.append((child, parent))
        # Continue the parent sequence for the last child sample.
        # We reuse the parent sequence here to reduce redundant memory
        # copies, especially when using non-beam search sampling methods.
        last_child_sample = child_samples[-1]
        parent.append_token_id(last_child_sample.output_token,
                                last_child_sample.logprobs,
                                last_child_sample.logits,)
        child_seqs.append((parent, parent))

    for seq, _ in child_seqs:
        self.detokenizer.decode_sequence_inplace(seq,
                                                    seq_group.sampling_params)
        self._check_stop(seq, seq_group.sampling_params)

    # Non-beam search case
    if not seq_group.sampling_params.use_beam_search:
        # For newly created child sequences, add them to the sequence group
        # and fork them in block manager if they are not finished.
        for seq, parent in child_seqs:
            if seq is not parent:
                seq_group.add(seq)
                if not seq.is_finished():
                    self.scheduler.fork_seq(parent, seq)

        # Free the finished and selected parent sequences' memory in block
        # manager. Keep them in the sequence group as candidate output.
        # NOTE: we need to fork the new sequences before freeing the
        # old sequences.
        for seq, parent in child_seqs:
            if seq is parent and seq.is_finished():
                self.scheduler.free_seq(seq)
        return

    # Beam search case
    # Select the child sequences to keep in the sequence group.
    selected_child_seqs = []
    unselected_child_seqs = []
    beam_width = seq_group.sampling_params.best_of
    length_penalty = seq_group.sampling_params.length_penalty

    # Select the newly finished sequences with the highest scores
    # to replace existing finished sequences.
    # Tuple of (seq, parent, is_new)
    existing_finished_seqs = [(seq, None, False)
                                for seq in existing_finished_seqs]
    new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
                            if seq.is_finished()]
    all_finished_seqs = existing_finished_seqs + new_finished_seqs
    # Sort the finished sequences by their scores.
    all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
        length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
                            reverse=True)
    for seq, parent, is_new in all_finished_seqs[:beam_width]:
        if is_new:
            # A newly generated child sequence finishes and has a high
            # score, so we will add it into the sequence group.
            selected_child_seqs.append((seq, parent))
    for seq, parent, is_new in all_finished_seqs[beam_width:]:
        if is_new:
            # A newly generated child sequence finishes but has a low
            # score, so we will not add it into the sequence group.
            # Additionally, if this sequence is a continuation of a
            # parent sequence, we will need remove the parent sequence
            # from the sequence group.
            unselected_child_seqs.append((seq, parent))
        else:
            # An existing finished sequence has a low score, so we will
            # remove it from the sequence group.
            seq_group.remove(seq.seq_id)

    # select the top beam_width sequences from the running
    # sequences for the next iteration to continue the beam
    # search.
    running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                            if not seq.is_finished()]
    # Sort the running sequences by their scores.
    running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
        length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
                            reverse=True)

    # Check if we can stop the beam search.
    if len(running_child_seqs) == 0:
        # No running sequences, stop the beam search.
        stop_beam_search = True
    elif len(all_finished_seqs) < beam_width:
        # Not enough finished sequences, continue the beam search.
        stop_beam_search = False
    else:
        # Check the early stopping criteria
        best_running_seq = running_child_seqs[0][0]
        current_worst_seq = all_finished_seqs[beam_width - 1][0]
        stop_beam_search = self._check_beam_search_early_stopping(
            seq_group.sampling_params.early_stopping,
            seq_group.sampling_params, best_running_seq, current_worst_seq)

    if stop_beam_search:
        # Stop the beam search and remove all the running sequences from
        # the sequence group.
        unselected_child_seqs.extend(running_child_seqs)
    else:
        # Continue the beam search and select the top beam_width sequences
        # to continue the beam search.
        selected_child_seqs.extend(running_child_seqs[:beam_width])
        # The remaining running sequences will not be used in the next
        # iteration. Again, if these sequences are continuations of
        # parent sequences, we will need to remove the parent sequences
        # from the sequence group.
        unselected_child_seqs.extend(running_child_seqs[beam_width:])

    # For newly created child sequences, add them to the sequence group
    # and fork them in block manager if they are not finished.
    for seq, parent in selected_child_seqs:
        if seq is not parent:
            seq_group.add(seq)
            if not seq.is_finished():
                self.scheduler.fork_seq(parent, seq)

    # Free the finished and selected parent sequences' memory in block
    # manager. Keep them in the sequence group as candidate output.
    for seq, parent in selected_child_seqs:
        if seq is parent and seq.is_finished():
            self.scheduler.free_seq(seq)

    # Remove the unselected parent sequences from the sequence group and
    # free their memory in block manager.
    for seq, parent in unselected_child_seqs:
        if seq is parent:
            # Remove the parent sequence if it is not selected for next
            # iteration
            seq_group.remove(seq.seq_id)
            self.scheduler.free_seq(seq)

from vllm.engine.llm_engine import LLMEngine
setattr(LLMEngine, "_process_sequence_group_outputs", _process_sequence_group_outputs)