from typing import Callable, Generic, Optional, Sequence, TypedDict, TypeVar

import numpy as np


class CamlRng:
    def __init__(self, *, seed: Optional[list[int]] = None) -> None: ...

class Prog:
    def __init__(self, prog: str) -> None: ...
    def __repr__(self) -> str: ...
    def __str__(self) -> str: ...
    def normalize_task(self) -> 'Prog': ...

class TokenizerConfig(TypedDict):
    enable_numerical_edges: bool
    add_reverse_edges: bool

class TensorizerConfig(TypedDict):
    d_model: int
    pos_enc_size: int
    uid_emb_size: int
    const_emb_size: int

class GraphTensors(TypedDict):
    """
    A representation of a Graphable, as exported by the OCaml
    Looprl library (except that the tensors are of type torch.Tensor
    instead of np.array).

    Tensor fields:
        - nodes: shape (num_toks, num_features)
        - edges: shape (num_edges, 3), contains (src, dst, typ) triples
        - pos_emb: shape (num_toks, d_model)

    We use a dictionary representation to avoid bad surprises
    with ray serialization.
    """
    nodes: np.ndarray
    edges: np.ndarray
    pos_emb: np.ndarray

class UidMap:
    def __init__(self) -> None: ...
    def __repr__(self) -> str: ...

class Graphable:
    def __repr__(self) -> str: ...
    def meta(self) -> dict[str, str]: ...
    def graph(self) -> str: ...
    def serialize(self) -> str: ...
    def tensorize(
        self,
        tensorizer_config: TensorizerConfig,
        tokenizer_config: TokenizerConfig,
        uids: UidMap
    ) -> tuple[GraphTensors, UidMap]: ...

T = TypeVar("T", covariant=True)

class SearchTree(Generic[T]):
    def __repr__(self) -> str: ...
    def is_choice(self) -> bool: ...
    def is_chance(self) -> bool: ...
    def is_failure(self) -> bool: ...
    def is_success(self) -> bool: ...
    def is_message(self) -> bool: ...
    def is_event(self) -> bool: ...
    def probe(self) -> Graphable: ...
    def choices(self) -> Sequence[Graphable]: ...
    def weights(self) -> Sequence[float]: ...
    def select(self: 'SearchTree[T]', i: int) -> 'SearchTree[T]': ...
    def success_value(self) -> T: ...
    def event_code(self) -> int: ...
    def failure_message(self) -> str: ...
    def failure_code(self) -> int: ...
    def message(self) -> str: ...
    def next(self: 'SearchTree[T]') -> 'SearchTree[T]': ...
    def serialize(self) -> str: ...

class AgentSpec(TypedDict):
    event_names: list[str]
    outcome_names: list[str]
    event_rewards: list[float]
    outcome_rewards: list[float]
    event_max_occurences: list[int]
    success_code: int
    default_failure_code: int
    size_limit_exceeded_code: int
    min_success_reward: float

teacher_spec: AgentSpec
solver_spec: AgentSpec

class TeacherResult(TypedDict):
    problem: Prog
    nonprocessed: Prog

def init_solver(prog: Prog) -> SearchTree[Prog]: ...
def init_teacher(rng: CamlRng) -> SearchTree[TeacherResult]: ...
def init_teacher_with_spec(
    rng: CamlRng, spec_sexp: str) -> SearchTree[TeacherResult]: ...

def unserialize_teacher(sexp: str) -> SearchTree: ...
def unserialize_solver(sexp: str) -> SearchTree: ...
def unserialize_teacher_probe(sexp: str) -> Graphable: ...
def unserialize_teacher_action(sexp: str) -> Graphable: ...
def unserialize_solver_probe(sexp: str) -> Graphable: ...
def unserialize_solver_action(sexp: str) -> Graphable: ...
def unserialize_formula(sexp: str) -> Graphable: ...

def token_encoding_size(config: TensorizerConfig) -> int: ...
def uid_encoding_offset(config: TensorizerConfig) -> int: ...
def num_edge_types() -> int: ...

def pretraining_tasks_sampler(
    rng: CamlRng,
    true_false: bool
) -> Callable[[], tuple[Graphable, Graphable, Graphable]]: ...
