import base64
import time
from typing import Any

# import openai
# from openai import OpenAI
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from ..types import MessageList, SamplerBase
import os
import openai
from openai import OpenAI
import subprocess
import time
import torch,gc
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
    "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
    + "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)
import numpy as np


class ServerSampler(SamplerBase):
    """
    Sample from OpenAI's chat completion API
    """

    def __init__(
        self,
        model,
        port=8021,
        device=0,
        temperature: float = 0.5,
        max_tokens: int = 1024,
        top_p=0.95,
        num_processes=8
    ):
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.sampling_params = SamplingParams(temperature=temperature,max_tokens=max_tokens,top_p=top_p)
        # self.llm=LLM(model=model)
        self.system_message = OPENAI_SYSTEM_MESSAGE_API
        self.num_processes = num_processes
        self.port = port
        self.device = device
        self.clients = []
        for i in range(self.num_processes):
            self.clients.append( OpenAI(
            api_key='12345', base_url = 'http://localhost:'+'%d'%(self.port+i)+'/v1'#''   
            ))
        self.start_process(model)
        # self.model = model


    def start_process(self,model):
        self.processes = []
        self.model = model
        for i in range(self.num_processes):
            command =  [
            "python", "-m", "vllm.entrypoints.openai.api_server",
            "--model",] +[model]+["--port", '%d'%(self.port+i), "--uvicorn-log-level","debug"]
            env = {**dict(os.environ), 'CUDA_VISIBLE_DEVICES': '%d'%(self.device+i)}
            # process = subprocess.Popen(command,env=env)#, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            process = subprocess.Popen(command,env=env, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)#, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            self.processes.append(process)
            time.sleep(10)
        return 

    def kill_process(self):
        print('Killing process!')
        # self.process.terminate()
        # time.sleep(10)
        # torch.cuda.empty_cache()
        # for _ in range(4):
        #     gc.collect()
        # try:
        #     self.process.wait(timeout=5)
        # except subprocess.TimeoutExpired:
        #     self.process.kill()

        
        for p in self.processes:
            p.terminate()
            time.sleep(5)
            torch.cuda.empty_cache()
            for _ in range(4):
                gc.collect()
            try:
                p.wait(timeout=5)
            except subprocess.TimeoutExpired:
                p.kill()
        self.processes = []

    def _pack_message(self, role: str, content: Any):
        return {"role": str(role), "content": content}
    
    def __call__(self, message_list: MessageList,max_tokens=4096,temperature=None) -> str:
        if self.system_message:
            message_list = [self._pack_message("system", self.system_message)] + message_list
        trial = 0
        while True:
            try:
                client_idx = np.random.randint(self.num_processes)
                response = self.clients[client_idx].chat.completions.create(
                    model=self.model,
                    messages=message_list,
                    temperature=self.temperature if temperature is None else temperature,
                    max_tokens=max_tokens,
                )
                return response.choices[0].message.content,response.usage.completion_tokens,response.usage.prompt_tokens
            # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
            except openai.BadRequestError as e:
                print("Bad Request Error", e)
                return ""
            except Exception as e:
                exception_backoff = 2**trial  # expontial back off
                print(
                    f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
                    e,
                )
                time.sleep(exception_backoff)
                trial += 1
            # unknown error shall throw exception



