from ctypes import c_uint, c_int
from gymnasium import spaces
import numpy as np
from typing import Any, SupportsFloat

from mas_sat.env.kissat.data_structure import LP_c_uint, LP_KissatState
from mas_sat.env.kissat.base import KissatBaseEnv

class KissatDecideEnv(KissatBaseEnv):
    def __init__(self, dataset, args) -> None:
        super().__init__(dataset, args)
        if args.agent in ["solver"]:#, "model_influence"]:
            self.decide_mode = 1
        else:
            self.decide_mode = 2
        self.decide_interval = args.decide_interval

        # specify the argument types
        self._lib.main_influence.argtypes = [LP_c_uint, c_uint]
        self._lib.main_influence.restype = None
        self._lib.main_decide_step.argtypes = [c_uint, c_int, c_int]
        self._lib.main_decide_step.restype = LP_KissatState

        # spaces
        self.action_space = spaces.Box(low=0, high=2**32-1, dtype=np.uint32)

    def step(self, action) -> tuple[dict, float, bool, bool, dict]:
        if self._steps == 0:
            mode = 0 # dummy call
            interval = 1
        else:
            mode = self.decide_mode
            if self._has_budget:
                interval = self.decide_interval
            else:
                interval = -1

        with self._redirect_output():
            state = self._lib.main_decide_step(action, mode, interval).contents
        return self._post_step(state)

    def influence(self, indices, num):
        self._lib.main_influence(indices, num)

    def get_dummy_action(self):
        return 0

    def _is_truncated(self) -> bool:
        if self._prop_limit > 0 and self._propagations > self._prop_limit:
            return True
        if self._step_limit > 0 and self._steps > self._step_limit:
            return True
        return False

    # def compare(self, solver_step, propagation, model_step):
    #     if "decisions" in self.metadata and solver_step is not None:
    #         solver_step_reduction = self.metadata["decisions"] / solver_step
    #     else:
    #         solver_step_reduction = 1
        
    #     if "propagations" in self.metadata and propagation is not None:
    #         propagation_reduction = self.metadata["propagations"] / propagation
    #     else:
    #         propagation_reduction = 1

    #     if "model_step" in self.metadata and model_step is not None:
    #         model_step_reduction = self.metadata["model_step"] / model_step
    #     else:
    #         model_step_reduction = 1

    #     return solver_step_reduction, propagation_reduction, model_step_reduction
