from abc import abstractmethod
from cnfgen import RandomKCNF, GraphColoringFormula, CliqueFormula
import math
import networkx as nx
from shutil import copy2
import os
from tqdm import tqdm

from mas_sat.dataset.cnf import CNFDataset

class CNFGenDataset(CNFDataset):
    """
    CNF Dataset generated by CNFGen
    """
    def __init__(self, args):
        super().__init__(args)
        self.len_dict = {
            "train": args.num_train,
            "valid": args.num_valid,
            "test": args.num_test
        }
        if args.command == "generate":
            # generate 3x to make sure have enough sat/unsat instances
            self.len_dict = {k: v*3 for k, v in self.len_dict.items()}
        self.cnf_dict = {
            k: [f"cnf_{args.split}/{k}/{i}.cnf" for i in range(v)]
            for k, v in self.len_dict.items()
        }

    @abstractmethod
    def generate_instance(self):
        pass

    def generate(self):
        # step 1: generate
        print("="*20, "Generate", "="*20)
        for subset in ["train", "valid", "test"]:
            self.cnf_dict[subset] = []
            num = self.len_dict[subset]
            cnf_dir = os.path.join(self.root, "cnf_both", subset)
            if not os.path.isdir(cnf_dir):
                os.makedirs(cnf_dir)
            for idx in tqdm(range(num), desc=f"Generating {subset}"):
                F = self.generate_instance()
                cnf_fname = os.path.join(cnf_dir, "{}.cnf".format(idx))
                self.cnf_dict[subset].append(f"cnf_both/{subset}/{idx}.cnf")
                with open(cnf_fname, "w") as f:
                    f.write(F.to_dimacs())

        # step 2: solve
        results_dict = self.solve()

        # step 3: copy (seperate the sat/unsat instances)
        print("="*20, "Copy", "="*20)
        for subset, results in results_dict.items():
            fnames = [os.path.join(self.root, f)
                      for f in self.cnf_dict[subset]]
            sat_dir = os.path.join(self.root, "cnf_sat", subset)
            if not os.path.isdir(sat_dir):
                os.makedirs(sat_dir)
            unsat_dir = os.path.join(self.root, "cnf_unsat", subset)
            if not os.path.isdir(unsat_dir):
                os.makedirs(unsat_dir)
            n_sat = 0
            n_unsat = 0
            for result, fname in tqdm(zip(results, fnames),
                                      total=len(fnames),
                                      desc=f"Copying {subset}"):
                if result is None:
                    continue
                elif result:
                    # pass
                    sat_fname = os.path.join(
                        sat_dir, "{}.cnf".format(n_sat))
                    copy2(fname, sat_fname)
                    copy2(fname[:-4]+".log", sat_fname[:-4]+".log")
                    n_sat += 1
                else:
                    unsat_fname = os.path.join(
                        unsat_dir, "{}.cnf".format(n_unsat))
                    copy2(fname, unsat_fname)
                    copy2(fname[:-4]+".log", unsat_fname[:-4]+".log")
                    n_unsat += 1

class RandomSATDataset(CNFGenDataset):
    def __init__(self, args):
        self.k, self.n_variable, self.n_clause =\
            (int(n) for n in args.set.split("-"))
        super().__init__(args)

    def generate_instance(self):
        return RandomKCNF(self.k, self.n_variable, self.n_clause)
