from ctypes import c_bool
import inspect
import sys
import os
import time
import itertools

from multiprocess.queues import Empty
import numpy as np

import multiprocess as mp

from tqdm import tqdm

CHANGE_DONE = "DONE"

class InferenceProcess(mp.Process):

    def __init__(self, load_static, on_batch_received, qi, qo):
        super().__init__(daemon=True)
        
        self.ls = load_static
        self.obr = on_batch_received
        
        # qi: send from main, qo: send from subprocesses
        self.qi = qi
        self.qo = qo
        
        # pi: send from main, po: send from subprocess
        self.pi, self.po = mp.Pipe()

        self._do_listen = mp.Value(c_bool)
        self._do_listen.value = True

        self.start()
        
    def set_on_batch_received(self, on_batch_received):
        # find out whether on_batch_received depends on any state outside its scope
        global_vars =  inspect.getclosurevars(on_batch_received).globals
        if self.is_alive():
            #self.obr = on_batch_received
            self.pi.send((on_batch_received, global_vars))
            # wait for answer from process
            while not self.pi.poll():
                time.sleep(0.1)
            e = self.pi.recv()
            if e != CHANGE_DONE:
                raise e
        

    def stop(self):
        self._do_listen.value = False
        self.pi.close()
        self.po.close()
        self.join()
        
    def _run(self):
        import time
        #state = Accelerator()
        
        #print(state.process_index, "starting", state.device)

        start = time.time()
        
        statics = self.ls()

        diff = time.time() - start
        print(f"Loading statics took: {diff:.2f}s", flush=True)
        
        # wait for all processes to finish loading
        #state.wait_for_everyone()

        while self._do_listen.value:
            # check for batch
                # if batch found: run batch
            # check for instruction in pipe
            
            try:
                i, batch = self.qi.get(block=False, timeout=0.1)
               
                # run inference
                result = self.obr(batch, **statics)
                # send back result
                self.qo.put((i, result))

            except Empty:
                pass
            except Exception as e:
                self.po.send(e)
            
            try:
                if self.po.poll():
                    obr, global_vars = self.po.recv()
                    #print(state.process_index, "got new instruction:", obr, global_vars)
                    self.obr = obr
                    # update globals
                    module = sys.modules["__main__"]
                    for name, value in global_vars.items():
                        setattr(module, name, value)
                    
                    # notify main about change done
                    self.po.send(CHANGE_DONE)

            except Exception as e:
                self.po.send(e)

        #print(state.process_index, "terminating")
        
    def run(self):
        try:
            self._run()
        except Exception as e:
            self.po.send(e)

    def get_error(self):
        if not self.pi.closed and self.pi.poll():
            e = self.pi.recv()
            raise e

