import dataclasses
from typing import Union, Optional, Sequence
from openai import OpenAI
import os


api_key = os.getenv("DEEPSEEK_API_KEY")
client = OpenAI(api_key=api_key, base_url="https://api.deepseek.com/v1")


TOKENS_NUMBER = PROMPT_NUMBER = COMPLETION_NUMBER = 0


def get_tokens_number():
    return TOKENS_NUMBER

@dataclasses.dataclass
class DecodingArguments:
    max_tokens: int = 2048
    temperature: float = 1.0
    top_p: float = 1.0
    n: int = 1
    stop: Optional[Sequence[str]] = None
    presence_penalty: float = 0.0
    frequency_penalty: float = 0.0


def generate_response(
    prompt: Union[str, Sequence[str]],
    decoding_args: DecodingArguments = None,
    return_list=False
):
    if decoding_args is None:
        decoding_args = DecodingArguments()

    if isinstance(prompt, list):
        prompt = prompt[-1]["content"]

    messages = [
        {"role": "system", "content": "You are a helpful assistant"},
        {"role": "user", "content": prompt},
    ]

    response = client.chat.completions.create(
        model="deepseek-chat",
        messages=messages,
        max_tokens=decoding_args.max_tokens,
        temperature=decoding_args.temperature,
        top_p=decoding_args.top_p,
        n=decoding_args.n,
        stop=decoding_args.stop,
        presence_penalty=decoding_args.presence_penalty,
        frequency_penalty=decoding_args.frequency_penalty,
        stream=False
    )

    responses = [choice.message.content for choice in response.choices]
    if decoding_args.n == 1 and not return_list:
        return responses[0]
    return responses


class ChatBot:
    bot = None
    model_name = None
    port = None
    tokenizer = None
    headless = False

    chat_cnt = 0
    prompt_length = []
    output_length = []
    token_length = []
    use_api = False

    MODEL_TYPE = {"browsing": "Browsing\nALPHA", "gpt3": "Default (GPT-3.5)", "gpt3-old": "Legacy (GPT-3.5)",
                  "gpt4": "GPT-4", "deepseek": "Deepseek-R1"}

    @classmethod
    def init(cls):
        cls.chat_cnt = 0
        print("DeepSeek R1 is reasoning!")

    @classmethod
    def call_chat_deepseek(cls, prompt, **kwargs):
        decoding_args = kwargs.pop("decoding_args", None)
        return_list = kwargs.pop("return_list", False)
        return generate_response(prompt, decoding_args, return_list)

