import random
import time
import subprocess
from subprocess import TimeoutExpired

# ================== NEEDS TO BE MODIFIED BEFORE USE ==================

# Add the *absolute* path to the Scorpion, MyND, and oink tools:
scorpion_path = ''
mynd_path = ''
oink_path = ''

# ================== END - NEEDS TO BE MODIFIED BEFORE USE ==================

instances = [
    {
        'path': './systems/bakery_3procs.smv',
        'aps' : ['STARTED', 'p1-TOKEN', 'p2-TOKEN', 'p3-TOKEN']
    },
    {
        'path': './systems/bakery_5procs.smv',
        'aps' : ['p1-TOKEN', 'p2-TOKEN', 'p3-TOKEN', 'p4-TOKEN', 'p5-TOKEN']
    },
    {
        'path': './systems/mutation_testing.smv',
        'aps' : ['mutation', 'NO_water', 'NO_output']
    },
    {
        'path': './systems/NI_correct.smv',
        'aps' : ['trigger_alpha', 'trigger_beta', 'halt']
    },
    {
        'path': './systems/NI_incorrect.smv',
        'aps' : ['trigger_alpha', 'trigger_beta', 'halt']
    },
    {
        'path': './systems/NRP_correct.smv',
        'aps' : ['sender_actions = 0', 'sender_actions = 1', 'sender_actions = 3', 'sender_actions = 4']
    },
    {
        'path': './systems/NRP_incorrect.smv',
        'aps' : ['sender_actions = 0', 'sender_actions = 1', 'sender_actions = 3', 'sender_actions = 4']
    },
    {
        'path': './systems/snark1_M1_concurrent.smv',
        'aps' : ['popRightFAIL', 'BOTH_MODIFYING', 'FAIL']
    },
    {
        'path': './systems/snark1_M2_sequential.smv',
        'aps' : ['popRightFAIL', 'BOTH_MODIFYING']
    }
]

# Fix a 
random.seed(0)


def system_call(cmd : str, timeout_sec=None):
    proc = subprocess.Popen(cmd, shell=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)

    try:
        stdout, stderr = proc.communicate(timeout=timeout_sec)
    except TimeoutExpired:
        proc.kill()
        return None, "", ""
   
    return proc.returncode, stdout.decode("utf-8").strip(), stderr.decode("utf-8").strip()

def generate_random_boolean_formula(aps : list[str], max_depth : int):
    def gen_rec(d):
        if d >= max_depth:
            r = random.random()
            if r < 0.1:
                return '1'
            elif r < 0.1:
                return '0'
            else:
                rnd_index = random.randint(0, len(aps) - 1)
                return aps[rnd_index]
        else:
            r = random.random()

            if r < 0.1:
                return ('(' + gen_rec(d+1) + ' & ' + gen_rec(d+1) + ')')
            elif r < 0.2:
                return ('(' +gen_rec(d+1) + ' | ' + gen_rec(d+1) + ')')
            elif r < 0.3:
                return ('(' +gen_rec(d+1) + ' <-> ' + gen_rec(d+1) + ')')
            elif r < 0.5:
                return ('(!' + gen_rec(d+1) + ')')
            else:
                # Return and atomic formula
                return gen_rec(max_depth)

    return gen_rec(0)


def generate_random_formula_FE(instance : dict):
    aps = instance['aps']
    trace_variables = ['pi', 'pii']
    max_depth = 3

    indexed_aps = ['{' + a + '}' + '_' + pi for a in aps for pi in trace_variables]

    body = generate_random_boolean_formula(indexed_aps, max_depth)

    return 'forall pi. exists pii. ' + body


def generate_random_formula_EE(instance : dict):
    aps = instance['aps']
    trace_variables = ['pi', 'pii']
    max_depth = 3

    indexed_aps = ['{' + a + '}' + '_' + pi for a in aps for pi in trace_variables]

    body = generate_random_boolean_formula(indexed_aps, max_depth)

    return 'exists pi. exists pii. ' + body


def convert_formula_to_hyperqb_format(f : str):
    f = f.replace('{', '')
    f = f.replace('}', '')
    f = f.replace('_pii', '[pii]')
    f = f.replace('_pi', '[pi]')
    
    return f

# ================================================================================================================
# ================================================================================================================

timeout_sec = 60

def run_hyplan_FE(system_content: str, formula_content : str):
    with open('system.txt', 'w+') as file:
        file.write(system_content)

    with open('formula.txt', 'w+') as file:
        file.write(formula_content)

    cmd = ' '.join(['../app/HyPlan', '--nusmv', './system.txt', './formula.txt'])

    startTime = time.time()
    c, stdout, stderr = system_call(cmd=cmd,timeout_sec=timeout_sec)
    endTime = time.time()

    if c == None:
        return None
    
    encoding_time = endTime - startTime 

    for l in stdout.splitlines():
        if l.startswith('|Predicates|:'):
            number_predicates = int(l.split(' ')[1])
        if l.startswith('|Actions|:'):
            number_actions = int(l.split(' ')[1])
        if l.startswith('|Objects|:'):
            number_objects = int(l.split(' ')[1])


    startTime = time.time()

    c, stdout, stderr = system_call(cmd='python ../translator-fond/translate.py ./dom.pddl ./prob.pddl',timeout_sec=timeout_sec)
    endTime = time.time()

    if c == None or stderr.strip() != "" :
        return None

    sas_time = endTime - startTime 

    startTime = time.time()
    c, stdout, stderr = system_call(cmd='java -Xmx4g -cp ' + mynd_path + ' mynd.MyNDPlanner ./output.sas',timeout_sec=timeout_sec)
    endTime = time.time()

    if c == None:
        return None

    planner_time = endTime - startTime 

    if 'INITIAL IS PROVEN' in stdout:
        final_res = True 
    else:
        assert('INITIAL IS DISPROVEN' in stdout)
        final_res = False

    return {'time_enc': encoding_time, 'time_solve': sas_time + planner_time, 'number_predicates': number_predicates, 'number_actions': number_actions, 'number_objects': number_objects, 'res': final_res}


