import pycosat
import random
import numpy as np
import argparse
import os
import string
import json

# Argument parser for configuration
parser = argparse.ArgumentParser(description='Synthetic Transformer Data Generator')
parser.add_argument('--file-path', type=str, default="data/plain")
parser.add_argument('--file-name', type=str, default="data.json")
parser.add_argument('--dataset-size', type=int, default=65536)
parser.add_argument('--num-vars', type=int, default=4, help="Number of variables in 3-SAT clauses")
parser.add_argument('--num-clauses', type=int, default=20, help="Number of clauses in 3-SAT problem")
parser.add_argument('--target-length', type=int, default=5, help="Length of the data string")
parser.add_argument('--with-feedback', action='store_true', help="Generate data for the feedback setting")
parser.add_argument('--sat-ratio', type=float, default=0.9, help="Ratio of satisfiable to unsatisfiable expressions")


args = parser.parse_args()
np.random.seed(2024)
random.seed(2024)

# 3-SAT generation functions
def generate_clause(num_vars):
    """Generate a 3-SAT clause with 3 variables."""
    return [random.choice([i, -i]) for i in random.sample(range(1, num_vars + 1), 3)]

def generate_expression(num_vars, num_clauses):
    """Generate a 3-SAT expression consisting of a specified number of clauses."""
    return [generate_clause(num_vars) for _ in range(num_clauses)]

def check_satisfiability(expression):
    """Check the satisfiability of a 3-SAT expression, returning True or False."""
    return pycosat.solve(expression) != "UNSAT"

def clause_to_string(clause):
    """Convert a single clause into string format, e.g., [1, -2, 3] -> "(x1|~x2|x3)"."""
    clause_str = []
    for var in clause:
        if var < 0:
            clause_str.append(f"~x{-var}")
        else:
            clause_str.append(f"x{var}")
    return f"({'|'.join(clause_str)})"

def expression_to_string(expression):
    """Convert the entire 3-SAT expression into string format."""
    return "&".join([clause_to_string(clause) for clause in expression])

# Generate a random data string
def generate_data_string(length, alphabet=['a', 'b']):
    # One can also consider alphabet = list(string.ascii_lowercase);
    # But this can be more challenging for the model to learn.
    return ''.join(np.random.choice(alphabet) for _ in range(length))

# Pre-generate a pool of satisfiable and unsatisfiable expressions
def pre_generate_expressions(
    num_vars, 
    num_clauses, 
    pool_size, 
    satisfiable_pool_size=None, 
    unsatisfiable_pool_size=None
):
    satisfiable = []
    unsatisfiable = []
    satisfiable_pool_size = -1 if satisfiable_pool_size is None else satisfiable_pool_size
    unsatisfiable_pool_size = -1 if unsatisfiable_pool_size is None else unsatisfiable_pool_size

    while (
        len(satisfiable) + len(unsatisfiable) < pool_size or
        len(satisfiable) < satisfiable_pool_size or
        len(unsatisfiable) < unsatisfiable_pool_size
    ):
        expression = generate_expression(num_vars, num_clauses)
        if check_satisfiability(expression):
            satisfiable.append(expression)
        else:
            unsatisfiable.append(expression)
    return satisfiable, unsatisfiable

# Generate a single data point
def generate_data_point(expression, satisfiable, with_feedback=False):
    instruction = expression_to_string(expression)
    data = generate_data_string(args.target_length)

    if satisfiable:
        target = data  # Output the string as is
    else:
        target = data[::-1]  # Reverse the string

    if with_feedback:
        first_attempt = random.choice([True, False])
        if first_attempt:
            return {
                "instruction": instruction,
                "input": f'{data}#',
                "response": target
            }
        else:
            return {
                "instruction": instruction,
                "input": f'{data}#{target[::-1]}0',
                "response": target
            }
    else:
        return {
            "instruction": instruction,
            "input": f'{data}#',
            "response": target
        }

# Generate dataset and save to file
def generate_dataset(file_path, size, satisfiable_pool, unsatisfiable_pool):
    data = []
    num_satisfiable = int(size * args.sat_ratio)
    num_unsatisfiable = size - num_satisfiable

    selected_satisfiable = satisfiable_pool[:num_satisfiable]
    selected_unsatisfiable = unsatisfiable_pool[:num_unsatisfiable]

    for idx, expr in enumerate(selected_satisfiable):
        data_point = generate_data_point(expr, satisfiable=True, with_feedback=args.with_feedback)
        data_point["id"] = f"{idx+1}"
        data.append(data_point)

    for idx, expr in enumerate(selected_unsatisfiable):
        data_point = generate_data_point(expr, satisfiable=False, with_feedback=args.with_feedback)
        data_point["id"] = f"{num_satisfiable+idx+1}"
        data.append(data_point)

    with open(file_path, 'w') as data_file:
        json.dump(data, data_file, indent=2, ensure_ascii=False)

    print(f"Dataset saved to {file_path}")

# Main function to generate datasets
def main():
    print("Pre-generating 3-SAT expressions...")
    satisfiable_pool, unsatisfiable_pool = pre_generate_expressions(
        args.num_vars,
        args.num_clauses,
        pool_size=args.dataset_size,
        satisfiable_pool_size=int(args.dataset_size * args.sat_ratio),
        unsatisfiable_pool_size=int(args.dataset_size * (1 - args.sat_ratio))
    )

    print("Generating training data...")
    os.makedirs(args.file_path, exist_ok=True)
    generate_dataset(
        os.path.join(args.file_path, args.file_name),
        args.dataset_size,
        satisfiable_pool,
        unsatisfiable_pool
    )

if __name__ == "__main__":
    main()
