import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import List, Union

from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable


def compile_func(function, backend):
    tracer = function.trace(backend=backend)
    compiler = CompiledFunction(tracer, function)
    return compiler


class CompiledFunction:
    def __init__(self, tracer, function):
        self.function = function

        self.last_node = CompGraphNode(tracer.last_node)
        self.expr_to_node = {}
        self.build_graph(tracer)
        self.topological_sort()

    def build_graph(self, tracer):
        self.nodes = [self.last_node]
        self.expr_to_node[tracer.last_node] = self.nodes[-1]

        rename_pid = {}

        visited = set([tracer.last_node])
        head = 0
        while head < len(self.nodes):
            cur_node = self.nodes[head]

            # add prev node
            prev_node = cur_node.expr.prev_node
            if prev_node is not None:
                if prev_node not in visited:
                    visited.add(prev_node)
                    self.nodes.append(CompGraphNode(prev_node))
                    self.expr_to_node[prev_node] = self.nodes[-1]
                cur_node.prev_node = self.expr_to_node[prev_node]
                self.expr_to_node[prev_node].add_next_node(cur_node)

            # add source node
            if isinstance(cur_node.expr, SglVariable):
                if cur_node.expr.name in tracer.variables:
                    source = tracer.variables[cur_node.expr.name].source
                else:
                    source = cur_node.expr.source
                if source not in visited:
                    visited.add(source)
                    self.nodes.append(CompGraphNode(source))
                    self.expr_to_node[source] = self.nodes[-1]
                cur_node.source_node = self.expr_to_node[source]
                self.expr_to_node[source].add_next_node(cur_node)
            head += 1

            # rename pid
            if cur_node.expr.pid not in rename_pid:
                rename_pid[cur_node.expr.pid] = len(rename_pid)
            cur_node.expr.pid = rename_pid[cur_node.expr.pid]

    def topological_sort(self):
        prevd = {}
        cand = Queue()
        for x in self.nodes:
            prevd[x] = (x.prev_node is not None) + (x.source_node is not None)
            if prevd[x] == 0:
                cand.put(x)
        new_list = []
        while cand.qsize() > 0:
            head = cand.get()
            new_list.append(head)
            for x in head.next_nodes:
                prevd[x] -= 1
                if prevd[x] == 0:
                    cand.put(x)
        self.nodes = new_list

    def print_graph(
        self,
    ):
        for node in self.nodes:
            print(node)

    def run_internal(
        self,
        backend,
        kwargs,
        default_sampling_para,
    ):
        stream_executor_ids = set([x.expr.pid for x in self.nodes])
        stream_executors = {}
        for x in stream_executor_ids:
            arguments = kwargs if x == self.last_node.expr.pid else {}
            stream_executors[x] = StreamExecutor(
                backend, arguments, default_sampling_para, None, False
            )
        for node in self.nodes:
            se_id = node.expr.pid
            expr = node.expr
            if isinstance(expr, SglVariable):
                # Make a copy for SglVariable
                expr = SglVariable(expr.name, expr.source)
                expr.source_stream_executor = stream_executors[
                    node.source_node.expr.pid
                ]
            elif isinstance(expr, SglArgument):
                # Substitute SglArgument
                expr = kwargs[expr.name]
            stream_executors[se_id].submit(expr)
        for stream_executor in stream_executors.values():
            stream_executor.end()
        return ProgramState(stream_executors[self.last_node.expr.pid])

    def run(
        self,
        *,
        max_new_tokens: int = 128,
        stop: Union[str, List[str]] = (),
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        backend=None,
        **kwargs,
    ):
        backend = backend or global_config.default_backend

        kwargs.update(self.function.bind_arguments)

        default_sampling_para = SglSamplingParams(
            max_new_tokens=max_new_tokens,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
        )

        return self.run_internal(backend, kwargs, default_sampling_para)

    def run_batch(
        self,
        batch_kwargs,
        *,
        max_new_tokens: int = 128,
        stop: Union[str, List[str]] = (),
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = -1,
        min_p: float = 0.0,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        backend=None,
        num_threads: Union[str, int] = "auto",
    ):
        assert isinstance(batch_kwargs, (list, tuple))
        if len(batch_kwargs) == 0:
            return []
        assert isinstance(batch_kwargs[0], dict)

        backend = backend or global_config.default_backend

        default_sampling_para = SglSamplingParams(
            max_new_tokens=max_new_tokens,
            stop=stop,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            min_p=min_p,
            frequency_penalty=frequency_penalty,
            presence_penalty=presence_penalty,
        )

        # Extract prefix by tracing and cache it
        if len(batch_kwargs) > 1:
            cache_program(self.function, backend)

        # Run all programs
        if num_threads == "auto":
            num_threads = multiprocessing.cpu_count()
        num_threads = min(num_threads, len(batch_kwargs))

        if num_threads == 1:
            rets = []
            for arguments in batch_kwargs:
                rets.append(
                    self.run_internal(backend, arguments, default_sampling_para)
                )
        else:
            with ThreadPoolExecutor(num_threads) as executor:
                futures = []
                for arguments in batch_kwargs:
                    futures.append(
                        executor.submit(
                            self.run_internal, backend, arguments, default_sampling_para
                        )
                    )
                rets = [f.result() for f in futures]
            rets[-1].sync()

        return rets


class CompGraphNode:
    def __init__(
        self, expr: SglExpr, prev_node=None, next_nodes=None, source_node=None
    ):
        self.expr = expr
        self.next_nodes = next_nodes or []
        self.prev_node = prev_node
        self.source_node = source_node

    def add_next_node(self, other):
        self.next_nodes.append(other)

    def __repr__(self):
        re = f"stream {self.expr.pid:2d}: "
        re += f"%{self.expr.node_id} = "
        if self.prev_node is not None:
            re += f"%{self.prev_node.expr.node_id} + "
        re += repr(self.expr)
        return re
