import random
import torch
import sympy


def generate_random_list_clauses(num_clauses, num_literals, clauses_length_range, negated=0.5):
    """
    Generates a random list of clauses for a CNF/DNF formula
    Parameters:
        num_clauses (int): Number of clauses in the formula
        num_literals (int): Number of distinct literals
        clauses_length_range (Tuple[int, int]): Min and max range for the number of literals per clause
        negated (float): Probability of negating a literal (between 0 and 1)
    Returns:
        List[List[int]]: A list of clauses, where each clause is a list of integers representing literals
    """
    formula = []
    for _ in range(num_clauses):
        clause_length = random.randint(clauses_length_range[0], clauses_length_range[1]) # Random length of the clause
        clause = random.sample(list(range(1, num_literals + 1)), clause_length)  # Select unique literals
        clause = [lit if random.random() < negated else -lit for lit in clause]  # Randomly negate literals
        formula.append(clause)
    return formula

def assign_proposition_names(list_clauses, proposition_names=None, cnf=True):
    """
    Assigns names to the propositions in the set of clauses
    Parameters:
        list_clauses (List[List[int]]): The CNF formula
        proposition_names (List[str]): List of proposition names
        cnf (bool): If True, the formula is CNF; if False, the formula is DNF
    Returns:
        List[List[str]]: The CNF formula as a list of lists of named propositions
        List[str]:  The CNF formula as a list of clauses
        str: The CNF formula as a string
    """
    if proposition_names is None:
        proposition_names = [f"P_{i+1}" for i in range(len(list_clauses[0]))]
    named_lists = []
    named_strs = []
    for clause in list_clauses:
        named_clause = []
        for literal in clause:
            if literal > 0:
                named_clause.append(proposition_names[literal - 1])
            else:
                named_clause.append(f"~{proposition_names[-literal - 1]}")
        named_lists.append(named_clause)
        if cnf:
            named_strs.append("(" + " | ".join(named_clause) + ")")
        else:
            named_strs.append("(" + " & ".join(named_clause) + ")")
    if cnf:
        named_formula = " & ".join(named_strs)
    else:
        named_formula = " | ".join(named_strs)
    return named_lists, named_strs, named_formula


def generate_all_assignments(num_literals):
    """
    Generates all possible assignments of truth values to a set of literals
    Parameters:
        num_literals (int): Number of distinct literals
    Returns:
        pytorch.Tensor: A tensor of shape (2^num_literals, num_literals) representing all assignments
    """
    num_assignments = 2 ** num_literals
    assignments = torch.zeros((num_assignments, num_literals), dtype=torch.bool)
    for i in range(num_assignments):
        binary_repr = format(i, f'0{num_literals}b')
        assignments[i] = torch.tensor([int(bit) for bit in binary_repr], dtype=torch.bool)
    return assignments


def evaluate_formula(tensor, clauses, formula_type='cnf', proposition_names=None):
    """
    Evaluates a list of clause over all rows of a boolean tensor in PyTorch
    Parameters:
        tensor (torch.Tensor): Boolean PyTorch tensor of shape (N, M), where each row is a boolean assignment
        clauses (List[List[int]] or str): formula given as clauses for CNF/DNF, otherwise sympy string
        formula_type (str): 'cnf', 'dnf', or 'any' to evaluate the formula type
        proposition_names (List[str]): List of proposition names when formula_type is 'any'
    Returns:
        torch.Tensor: Boolean tensor of shape (N,) indicating whether each row satisfies the CNF/DNF formula
    """
    
    def evaluate_cnf(tensor, list_clauses):
        N, M = tensor.shape
        results = torch.ones(N, dtype=torch.bool, device=tensor.device)
        for clause in list_clauses:
            clause_results = torch.zeros(N, dtype=torch.bool, device=tensor.device)
            for literal in clause:
                var_idx = abs(literal) - 1
                if literal > 0:
                    clause_results |= tensor[:, var_idx]
                else:
                    clause_results |= ~tensor[:, var_idx]
            results &= clause_results
        return results
    
    def evaluate_dnf(tensor, list_clauses):
        N, M = tensor.shape
        results = torch.zeros(N, dtype=torch.bool, device=tensor.device)
        for clause in list_clauses:
            clause_results = torch.ones(N, dtype=torch.bool, device=tensor.device)
            for literal in clause:
                var_idx = abs(literal) - 1
                if literal > 0:
                    clause_results &= tensor[:, var_idx]
                else:
                    clause_results &= ~tensor[:, var_idx]
            results |= clause_results
        return results
    
    def evaluate_any(tensor, formula, proposition_names):
        formula = sympy.parse_expr(formula)
        predictions = []
        for sample in tensor:
            substitutions = {}
            for i, value in enumerate(sample):
                substitutions[proposition_names[i]] = value.item()
            predictions.append(bool(formula.subs(substitutions)))
        return torch.tensor(predictions, dtype=torch.bool, device=tensor.device)
    
    if formula_type == 'cnf':
        results = evaluate_cnf(tensor, clauses)
    elif formula_type == 'dnf':
        results = evaluate_dnf(tensor, clauses)
    else:
        results = evaluate_any(tensor, clauses, proposition_names)
    return results


def create_dataset(num_literals, clauses, formula_type='cnf', proposition_names=None):
    """
    Creates a dataset of all possible assignments and their evaluations for a given formula
    Parameters:
        num_literals (int): Number of distinct literals
        clauses (List[List[int]] or str): formula given as clauses for CNF/DNF, otherwise sympy string
        formula_type (str): 'cnf', 'dnf', or 'any' to evaluate the formula type
        proposition_names (List[str]): List of proposition names when formula_type is 'any'
    Returns:
        torch.Tensor: tensor of shape (2^num_literals, num_literals) representing all assignments
        torch.Tensor: tensor of shape (2^num_literals,) indicating whether each assignment satisfies the formula
    """
    tensor = generate_all_assignments(num_literals)
    results = evaluate_formula(tensor, clauses, formula_type, proposition_names)
    return tensor, results