
# pylint: disable = line-too-long

import sys, os
import signal
import random
import re

import spot
import spot.gen as sg

import deepltl.utils.utils as utils
import deepltl.data.ltl_parser as ltl_parser


SPOT_WORKER: utils.PersistentWorker = None
STATISTICS = {'timeout': 0, 'unsat': 0, 'max_length': 0, 'max_parts': 0}

def spot_get_trace(formula_str, simplify):
    spot_formula = spot.formula(formula_str)
    automaton = spot_formula.translate()
    automaton.merge_edges()
    acc_run = automaton.accepting_run()
    if acc_run is None:
        return False, None
    else:
        trace = spot.twa_word(acc_run)
        if simplify:
            trace.simplify()
        return True, str(trace)


def add_part(formulas_so_far, solutions_so_far, patterns, aps, max_parts, max_length, timeout):
    global STATISTICS
    if len(formulas_so_far) >= max_parts:
        STATISTICS['max_parts'] += 1
        return formulas_so_far, solutions_so_far, 'max_parts', None

    new_pattern: str = patterns[random.randint(0, len(patterns)-1)]
    used_aps = ''.join(set(re.sub(r'[^a-z]', '', new_pattern)))
    new_aps = aps[:]
    random.shuffle(new_aps)
    new_aps = new_aps[:len(used_aps)]
    translation = str.maketrans(used_aps, ''.join(new_aps)) # map existing aps to random new aps
    new_pattern = new_pattern.translate(translation)
    new_pattern_f = ltl_parser.ltl_formula(new_pattern, 'spot')
    new_formula = ltl_parser.F_AND(formulas_so_far[-1], new_pattern_f) if formulas_so_far else new_pattern_f

    if new_formula.size() > max_length:
        STATISTICS['max_length'] += 1
        return formulas_so_far, solutions_so_far, 'max_length', new_formula
    finished, res = SPOT_WORKER.call(spot_get_trace, (new_formula.to_str('spot'), False), timeout)
    if not finished:
        STATISTICS['timeout'] += 1
        return formulas_so_far, solutions_so_far, 'timeout', new_formula
    sat, trace = res
    if not sat:
        STATISTICS['unsat'] += 1
        return formulas_so_far, solutions_so_far, 'unsat', new_formula
    return add_part(formulas_so_far + [new_formula], solutions_so_far + [ltl_parser.ltl_trace(trace, 'spot')], patterns, aps, max_parts, max_length, timeout)


def gen_dac_patterns():
    if len(sys.argv) > 1:
        ds_name = sys.argv[1]
    else:
        ds_name = 'pattern'
    output_dir = os.path.join('data', ds_name)

    pattern_strings = [q.relabel(spot.Abc).to_str() for q in sg.ltl_patterns(sg.LTL_DAC_PATTERNS)]
    aps = ['a', 'b', 'c', 'd', 'e', 'f']
    max_parts = 8
    max_length = 126
    timeout = 1
    outfile_edge_name = 'patterns.txt'
    outfile_timeout_name = 'timeouts.txt'

    interrupted = False
    def signal_handler(signal, frame):
        nonlocal interrupted
        print(f"signal {signal:d} received, exiting")
        interrupted = True
    signal.signal(signal.SIGINT, signal_handler)
    #signal.signal(signal.SIGTERM, signal_handler)
    print("Main process PID {:d}".format(os.getpid()))
    global SPOT_WORKER
    SPOT_WORKER = utils.PersistentWorker()

    with open(os.path.join(output_dir, outfile_edge_name), 'w') as outfile_edge, open(os.path.join(output_dir, outfile_timeout_name), 'w') as outfile_timeout:
        total_count = 0
        while not interrupted:
            total_count += 1
            all_formulas, all_traces, termreason, new_formula = add_part([], [], pattern_strings, aps, max_parts, max_length, timeout)
            if not all_formulas:
                print('not a single part was generated')
                continue
            outfile_edge.write(all_formulas[-1].to_str('network-polish') + '\n' + all_traces[-1].to_str('network-polish') + '\n')
            if termreason == 'timeout':
                outfile_timeout.write(new_formula.to_str('network-polish') + '\n\n')
            if total_count % 10000 == 0:
                print("{total_count} samples generated: {max_length} max length, {max_parts} max parts, {unsat} unsat, {timeout} timeout".format(total_count=total_count, **STATISTICS), flush=True)
        print("{total_count} samples generated: {max_length} max length, {max_parts} max parts, {unsat} unsat, {timeout} timeout".format(total_count=total_count, **STATISTICS))
    SPOT_WORKER.terminate()
    print('Done.')


if __name__ == '__main__':
    gen_dac_patterns()
