from abc import abstractmethod
from contextlib import contextmanager
from ctypes import CDLL, c_int, create_string_buffer
import gymnasium as gym
import os
from os.path import dirname
import sys

from mas_sat.env.kissat.data_structure import \
    LP_c_char, LP_LP_c_char, \
    KissatState, LP_KissatState
from mas_sat.env.kissat.observation import state_to_observation
from mas_sat.env.kissat.space import KissatSpace
from mas_sat.env.kissat.utils import parse_kissat_log

class KissatBaseEnv(gym.Env):
    def __init__(self, dataset, args):
        super().__init__()

        # components
        self._dataset = dataset

        # hyperparamters
        self._prop_limit = args.prop_limit
        self._step_limit = args.step_limit
        self._prop_weight = args.prop_weight
        self._step_weight = args.step_weight
        self._penalty = args.penalty

        # load the kissat shared library
        pypath = os.path.abspath(__file__)
        libpath = os.path.join(
            dirname(dirname(dirname(dirname(pypath)))),
            "kissat", "build", "libkissat.so"
        )
        self._lib = CDLL(libpath)

        # specify the argument types
        self._lib.main_reset.argtypes = [c_int, LP_LP_c_char]
        self._lib.main_reset.restype = LP_KissatState
        self._lib.main_close.argtypes = []
        self._lib.main_close.restype = c_int

        # spaces
        self.observation_space = KissatSpace()

        # file descriptors
        self._log_name = "logs/{}.log".format(os.getpid())
        self._stdout_fd = sys.stdout.fileno()
        self._stderr_fd = sys.stderr.fileno()
        self._stdout_dup_fd = os.dup(self._stdout_fd)
        self._stderr_dup_fd = os.dup(self._stderr_fd)

    def __del__(self):
        os.close(self._stdout_dup_fd)
        os.close(self._stderr_dup_fd)
        if os.path.isfile(self._log_name):
            os.remove(self._log_name)

    @contextmanager
    def _redirect_output(self):
        try:
            # redirect stdout and stderr to log file
            os.dup2(self._log_fd, self._stdout_fd)
            os.dup2(self._log_fd, self._stderr_fd)
            yield
        finally:
            # restore stdout and stderr
            os.dup2(self._stdout_dup_fd, self._stdout_fd)
            os.dup2(self._stderr_dup_fd, self._stderr_fd)

    def reset(self, idx=None, seed=None, options=None) -> tuple[dict, dict]:
        super().reset(seed=seed, options=options)

        # open the log file
        self._log_f = open(self._log_name, "w")
        self._log_fd = self._log_f.fileno()

        # call reset
        cnf_path = self._dataset.get(idx)
        args = ["kissat", "--compact=false", "--profile=0", cnf_path]
        argc = len(args)
        argv = (LP_c_char * (argc+1))()
        for i, arg in enumerate(args):
            argv[i] = create_string_buffer(arg.encode("utf-8"))
        with self._redirect_output():
            ret = self._lib.main_reset(argc, argv).contents
        observation = state_to_observation(ret, original=True)

        # set internal states
        self._propagations = ret.propagations
        self._steps = 0
        self._has_budget = True

        return observation, {}

    def _post_step(self, state: KissatState) -> tuple[dict, float, bool, bool, dict]:
        reward = -(
            self._step_weight + \
            self._prop_weight * (state.propagations - self._propagations)
        )
        self._propagations = state.propagations
        self._steps += 1

        # test if terminated
        terminated = self._is_terminal(state)
        if terminated:
            return {}, reward, terminated, False, {}
        
        # test if truncated
        truncated = self._is_truncated()
        if truncated:
            reward -= self._penalty

        # post processing
        observation = state_to_observation(state)
        return observation, reward, terminated, truncated, {}
    
    def close(self) -> tuple[bool|None, dict]:
        # close the solver and log file
        with self._redirect_output():
            ret = self._lib.main_close()
        self._log_f.close()
        return parse_kissat_log(self._log_name)

    def len(self) -> int:
        return self._dataset.len()

    @abstractmethod
    def get_dummy_action(self):
        pass

    # endpoint test
    def _is_terminal(self, state: KissatState) -> bool:
        return state.literal_num <= 0
    
    @abstractmethod
    def _is_truncated(self) -> bool:
        pass

    def out_of_budget(self):
        self._has_budget = False
