from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceGroup, SequenceStatus, PromptLogprobs, RequestMetrics, SampleLogprobs
from vllm.lora.request import LoRARequest
import time
from typing import List, Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
# setattr(SamplerOutput, "logits", None)
setattr(SequenceGroupOutput, "logits", None)

def __init__(
    self,
    index: int,
    text: str,
    token_ids: List[int],
    cumulative_logprob: float,
    logprobs: Optional[SampleLogprobs],
    finish_reason: Optional[str] = None,
    stop_reason: Union[int, str, None] = None,
    lora_request: Optional[LoRARequest] = None,
    logits_list = None,
) -> None:
    self.index = index
    self.text = text
    self.token_ids = token_ids
    self.cumulative_logprob = cumulative_logprob
    self.logprobs = logprobs
    self.finish_reason = finish_reason
    self.stop_reason = stop_reason
    self.lora_request = lora_request
    self.logits_list = logits_list

setattr(CompletionOutput, "__init__", __init__)
        
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
    seqs = seq_group.get_seqs()
    if len(seqs) == 1:
        top_n_seqs = seqs
    else:
        # Get the top-n sequences.
        n = seq_group.sampling_params.n
        if seq_group.sampling_params.use_beam_search:
            sorting_key = lambda seq: seq.get_beam_search_score(
                seq_group.sampling_params.length_penalty)
        else:
            sorting_key = lambda seq: seq.get_cumulative_logprob()
        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
        top_n_seqs = sorted_seqs[:n]

    # Create the outputs.
    # NOTE: We need omit logprobs here explicitly because the sequence
    # always has the logprobs of the sampled tokens even if the
    # logprobs are not requested.
    include_logprobs = seq_group.sampling_params.logprobs is not None
    outputs = [
        CompletionOutput(seqs.index(seq), seq.output_text,
                            seq.get_output_token_ids(),
                            seq.get_cumulative_logprob(),
                            seq.output_logprobs if include_logprobs else None,
                            SequenceStatus.get_finished_reason(seq.status),
                            seq.stop_reason,
                            logits_list=seq.get_logits_list()) for seq in top_n_seqs
    ]

    # Every sequence in the sequence group should have the same prompt.
    prompt = seq_group.prompt
    prompt_token_ids = seq_group.prompt_token_ids
    prompt_logprobs = seq_group.prompt_logprobs
    finished = seq_group.is_finished()
    finished_time = time.time() if finished else None
    seq_group.set_finished_time(finished_time)
    return cls(seq_group.request_id,
                prompt,
                prompt_token_ids,
                prompt_logprobs,
                outputs,
                finished,
                seq_group.metrics,
                lora_request=seq_group.lora_request)

setattr(RequestOutput, "from_seq_group", from_seq_group)