"""Tracing a program."""

import uuid
from typing import Any, Callable, Dict, List, Optional, Union

from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
    SglArgument,
    SglCommitLazy,
    SglConcateAndAppend,
    SglConstantText,
    SglExpr,
    SglExprList,
    SglFork,
    SglFunction,
    SglGen,
    SglGetForkItem,
    SglRoleBegin,
    SglRoleEnd,
    SglSelect,
    SglVariable,
    SglVarScopeBegin,
    SglVarScopeEnd,
)


class StopTracing(Exception):
    pass


def extract_prefix_by_tracing(program, backend):
    # Create dummy arguments
    dummy_arguments = {name: SglArgument(name, None) for name in program.arg_names}
    arguments = dummy_arguments
    arguments.update(program.bind_arguments)

    # Trace
    tracer = TracerProgramState(backend, arguments, only_trace_prefix=True)
    try:
        with TracingScope(tracer):
            tracer.ret_value = program.func(tracer, **arguments)
    except (StopTracing, TypeError, AttributeError):
        # Some exceptions may not be catched
        pass

    # Run and cache prefix
    prefix = ""
    for expr in tracer.flatten_nodes():
        if isinstance(expr, SglConstantText):
            prefix += expr.value
        else:
            break
    return prefix


def trace_program(program, arguments, backend):
    # Create dummy backend
    if backend is None:
        backend = BaseBackend()

    # Create dummy arguments
    dummy_arguments = {
        name: SglArgument(name, None)
        for name in program.arg_names
        if name not in arguments
    }
    arguments.update(dummy_arguments)
    arguments.update(program.bind_arguments)

    # Trace
    tracer = TracerProgramState(backend, arguments, only_trace_prefix=False)
    with TracingScope(tracer):
        tracer.ret_value = program.func(tracer, **arguments)
    return tracer


class TracerProgramState(ProgramState):
    def __init__(self, backend, arguments, only_trace_prefix):
        self.pid = uuid.uuid4().hex
        self.backend = backend
        self.arguments: Dict[str, Any] = arguments
        self.only_trace_prefix = only_trace_prefix

        if hasattr(backend, "endpoint"):
            self.backend = backend.endpoint

        self.nodes = []
        self.last_node = None
        self.variables = {}
        self.ret_value = None

        # For completion

        # For chat
        self.messages_ = []
        self.cur_role = None
        self.chat_template = self.backend.get_chat_template()

        # For multi states
        self.child_states = []

        cur_scope = TracingScope.get_current_scope()
        if cur_scope is not None:
            cur_scope.add_child_state(self)

    ##################################
    ########### Public API ###########
    ##################################

    def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
        assert size >= 1

        if self.only_trace_prefix:
            raise StopTracing()

        fork_node = SglFork(size)
        fork_node.prev_node = self.last_node

        states = [
            TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
            for _ in range(size)
        ]

        for i in range(size):
            node = SglGetForkItem(i)
            node.prev_node = fork_node
            states[i].last_node = node
            states[i].variables = dict(self.variables)
            states[i].messages_ = list(self.messages_)
            states[i].cur_role = self.cur_role
            states[i].chat_template = self.chat_template

        state_group = ProgramStateGroup(states, self)

        return state_group

    ##################################
    ########## Internal API ##########
    ##################################

    def _append_node(self, other: SglExpr):
        self.nodes.append(other)
        other.prev_node = self.last_node
        self.last_node = other

    def _execute(self, other: SglExpr):
        if isinstance(other, str):
            other = SglConstantText(other)

        other.pid = self.pid

        if isinstance(other, SglConstantText):
            self._execute_fill(other)
        elif isinstance(other, SglGen):
            self._execute_gen(other)
        elif isinstance(other, SglSelect):
            self._execute_select(other)
        elif isinstance(other, SglExprList):
            for x in other.expr_list:
                self._execute(x)
        elif isinstance(other, SglRoleBegin):
            self._execute_role_begin(other)
        elif isinstance(other, SglRoleEnd):
            self._execute_role_end(other)
        elif isinstance(other, SglVarScopeBegin):
            self._execute_var_scope_begin(other)
        elif isinstance(other, SglVarScopeEnd):
            self._execute_var_scope_end(other)
        else:
            if self.only_trace_prefix:
                raise StopTracing()
            else:
                self._append_node(other)

        return self

    def __iadd__(self, other):
        self._execute(other)
        return self

    def _execute_fill(self, expr: SglConstantText):
        if isinstance(expr, str):
            expr = SglConstantText(expr)
        self._append_node(expr)

    def _execute_gen(self, expr: SglGen):
        name = expr.name if expr.name is not None else "gen_" + str(len(self.variables))
        new_node = SglVariable(name, source=expr)
        self.variables[name] = new_node
        self._append_node(expr)

    def _execute_select(self, expr: SglSelect):
        name = (
            expr.name if expr.name is not None else "select_" + str(len(self.variables))
        )
        new_node = SglVariable(name, source=expr)
        self.variables[name] = new_node
        self._append_node(expr)

    def _execute_role_begin(self, expr: SglRoleBegin):
        assert self.cur_role is None, "Nested roles are not allowed."

        if len(self.messages_) == 0 and expr.role != "system":
            # Insert default system message
            default_system = self.chat_template.default_system_prompt
            if default_system:
                self._execute_role_begin(SglRoleBegin("system"))
                self._execute_fill(default_system)
                self._execute_role_end(SglRoleEnd("system"))

        self.cur_role = expr.role

        prefix, suffix = self.chat_template.get_prefix_and_suffix(
            expr.role, self.messages_
        )

        self._execute_fill(prefix)

    def _execute_role_end(self, expr: SglRoleEnd):
        prefix, suffix = self.chat_template.get_prefix_and_suffix(
            expr.role, self.messages_
        )

        self._execute_fill(suffix)

        self.messages_.append({"role": expr.role, "content": ""})

        self.cur_role = None

    def _execute_var_scope_end(self, expr: SglVarScopeEnd):
        new_node = SglVariable(name, source=self.last_node)
        self.variables[name] = new_node

    def get_var(self, name):
        ret = self.arguments.get(name, None)
        if ret is not None:
            return ret

        v = self.variables[name]
        return SglVariable(v.name, v.source)

    def flatten_nodes(self):
        def traverse(cur):
            if isinstance(cur, SglExprList):
                for child in cur.expr_list:
                    traverse(child)
            else:
                ret.append(cur)

        ret = []
        for x in self.nodes:
            traverse(x)
        return ret

    def __del__(self):
        pass


class TracingScope:
    cur_scope = None

    def __init__(self, tracer_state: TracerProgramState):
        self.tracer_state = tracer_state
        self.last_scope = TracingScope.cur_scope

    def __enter__(self):
        TracingScope.cur_scope = self
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        TracingScope.cur_scope = self.last_scope

    @staticmethod
    def get_current_scope():
        return TracingScope.cur_scope

    def add_child_state(self, state: TracerProgramState):
        cur_scope = self
        while cur_scope is not None:
            cur_scope.tracer_state.child_states.append(state)
            cur_scope = cur_scope.last_scope
