import numpy as np
from numpy.ctypeslib import as_array
from mas_sat.env.kissat.data_structure import KissatState

def state_to_observation(state: KissatState, original=False) -> dict:
    # decode numbers
    n_literal = state.literal_num
    n_clause = state.clause_num
    n_edge = state.edge_num

    # literal and edge
    if original:
        literal_candidates = np.ones(n_literal, dtype=np.bool_)
        literal_values = np.zeros(n_literal, dtype=np.int8)
        clause_glues = np.zeros(n_clause, dtype=np.uint32)
        clause_refs = np.zeros(n_clause, dtype=np.uint32)
    else:
        literal_candidates = as_array(state.literal_candidates.begin, shape=[n_literal])
        literal_values = as_array(state.literal_values.begin, shape=[n_literal])
        if n_clause == 0:
            clause_glues = np.zeros(0, dtype=np.uint32)
            clause_refs = np.zeros(0, dtype=np.uint32)
        else:
            clause_glues = as_array(state.clause_glues.begin, shape=[n_clause])
            clause_refs = as_array(state.clause_refs.begin, shape=[n_clause])

    # edge
    if n_clause == 0:
        literal_indices = np.zeros(0, dtype=np.uint32)
        clause_indices = np.zeros(0, dtype=np.uint32)
    else:
        literal_indices = as_array(state.literal_indices.begin, shape=[n_edge])
        clause_indices = as_array(state.clause_indices.begin, shape=[n_edge])

    observation = {
        "literal_candidates": np.expand_dims(literal_candidates, 1),
        "literal_values": np.expand_dims(literal_values, 1),
        "clause_refs": np.expand_dims(clause_refs, 1),
        "clause_glues": np.expand_dims(clause_glues, 1),
        "literal_indices": np.expand_dims(literal_indices, 1),
        "clause_indices": np.expand_dims(clause_indices, 1)
    }
    return observation
