# SPDX-License-Identifier: Apache-2.0
"""Minimal overlap-enabled wrapper with a vLLM-like generate interface."""

from __future__ import annotations

import os
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence

from vllm import LLM, SamplingParams

from .vllm_overlap_patch import (
    apply_distiller_overlap_patch,
    export_distiller_overlap_cudagraph_stats,
    export_distiller_overlap_log,
    export_distiller_overlap_loss_log,
    install_distiller_overlap_hooks,
    update_distiller_overlap_state,
)
from .vllm_overlap_patch_v1 import (
    apply_distiller_overlap_patch_v1,
    export_distiller_overlap_cudagraph_stats_v1,
    export_distiller_overlap_log_v1,
    export_distiller_overlap_loss_log_v1,
)

os.environ.setdefault("VLLM_USE_V1", "0")


@dataclass(frozen=True)
class OverlapConfig:
    enabled: bool = False
    beta: float = 0.5
    topk: int = 8
    buffer_slots: int = 2
    first_layer: int = 0
    mix_mode: str = "logits"
    train_enabled: bool = False
    train_sync_interval: int = 1
    normalize_inputs: bool = True
    reward_scale: bool = False
    logits_rescale: bool = False
    logit_loss_weight: float = 0.0
    curiosity_only: bool = False


class OverlapLLM:
    """Wraps vLLM LLM with optional overlap distiller support."""

    def __init__(
        self,
        model_name: str,
        *,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.5,
        max_model_len: int = 4096,
        enforce_eager: bool = True,
        overlap_config: Optional[OverlapConfig] = None,
        **kwargs,
    ) -> None:
        self.llm = LLM(
            model=model_name,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=max_model_len,
            enforce_eager=enforce_eager,
            **kwargs,
        )
        self._overlap_config: Optional[OverlapConfig] = None
        if overlap_config and overlap_config.enabled:
            self.enable_overlap(overlap_config)

    def enable_overlap(self, config: OverlapConfig) -> None:
        use_v1 = os.getenv("VLLM_USE_V1", "0") == "1"
        if use_v1:
            apply_distiller_overlap_patch_v1()
        else:
            apply_distiller_overlap_patch()
        update_distiller_overlap_state(
            enabled=config.enabled,
            beta=config.beta,
            topk=config.topk,
            buffer_slots=config.buffer_slots,
            first_layer=config.first_layer,
            mix_mode=config.mix_mode,
            train_enabled=config.train_enabled,
            train_sync_interval=config.train_sync_interval,
            normalize_inputs=config.normalize_inputs,
            reward_scale=config.reward_scale,
            logits_rescale=config.logits_rescale,
            logit_loss_weight=config.logit_loss_weight,
            curiosity_only=config.curiosity_only,
        )

        if not use_v1:
            def _apply(_model) -> None:
                from .vllm_overlap_patch import (
                    apply_distiller_overlap_patch,
                    update_distiller_overlap_state,
                )

                apply_distiller_overlap_patch()
                update_distiller_overlap_state(
                    enabled=config.enabled,
                    beta=config.beta,
                    topk=config.topk,
                    buffer_slots=config.buffer_slots,
                    first_layer=config.first_layer,
                    mix_mode=config.mix_mode,
                    train_enabled=config.train_enabled,
                    train_sync_interval=config.train_sync_interval,
                    normalize_inputs=config.normalize_inputs,
                    reward_scale=config.reward_scale,
                    logits_rescale=config.logits_rescale,
                    logit_loss_weight=config.logit_loss_weight,
                    curiosity_only=config.curiosity_only,
                )

            self.llm.llm_engine.model_executor.apply_model(_apply)
            install_distiller_overlap_hooks(self.llm.llm_engine.model_executor)
        self._overlap_config = config

    def generate(
        self,
        prompts: Sequence[str] | str,
        sampling_params: SamplingParams,
    ):
        return self.llm.generate(prompts, sampling_params)

    def get_tokenizer(self):
        return self.llm.get_tokenizer()

    def export_overlap_log(self, file_path: str) -> Optional[Iterable[int]]:
        if not file_path:
            return None
        use_v1 = os.getenv("VLLM_USE_V1", "0") == "1"
        if use_v1:
            return export_distiller_overlap_log_v1(file_path)
        return export_distiller_overlap_log(self.llm.llm_engine.model_executor, file_path)

    def export_overlap_loss_log(self, file_path: str) -> Optional[Iterable[int]]:
        if not file_path:
            return None
        use_v1 = os.getenv("VLLM_USE_V1", "0") == "1"
        if use_v1:
            return export_distiller_overlap_loss_log_v1(file_path)
        return export_distiller_overlap_loss_log(self.llm.llm_engine.model_executor, file_path)

    def export_overlap_cudagraph_stats(self) -> dict:
        use_v1 = os.getenv("VLLM_USE_V1", "0") == "1"
        if use_v1:
            return export_distiller_overlap_cudagraph_stats_v1()
        return export_distiller_overlap_cudagraph_stats()