def run_hyplan_EE(system_content: str, formula_content : str):
    with open('system.txt', 'w+') as file:
        file.write(system_content)

    with open('formula.txt', 'w+') as file:
        file.write(formula_content)

    cmd = ' '.join(['../app/HyPlan', '--reach', '--nusmv', './system.txt', './formula.txt'])

    startTime = time.time()
    c, stdout, stderr = system_call(cmd=cmd,timeout_sec=timeout_sec)
    endTime = time.time()
    if c == None: return None

    encoding_time = endTime - startTime 

    for l in stdout.splitlines():
        if l.startswith('|Predicates|:'):
            number_predicates = int(l.split(' ')[1])
        if l.startswith('|Actions|:'):
            number_actions = int(l.split(' ')[1])
        if l.startswith('|Objects|:'):
            number_objects = int(l.split(' ')[1])

    startTime = time.time()
    c, stdout, stderr = system_call(cmd=scorpion_path + ' ./dom.pddl ./prob.pddl --search "astar(cegar())"',timeout_sec=timeout_sec)
    endTime = time.time()
    if c == None: return None
    t = endTime - startTime 

    if 'Solution found' in stdout:
        final_res = True 
    else:
        final_res = False

    return {'time_enc': encoding_time, 'time_solve': t, 'number_predicates': number_predicates, 'number_actions': number_actions, 'number_objects': number_objects, 'res': final_res}


def run_pg(system_content: str, formula_content : str):
    with open('system.txt', 'w+') as file:
        file.write(system_content)

    with open('formula.txt', 'w+') as file:
        file.write(formula_content)

    cmd = ' '.join(['../app/HyPlan', '--pg', '--nusmv', './system.txt', './formula.txt'])

    startTime = time.time()
    c, stdout, stderr = system_call(cmd=cmd,timeout_sec=timeout_sec)
    endTime = time.time()
    if c == None: return None
    t = endTime - startTime 

    for l in stdout.splitlines():
        if l.startswith('|S|:'):
            number_states = int(l.split(' ')[1])



    startTime = time.time()
    c, stdout, stderr = system_call(cmd=oink_path + ' ./game.pg ./sol.pg',timeout_sec=timeout_sec)
    endTime = time.time()
    if c == None: return None
    solve_time = endTime - startTime 

    return {'time_enc': t, 'time_solve': solve_time, 'number_states': number_states}

# ================================================================================================================
# ================================================================================================================


def run_all(number_of_formulas : int, prefix : str):
    for instance in instances:
        print('=================', instance['path'], '=================')

        with open(instance['path']) as file:
            system_content = file.read()
        
        if prefix == 'FE':
            formulas = [generate_random_formula_FE(instance) for _ in range(0, number_of_formulas)]
        elif prefix == 'EE':
            formulas = [generate_random_formula_EE(instance) for _ in range(0, number_of_formulas)]
        else:
            print('Unsupported Prefix', prefix)
            exit(0)
        

        time_hyplan = 0
        number_predicates = 0
        number_actions = 0
        number_objects = 0

        time_pg = 0
        number_states = 0

        for formula in formulas:

            # ============== Run HyPlan ==============
            if prefix == 'FE':
                res_hyplan = run_hyplan_FE(system_content=system_content, formula_content=formula)
            elif prefix == 'EE':
                res_hyplan = run_hyplan_EE(system_content=system_content, formula_content=formula)
            else:
                print('Unsupported prefix')
                exit(0)

            if res_hyplan != None:
                print('|', end='', flush=True)
                time_hyplan = time_hyplan + res_hyplan['time_enc'] + res_hyplan['time_solve']
                number_predicates = number_predicates + res_hyplan['number_predicates']
                number_actions = number_actions + res_hyplan['number_actions']
                number_objects = number_objects + res_hyplan['number_objects']
            else:
                print('TO', end='', flush=True)

            # ============== Run PG ==============

            res_pg = run_pg(system_content=system_content, formula_content=formula)

            if res_pg != None:
                print('|', end='', flush=True)
                time_pg = time_pg  + res_pg['time_enc'] + res_pg['time_solve']
                number_states = number_states + res_pg['number_states']
            else:
                print('T', end='', flush=True)

            print(' # ', end='', flush=True)

        print('')

        time_hyplan = time_hyplan / len(formulas)
        number_predicates = number_predicates / len(formulas)
        number_actions = number_actions / len(formulas)
        number_objects = number_objects / len(formulas)

        time_pg = time_pg / len(formulas)
        number_states = number_states / len(formulas)

        print ('HyPlan time: {}'.format(time_hyplan))
        print ('PDDL Size: {}/{}/{}'.format(number_predicates, number_actions, number_objects))
        print('')

        print ('PG time: {}'.format(time_pg))
        print ('PG Size: {}'.format(number_states))

print('============ Running \\exists\\exists instances ================')

run_all(number_of_formulas=10,prefix='EE')

print('\n\n\n============ Running \\forall\\exists instances ================')

run_all(number_of_formulas=10,prefix='FE')

