# Code to generate variable number of clauses for SAT and UNSAT CNF Problems training
import random
import os
import tempfile
import subprocess
import time 

class CNFProblem:
    def __init__(self, num_clauses: int, num_literals: int):
        self.num_clauses = num_clauses
        self.num_literals = num_literals
        self.problem = [self.get_unsat_problem(), self.get_sat_problem()]

    def get_unsat_problem(self):
        status = True
        problem = []
        for i in range(self.num_clauses):
            new_clause = self.get_clause()
            if not self.check_valid_problem(problem + [new_clause]):
                status = False
            problem.append(new_clause)
        
        if not status:
            return problem
        
        problem = self.get_sat_problem()
        shortest = min(problem, key = len)
        for literal in shortest:
            problem.append([-literal])
        problem.sort(key = len)
        while len(problem)>self.num_clauses:
            problem.pop()

        return problem

    def get_sat_problem(self):
        problem = []
        for i in range(self.num_clauses):
            new_clause = self.get_clause()
            while not self.check_valid_problem(problem + [new_clause]) or new_clause in problem:
                new_clause = self.get_clause()
            else:
                problem.append(new_clause)

        return problem

    def get_clause(self):
        literals = set([i for i in range(1,self.num_literals+1)])
        num_literals = random.randint(1,self.num_literals)
        signs = [-1,1]
        clause = []
        for literal in range(num_literals):
            sign = random.choice(signs)
            literal = random.choice(tuple(literals))
            literals.remove(literal)
            clause.append(sign*literal)
        return clause

    def check_valid_problem(self, problem):
        minisat_output = self.get_minisat_output(problem)
        return minisat_output[0] == 'SAT'
    
    def get_minisat_output(self, problem):
        minisat_output = []
        directory = os.getcwd() + '/minisat/core/'
        with tempfile.NamedTemporaryFile(dir = directory, suffix = '.txt', mode = 'w+') as tmpfile:
            filename = tmpfile.name
            minisat_input = self.__problem_to_minisat_input(problem)
            
            tmpfile.write(minisat_input)
            tmpfile.seek(0)

            with tempfile.NamedTemporaryFile(dir = directory, suffix = '.txt', mode = 'w+') as outfile:
                outfile_name = outfile.name
                p = subprocess.Popen(['./minisat', str(filename), '-no-luby', '-rinc=1.5', '-phase-saving=0', '-rnd-freq=0.02', str(outfile_name)], cwd = directory)
                p.wait()
                minisat_output =[line.rstrip('\n') for line in outfile]

        return minisat_output

    def get_sat_problem_as_adj(self):
        return self.__as_adj(self.problem[1])
    
    def get_unsat_problem_as_adj(self):
        return self.__as_adj(self.problem[0])

    def get_unsat_sat_pair(self):
        return self.problem

    def get_label_assignment(self):
        label = [[0 for i in range(2*self.num_literals)] for j in range(self.num_clauses)]
        solution = set([int(i) for i in self.get_minisat_output(self.problem[1])[1].split(' ')])
        for clause in range(self.num_clauses):
            for literal in self.problem[1][clause]:
                if literal in solution:
                    if literal>0:
                        label[clause][literal-1] = 1
                    else:
                        label[clause][self.num_literals+abs(literal)-1] = 1
        label = self.pad_2d_list(label, 30, 2)
        return label

    def get_label_as_adj(self):
        text = ''
        for i in self.get_label_assignment():
            text += ' '.join(str(j) for j in i) + '\n'
        text += '\n'
        return text
    
    def __problem_to_minisat_input(self, problem):
        text = f'p cnf {len(problem)} {self.num_literals}\n'
        for clause in problem:
            text += ' '.join(str(ele) for ele in clause) + ' 0\n'
        return text


    def pad_2d_list(self, lst, target_size, padding_value):
        padded_list = []
        for row in lst:
            padded_row = row + [padding_value] * (target_size - len(row))
            padded_list.append(padded_row)
        
        while len(padded_list) < target_size:
            padded_list.append([padding_value] * target_size)
        
        return padded_list


    def __as_adj(self, problem):
        matrix = [[0 for i in range(2*self.num_literals)] for j in range(self.num_clauses)]
        for i in range(self.num_clauses):
            for j in problem[i]:
                if j < 0:
                    matrix[i][self.num_literals+abs(j)-1] = 1
                else:
                    matrix[i][j-1] = 1

        matrix = self.pad_2d_list(matrix, 30, 2)
        
        text = ''
        for i in matrix:
            text += ' '.join(str(j) for j in i) + '\n'
        text += '\n'
        return text

# obj = CNFProblem(50, 10)

# sat_problem = obj.get_sat_problem_as_adj()
# print(sat_problem)
# print(len(sat_problem), len(sat_problem[0]))


with open('test9.txt','a') as f:
    with open('test_labels9.txt','a') as f1:            
        for i in range(10000):
            problem = CNFProblem(13,13) # number of clauses and literals
            sat_problem = problem.get_sat_problem_as_adj()
            # unsat_problem = problem.get_unsat_problem_as_adj()
            f.write(sat_problem)
            # f.write(unsat_problem)
            # unsat = problem.check_valid_problem(problem.get_unsat_sat_pair()[0])
            sat = problem.check_valid_problem(problem.get_unsat_sat_pair()[1])
            
            if sat == True:
                label = problem.get_label_as_adj()
                f1.write(label)
                # f1.write('1'+'\n')
            else:
                print('Miss')
                break
            
            # if unsat == False:
            #     f1.write('0'+'\n')
            # else:
            #     print('Miss')
            #     break
# pair = obj.get_unsat_sat_pair()
# print(obj.get_minisat_output(pair[0]), obj.get_minisat_output(pair[1]))