import os
import random
import re

from mas_sat.dataset.cnf import CNFDataset

# sort the cnf files according to numbers in them
def numerical_key(value):
    numbers = map(int, re.findall(r'\d+', value))
    return tuple(numbers)

class SATLIBDataset(CNFDataset):
    def __init__(self, args):
        self.root = os.path.join(args.data_dir, "satlib", args.domain, args.set)
        self.test()
        self.suffix = ".cnf"

        cnf_list = os.listdir(self.root)
        cnf_list = sorted([f for f in cnf_list if f.endswith(".cnf")], key=numerical_key)
        # only split for those sets with >200 instances
        # use 100 instances for valid and test
        if len(cnf_list) > 100:
            self.len_dict = {
                "train": len(cnf_list)-200,
                "valid": 100,
                "test": 100
            }
        else:
            self.len_dict = {
                "train": 0,
                "valid": 0,
                "test": len(cnf_list)
            }
        self.cnf_dict = {
            "train": cnf_list[:self.len_dict["train"]],
            "valid": cnf_list[self.len_dict["train"]:self.len_dict["train"]+self.len_dict["valid"]],
            "test": cnf_list[self.len_dict["train"]+self.len_dict["valid"]:]
        }
