import time
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import IntEnum
from typing import ClassVar, Dict, List, Optional, Set

import msgspec
from msgspec import field as msgspec_field

from vllm.sampling_params import SamplingParams


class RequestStatsUpdate(
        msgspec.Struct,  # type: ignore
        array_like=True,
        omit_defaults=True,
        gc=False):
    """
    An update to the request stats.

    This represents a stats update at a specific timestamp with metadata
    associated with the update.

    NOTE: since there might be multiple processes generating updates at
    different parts of the engine (e.g. input processor, scheduler, engine core,
    etc.), we use the monotonic timestamp to record the update to compute any
    intervals, and explicit wall-clock timestamp should be used for timestamps.

    WARNING: This assumes stats are generated in a single machine. If there are
    potentially multiple machines, one should always generate the stats updates
    on one single machine or use something else.
    """

    class Type(IntEnum):
        """See `RequestStats` for the lifecycle of a request."""

        # Request arrived at the engine frontend.
        ARRIVED = 0
        # Input processed by the input processor.
        INPUT_PROCESSED = 1
        # Queued on the engine core.
        QUEUED = 2
        # Scheduled running prefill by the scheduler.
        # A request could be running a new prefill on the prompt tokens or
        # a resumed prefill on the original prefill tokens + generated output
        # tokens before preemption.
        PREFILLING = 3
        # Preempted by the scheduler.
        PREEMPTED = 4
        # Output token is generated by the engine core.
        DECODING = 5
        # Token detokenized by the detokenizer.
        # We will record the timestamp for each output token, as well as the
        # finish reason.
        DETOKENIZED = 6
        # Request finishes (or aborts).
        FINISHED = 7

    """
    Valid state updates:
    ARRIVED
    │
    ├──────► INPUT_PROCESSED ──────► QUEUED ──────► PREFILLING ◄────┐
    │              │                   │              │             │
    │              │                   │              ▼             │
    │              │                   │       -──► DECODING        │
    │              │                   │       |      │             │
    │              │                   │       |      ▼             │
    │              │                   │       └─ DETOKENIZED       │
    │              │                   │              │             │
    │              │                   │              ▼             │
    │              ▼                   ▼           PREEMPTED ◄──────┘
    │              │                   │              │
    └──────────────┴───────────────────┴──────────────┴
                                │
                                ▼
                FINISHED (All could go to FINISHED)
    """
    _VALID_TRANSITIONS: ClassVar[Dict[Type, Set[Type]]] = {
        Type.ARRIVED: {
            Type.INPUT_PROCESSED,
            Type.FINISHED,
        },
        Type.INPUT_PROCESSED: {
            Type.QUEUED,
            Type.FINISHED,
        },
        Type.QUEUED: {
            Type.PREFILLING,
            Type.FINISHED,
        },
        Type.PREFILLING: {
            Type.DECODING,
            Type.PREEMPTED,
            Type.FINISHED,
        },
        Type.DECODING: {
            Type.DETOKENIZED,
            Type.FINISHED,
        },
        Type.DETOKENIZED: {
            Type.DECODING,
            Type.PREEMPTED,
            Type.FINISHED,
        },
        Type.PREEMPTED: {Type.PREFILLING, Type.FINISHED},
        Type.FINISHED: set(),
    }

    request_id: str

    type: Type

    # Timestamp when the update is recorded. This is used to record time
    # intervals between events rather than wall clock time.
    monotonic_ts_s: float = msgspec_field(
        default_factory=lambda: time.monotonic())

    ############################################################
    # Metadata associated with the update.
    ############################################################
    # For input_processed. Metadata needed for stats logging.
    num_prompt_tokens: Optional[int] = None
    sampling_params: Optional[SamplingParams] = None

    # For running.
    # Number of tokens computed when scheduled to run.
    num_computed_tokens: Optional[int] = None
    # Number of cached tokens when scheduled to run.
    num_cached_tokens: Optional[int] = None

    # For decoded.
    # The number of new output tokens generated.
    num_new_tokens: Optional[int] = None

    # For both detokenized and decoded.
    # Finished reason.
    finish_reason: Optional[str] = None

    # Non-optional fields for each update type.
    _REQUIRED_FIELDS: ClassVar[Dict[Type, List[str]]] = {
        Type.INPUT_PROCESSED: ["num_prompt_tokens", "sampling_params"],
        Type.PREFILLING: ["num_computed_tokens", "num_cached_tokens"],
        Type.DETOKENIZED: ["num_new_tokens"],
        Type.FINISHED: ["finish_reason"],
    }

    def __post_init__(self):
        required_fields = self._REQUIRED_FIELDS.get(self.type, [])
        for field in required_fields:
            if getattr(self, field) is None:
                raise ValueError(
                    f"Field {field} is required for update type {self.type}.")

    @staticmethod
    def check_valid_update(
        update: "RequestStatsUpdate",
        last_update_type: Optional[Type],
        last_updated_ts_s: Optional[float],
    ):
        if last_update_type is None:
            assert update.type == RequestStatsUpdate.Type.ARRIVED
        else:
            valid_cur_update_types = RequestStatsUpdate._VALID_TRANSITIONS[
                last_update_type]
            assert update.type in valid_cur_update_types, (
                f"Invalid update type: {update.type} for last_update_type: "
                f"{last_update_type}.")

        if last_updated_ts_s is not None:
            assert update.monotonic_ts_s >= last_updated_ts_s, (
                "Update timestamp must be monotonically increasing, but "
                f"last_updated_ts_s={last_updated_ts_s} and "
                f"update.monotonic_ts_s={update.monotonic_ts_s}.")


