from abc import abstractmethod
import os
import random

from mas_sat.env.kissat.utils import parse_kissat_log, solve_kissat_mp

class CNFDataset(object):
    def __init__(self, args):
        self.root = os.path.join(args.data_dir, args.dataset, args.set)
        self.test()
        self.suffix = ".cnf"

    # basic get/set APIs
    def reset(self):
        self.counter = 0
        self.finish = False

    def train(self):
        self.mode = "train"

    def valid(self):
        self.mode = "valid"
        self.reset()

    def test(self):
        self.mode = "test"
        self.reset()

    def len(self):
        return self.len_dict[self.mode]
    
    def is_finish(self):
        return self.finish
    
    # main APIs
    @abstractmethod
    def generate(self):
        pass

    def solve(self):
        results_dict = {}
        print("="*20, "Solve", "="*20)
        for subset in ["train", "valid", "test"]:
            print("-"*20, f"Solve for {subset}", "-"*20)
            fnames = [os.path.join(self.root, f)
                      for f in self.cnf_dict[subset]]
            results = solve_kissat_mp(fnames, 32)
            results_dict[subset] = results
            n_sat = 0
            n_unsat = 0
            for result in results:
                if result is None:
                    continue
                elif result:
                    n_sat += 1
                else:
                    n_unsat += 1
            print("summary: {}/{} sat/unsat instances".format(n_sat, n_unsat))
        return results_dict

    def get(self, idx: int|None=None):
        if self.mode == "train":
            # by default, randomly sample train instances
            if idx is None:
                idx = random.randint(0, self.len()-1)
        else:
            # by default, sequentially sample valid/test instances
            if idx is None:
                idx = self.counter
                self.counter += 1
                if self.counter >= self.len():
                    self.counter = 0
                    self.finish = True
        cnf_fname = self.cnf_dict[self.mode][idx]
        fname = os.path.join(self.root, cnf_fname)
        return fname