from copy import deepcopy
from typing import Optional


class DPLLSolver:
    def __init__(self, cnf):
        """
        initializes the DPLL Solver with a given CNF (Conjunctive Normal Form) input.

        :param cnf: a string representing the CNF, where each clause is on a new line,
                    literals are separated by spaces, negation is denoted by '!', and
                    variables are single characters.
        """
        self.assign_true = set()  # set of literals assigned True
        self.assign_false = set()  # set of literals assigned False
        self.n_props = 0          # count of propositions made
        self.n_splits = 0         # count of splits (decisions) made
        self.cnf = cnf            # the CNF input

    def print_cnf(self):
        """
        prints the CNF in a more readable format, where clauses are enclosed in parentheses
        and literals are separated by '+'.
        """
        s = ''
        for i in self.cnf:
            if len(i) > 0:
                s += '(' + i.replace(' ', '+') + ')'
        print(s)

    def solve(self, cnf, literals):
        """
        recursively solves the CNF using the DPLL algorithm.

        :param cnf: the CNF in its current state (as clauses get simplified).
        :param literals: list of literals that haven't been assigned yet.
        :return: True if the CNF is satisfiable, False otherwise.
        """
        new_true = []  # literals assigned True in this decision level
        new_false = []  # literals assigned False in this decision level
        self.n_splits += 1
        cnf = list(set(cnf))  # remove duplicate clauses
        units = [i for i in cnf if len(i) < 3]  # unit clauses
        units = list(set(units))  # remove duplicate units

        # Unit Propagation
        if len(units):
            for unit in units:
                self.n_props += 1
                if '!' in unit:
                    self.assign_false.add(unit[-1])
                    new_false.append(unit[-1])
                    # simplify the CNF by removing clauses and literals
                    i = 0
                    while True:
                        if unit in cnf[i]:
                            cnf.remove(cnf[i])
                            i -= 1
                        elif unit[-1] in cnf[i]:
                            cnf[i] = cnf[i].replace(unit[-1], '').strip()
                        i += 1
                        if i >= len(cnf):
                            break
                else:
                    self.assign_true.add(unit)
                    new_true.append(unit)
                    i = 0
                    while True:
                        if '!'+unit in cnf[i]:
                            cnf[i] = cnf[i].replace('!'+unit, '').strip()
                        elif unit in cnf[i]:
                            cnf.remove(cnf[i])
                            i -= 1
                        i += 1
                        if i >= len(cnf):
                            break

        # check if CNF is solved
        if len(cnf) == 0:
            return True

        # check for an empty clause (unsatisfiable)
        if sum(len(clause) == 0 for clause in cnf):
            # Undo assignments made in this decision level
            for i in new_true:
                self.assign_true.remove(i)
            for i in new_false:
                self.assign_false.remove(i)
            return False

        # splitting (choose a literal and try both True and False)
        literals = [k for k in list(set(''.join(cnf))) if k.isalpha()]
        x = literals[0]
        if self.solve(deepcopy(cnf) + [x], deepcopy(literals)):
            return True
        elif self.solve(deepcopy(cnf) + ['!' + x], deepcopy(literals)):
            return True
        else:
            # undo assignments made in this decision level
            for i in new_false:
                self.assign_false.remove(i)
            return False

    def dpll(self) -> Optional[dict]:
        """
        public method to solve the CNF using the DPLL algorithm.

        :return: a dictionary mapping each literal to its boolean value if the CNF is satisfiable,
                 or None if the CNF is unsatisfiable.
        """
        literals = [i for i in list(set(self.cnf)) if i.isalpha()]
        cnf = self.cnf.splitlines()
        res = {}
        if self.solve(cnf, literals):
            # assign values to literals based on the assignments made during solving
            for i in self.assign_true:
                res[i] = True
            for i in self.assign_false:
                res[i] = False

            return res
        else:
            return None  # unsat!