# Copyright (c) Alibaba, Inc. and its affiliates.
from dataclasses import dataclass, field
from typing import List, Optional

from swift.utils import get_logger

logger = get_logger()


@dataclass
class GenerationArguments:
    """
    GenerationArguments class is a dataclass that holds various arguments related to text generation.

    Args:
        max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Default is None (unlimited).
        temperature (Optional[float]): Sampling temperature. Default is None.
        top_k (Optional[int]): Top-k sampling parameter. Default is None.
        top_p (Optional[float]): Top-p (nucleus) sampling parameter. Default is None.
        repetition_penalty (Optional[float]): Penalty for repeated tokens. Default is None.
        num_beams (int): Number of beams for beam search. Default is 1.
        stream (bool): Flag to indicate if streaming output should be enabled. Default is None.
        stop_words (List[str]): List of stop words to end generation. Default is an empty list.
    """

    # generation config
    max_new_tokens: Optional[int] = None  # Unlimited, constrained by max_model_len.
    # If it is None, use the parameters from generation_config.
    temperature: Optional[float] = None  # Set to 0, which means do_sample is False.
    top_k: Optional[int] = None
    top_p: Optional[float] = None
    repetition_penalty: Optional[float] = None
    num_beams: int = 1

    stream: Optional[bool] = None
    stop_words: List[str] = field(default_factory=list)
    logprobs: bool = False
    top_logprobs: Optional[int] = None

    def _init_stream(self):
        if self.stream is None:
            self.stream = False

    def get_request_config(self):
        if getattr(self, 'task_type') != 'causal_lm':
            return
        from swift.llm import RequestConfig

        return RequestConfig(
            max_tokens=self.max_new_tokens,
            temperature=self.temperature,
            top_p=self.top_p,
            top_k=self.top_k,
            num_beams=self.num_beams,
            stop=self.stop_words,
            stream=self.stream,
            repetition_penalty=self.repetition_penalty,
            logprobs=self.logprobs,
            top_logprobs=self.top_logprobs)