@dataclass
class RequestStats:
    """Stats associated with a request (`Request`)."""

    ############################################################
    # Metadata
    ############################################################
    request_id: str
    sampling_params: Optional[SamplingParams] = None
    num_prompt_tokens: Optional[int] = None

    ############################################################
    # Metrics and Stats
    ############################################################
    # Timestamp when the request was last updated.
    last_updated_ts_s: Optional[float] = None

    # Last update stats type.
    last_update_type: Optional[RequestStatsUpdate.Type] = None

    # Timestamp when the request arrived at the llm engine.
    arrival_ts_s: Optional[float] = None

    # Number of tokens cached. When part of the request prefix is cached,
    # this will be set.
    num_cached_tokens: int = 0

    # Number of tokens computed.
    num_computed_tokens: int = 0

    # The timestamp when the request become waiting in the queue.
    queued_ts_s: Optional[float] = None

    # When the input processor is completed.
    input_processor_end_ts_s: Optional[float] = None

    # A sorted list of timestamps when the request was scheduled to prefill.
    # This could be when:
    # 1. the request is newly scheduled, so it's a new prefill.
    # 2. the request was preempted and resumed. It is equivalent to running
    #    a prefill of the original prefill tokens + generated output tokens
    #    before preemption.
    prefill_start_ts_s_lst: List[float] = dataclass_field(default_factory=list)

    # A list of timestamps when a token is decoded by the engine core.
    decoding_ts_s_lst: List[float] = dataclass_field(default_factory=list)

    # A sorted list of timestamps for each output token.
    output_token_ts_s_lst: List[float] = dataclass_field(default_factory=list)

    # First token's timestamp.
    first_token_ts_s: Optional[float] = None

    # TODO(rickyx): we need model runner to surface these.
    model_forward_duration_s: float = 0.0
    # Includes model forward, block/sync across workers, cpu-gpu sync time
    # and sampling time.
    model_execute_duration_s: float = 0.0

    # A sorted list of timestamps when the request was preempted at the
    # scheduler.
    # TODO(rickyx): right now, we don't actually have a good high-level
    # metric to measure the impact of preemption other than observation of
    # large P99 TPOT. Ideally we could quantify the impact of preemption by
    # measuring the number of tokens re-computed due to preemption.
    preempted_ts_s_lst: List[float] = dataclass_field(default_factory=list)

    # Timestamp when the request was finished at the engine core.
    finished_ts_s: Optional[float] = None

    # Finish reason.
    finish_reason: Optional[str] = None

    ############################################################
    # Derived properties.
    ############################################################
    @property
    def prefill_ts_s(self) -> Optional[float]:
        """The timestamp when the request started prefilling.
        Since a request could be preempted in decoding and later resumed
        to prefill the decoded tokens, we use the first prefill start timestamp.
        """
        return (self.prefill_start_ts_s_lst[0]
                if self.prefill_start_ts_s_lst else None)

    @property
    def e2e_latency_s(self) -> Optional[float]:
        if self.finished_ts_s is None or self.arrival_ts_s is None:
            return None
        assert self.finished_ts_s >= self.arrival_ts_s
        return self.finished_ts_s - self.arrival_ts_s

    @property
    def queue_duration_s(self) -> Optional[float]:
        """How long the request was waiting to run."""
        if self.queued_ts_s is None or self.prefill_ts_s is None:
            # Either not queued or not running yet.
            return None
        assert self.queued_ts_s <= self.prefill_ts_s
        return self.prefill_ts_s - self.queued_ts_s

    @property
    def inference_latency_s(self) -> Optional[float]:
        """How long the request was running inference
        (prefill and decode)."""
        if self.finished_ts_s is None or self.prefill_ts_s is None:
            return None
        assert self.finished_ts_s >= self.prefill_ts_s
        return self.finished_ts_s - self.prefill_ts_s

    @property
    def first_token_latency_s(self) -> Optional[float]:
        if self.first_token_ts_s is None or self.arrival_ts_s is None:
            return None
        assert self.first_token_ts_s >= self.arrival_ts_s
        return self.first_token_ts_s - self.arrival_ts_s

    @property
    def prefill_latency_s(self) -> Optional[float]:
        if self.first_token_ts_s is None or self.prefill_ts_s is None:
            return None
        assert self.first_token_ts_s >= self.prefill_ts_s
        return self.first_token_ts_s - self.prefill_ts_s

    @property
    def decode_latency_s(self) -> Optional[float]:
        if self.e2e_latency_s is None or self.first_token_latency_s is None:
            return None
        assert self.e2e_latency_s >= self.first_token_latency_s
        return self.e2e_latency_s - self.first_token_latency_s

    @property
    def output_token_latency_s_lst(self) -> List[float]:
        if len(self.output_token_ts_s_lst) == 0:
            return []
        latency_s_lst = []
        for i in range(1, len(self.output_token_ts_s_lst)):
            assert (self.output_token_ts_s_lst[i]
                    >= self.output_token_ts_s_lst[i - 1])
            latency_s = (self.output_token_ts_s_lst[i] -
                         self.output_token_ts_s_lst[i - 1])
            latency_s_lst.append(latency_s)
        return latency_s_lst

    @property
    def num_output_tokens(self) -> int:
        return len(self.output_token_ts_s_lst)

    @property
    def is_finished(self) -> bool:
        return self.finished_ts_s is not None

    def update_from(self, update: "RequestStatsUpdate"):
        RequestStatsUpdate.check_valid_update(update, self.last_update_type,
                                              self.last_updated_ts_s)
        ts = update.monotonic_ts_s
        self.last_updated_ts_s = ts
        self.last_update_type = update.type
        if update.type == RequestStatsUpdate.Type.ARRIVED:
            self.arrival_ts_s = ts
        elif update.type == RequestStatsUpdate.Type.INPUT_PROCESSED:
            self.input_processor_end_ts_s = ts
            self.sampling_params = update.sampling_params
            self.num_prompt_tokens = update.num_prompt_tokens
        elif update.type == RequestStatsUpdate.Type.QUEUED:
            self.queued_ts_s = ts
        elif update.type == RequestStatsUpdate.Type.PREFILLING:
            self.prefill_start_ts_s_lst.append(ts)
            self.num_cached_tokens = update.num_cached_tokens or 0
            self.num_computed_tokens = update.num_computed_tokens or 0
        elif update.type == RequestStatsUpdate.Type.PREEMPTED:
            self._reset_for_preemption(ts)
        elif update.type == RequestStatsUpdate.Type.DECODING:
            self.decoding_ts_s_lst.append(ts)
        elif update.type == RequestStatsUpdate.Type.DETOKENIZED:
            self._record_detokenized_output(
                ts,
                update.num_new_tokens or 0,
            )
        elif update.type == RequestStatsUpdate.Type.FINISHED:
            self.finished_ts_s = ts
            self.finish_reason = update.finish_reason
        else:
            raise ValueError(f"Unknown update type: {update.type}")

    def _record_detokenized_output(
        self,
        ts_s: float,
        num_new_tokens: int,
    ):
        # Update if first output token is generated.
        if len(self.output_token_ts_s_lst) == 0:
            self.first_token_ts_s = ts_s
            assert (
                self.prefill_ts_s is not None
            ), "Request must be running before generating output tokens."

        # Some X new tokens were generated at the ts.
        self.output_token_ts_s_lst.extend([ts_s] * num_new_tokens)

    def _reset_for_preemption(self, ts_s: float):
        self.preempted_ts_s_lst.append(ts_s)
        # Reset the computed tokens since it might restart the prefill.
        self.num_computed_tokens = 0
        # Cached token count might also change when resumed.
        self.num_cached_tokens = 0
        # These stats don't change since they happen before request running.
        # - arrival_ts_s
        # - input_processor_end_ts_s
        # - sampling_params
        # - num_prompt_tokens
        # - first_token_ts_s
        #
        # These stats are accumulated over preemptions:
        # - output_token_ts_s_lst
        # - prefill_start_ts_s_lst (after preemption, it will prefill the
        #   original prefill tokens and any output tokens generated before
        #   preemption.)