class InferenceContext():
    def __init__(self, num_gpus=9, gpu_list=None, verbose=0):
        '''
        Use num_gpus to specify number of gpus to automatically allocate or
        gpu_list = [0,1,5, ...] to give a list of gpu indices to use
        
        '''
        if (num_gpus is None and gpu_list is None):
            raise ValueError("Need either num_gpus or gpu_list as not None arguments")

        self.num_gpus = num_gpus if gpu_list is None else len(gpu_list)
        self.gpu_list = list(range(num_gpus)) if gpu_list is None else gpu_list
        self.verbose = verbose
        # manager state
        self.is_initialized = False

        # state of current initialization
        # process context of last started processes
        self.processes = None
        self.groups = None

        # in and out queues
        self.qi, self.qo, = mp.Queue(), mp.Queue()

    def set_on_batch_received(self, on_batch_received):
        if not self.is_initialized:
            raise ValueError("Need to call start() first")
        
        if self.verbose:
            # check dependencies
            global_vars =  inspect.getclosurevars(on_batch_received).globals
            if len(global_vars) > 0:
                print(f"Your `on_batch_received` method relies on global variables. Their current state will be copied into the new processes,"
                      f" but, future changes will not be propagated automatically (yet).\n"
                      f"Here is the list of global variables: {global_vars}")
        for p in self.processes:
            p.set_on_batch_received(on_batch_received)

    def start(self, load_static, on_new_batch, gpus_per_proc=1):
        # todo: auto update global state in subprocesses?
        # todo: control on which gpu a process starts! idea: is it just the local rank?
        # todo: gracefully handle interrupts during inference!
        #       - clean up queues and pipes

        if self.is_initialized:
            raise ValueError("I am already running! Call stop first.")

        if gpus_per_proc > self.num_gpus:
            raise ValueError(f"Too many GPUs requests: {gpus_per_proc} > {self.num_gpus}")
                
        self.processes = []

        # naive placement group computation...
        num_groups = self.num_gpus//gpus_per_proc
        self.groups = [self.gpu_list[i*gpus_per_proc:i*gpus_per_proc+gpus_per_proc] for i in range(num_groups)]
        
        
        old_envs = dict(os.environ)  # or os.environ.copy()
        try:
            for i, gpus in enumerate(self.groups):
                os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in gpus])
                os.environ["LOCAL_RANK"] = "0"
                p = InferenceProcess(load_static, on_new_batch, self.qi, self.qo)
                self.processes.append(p)
        finally:
            os.environ.clear()
            # restore old env vars
            os.environ.update(old_envs)

        self.is_initialized = True

        
    def run_inference(self, data, max_batch_size="distribute"):
        '''
        batch_size: str or int. If "distribute": give each running process similarly many samples
        '''
        if not self.is_initialized:
            raise ValueError("Need to run start first")

        if isinstance(data, str):
            data = [data]

        if not isinstance(data, list):
            raise ValueError("Input must be str or list of samples")

        # batching the data
        if max_batch_size == "distribute":
            indices = [x for x in np.array_split(range(len(data)), len(self.processes))  if len(x) > 0]
            batches = [[data[i] for i in x] for x in indices]
        else:
            num_batches = int(np.ceil(len(data) / max_batch_size))
            batches = [data[i*max_batch_size:(i+1)*max_batch_size] for i in range(num_batches)]
                    
        # create progress bar
        try:
            with tqdm(total=len(batches)) as pbar:
                results = []
                for i, batch in enumerate(batches):
                    self.qi.put((i, batch))

                    # try to read from self.qo until its empty
                    while not self.qo.empty():
                        results.append(self.qo.get())
                        pbar.update(1)

                # read until all batches are returned
                while len(results) != len(batches):
                    results.append(self.qo.get())
                    pbar.update(1)

        except KeyboardInterrupt:
            # clear up queues
            # issue: if we just submitted a long lasting batch to a subprocess
            # we can not make sure that subprocesses are still running on them
            while not self.qi.empty():
                self.qi.get()
            while not self.qo.empty():
                self.qo.get()
            return
                
        # sort results, in case batches got mixed up
        results = sorted(results, key=lambda x: x[0])
        # remove batch index, concat to list
        results = list(itertools.chain(*[x[1] for x in results]))
        return results

    def stop(self):
        if self.is_initialized:
            
            # for future:
            '''
            for c1, c2 in self.pipes:
                c1.close()
                c2.close()
            self.qi.close()
            self.qo.close()
            '''
            for p in self.processes:
                p._do_listen.value = False
            for p in self.processes:
                p.stop()
            self.is_initialized = False
    
    def status(self):
        print("Initialized?", self.is_initialized)
        if self.is_initialized:
            print(self.processes)



from vllm.entrypoints.openai.protocol import ChatCompletionResponse
from vllm import LLM, SamplingParams

def to_chat_completion_response(r, model_name):
    # r is an vllm RequestOutput or list of RequestOutput
    if isinstance(r, list):
        return [to_chat_completion_response(x, model_name) for x in r]

    choices = [dict(index=x.index, message={"role":"assistant", "content":x.text}) for x in r.outputs]
    completion_tokens = sum(len(x.token_ids) for x in r.outputs)
    prompt_tokens = len(r.prompt_token_ids)
    usage = dict(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens = completion_tokens + prompt_tokens)
    return ChatCompletionResponse(model=model_name, choices=choices, usage=usage)

def generate_assistant_turn(sampling_params=None, return_format="list"):
    '''
    return_format can be "list", i.e. list of conversations or "openai", i.e. 
    list of ChatCompletionResponses in OpenAi API compatible format
    '''
    if sampling_params is None:
        sampling_params = SamplingParams()

    def x(batch, model):
        '''
        batch: list of lists containing the chat messages, e.g.:
            [
                [
                    {"role": "user", "content": f"Do both X and Y express the same moral guideline?"},
                ]
                ...
            ]

        Returns: list of list of list of dicts.
        Explanation: 1 conversation is encoded as list of dicts.
        Each covnersation can lead to n answers, depending on sampling params, so we return a list of convos per input convo in batch
        Finally, wrap those above in a list.
        Sizes: batch_size x sampling_params.n x num_turns in each conversation.
        '''
        l = model.chat(batch, sampling_params, use_tqdm=False)
        if return_format == "openai":
            return to_chat_completion_response(l, "no_name")
        if return_format == "list":
            r = []
            for chat, o in zip(batch, l):
                variants = [chat + [{"role":"assistant", "content":x.text}] for x in o.outputs]
                r.append(variants)
            return r
    return x