@dataclass
class KVCacheStats:
    #   KV Cache Usage in %
    gpu_cache_usage_sys: float = 0.0
    gpu_prefix_cache_hit_rate: float = 0.0


@dataclass
class SchedulerStats:
    """Stats associated with the scheduler."""

    # Number of requests currently running.
    num_running_reqs: int = 0
    # Number of requests currently waiting.
    num_waiting_reqs: int = 0

    kv_cache_stats: KVCacheStats = dataclass_field(
        default_factory=KVCacheStats)


@dataclass
class EngineCoreProcessStats:
    """Stats associated with the engine core process."""

    # Number of requests currently in the input queue. None if the engine core
    # is not running in multiprocess mode.
    input_queue_size: Optional[int] = None
    # Number of outputs currently in the output queue. None if the engine core
    # is not running in multiprocess mode.
    output_queue_size: Optional[int] = None


class EngineCoreStatsSnapshot(
        msgspec.Struct,  # type: ignore
        array_like=True,
        omit_defaults=True,
        gc=False):
    """
    A snapshot of the EngineCore's current stats over a period of time.
    """

    # Snapshot of the scheduler stats.
    scheduler_stats: SchedulerStats = msgspec_field(
        default_factory=SchedulerStats)

    # Per request stats updates.
    requests_stats_updates: List[RequestStatsUpdate] = msgspec_field(
        default_factory=list)

    # Engine core's queue stats.
    engine_core_process_stats: EngineCoreProcessStats = msgspec_field(
        default_factory=EngineCoreProcessStats)

    # TODO(rickyx): Add other components' stats,
    # e.g. model runner/worker and etc.
