import os
import re
import json
from collections import deque
from sympy import symbols, to_dnf # type: ignore
from ltlf2dfa.parser.ltlf import LTLfParser # type: ignore

def prune_and_conditions(dfa_text):
    lines = dfa_text.strip().split("\n")
    cleaned_lines = []

    for line in lines:
        match = re.search(r'(\d+ -> \d+ \[label=")(.+?)("\];)', line)
        if match:
            prefix, conditions, suffix = match.groups()
            valid_conditions = []

            for condition in conditions.split(" | "):
                terms = condition.split(" & ")
                positive_terms = {term for term in terms if "~" not in term}
                negative_terms = {term for term in terms if "~" in term}

                if len(positive_terms) == 1 or (len(positive_terms) == 0 and len(negative_terms) > 0):
                    valid_conditions.append(condition)

            if valid_conditions:
                new_conditions = " | ".join(valid_conditions)
                cleaned_lines.append(f"{prefix}{new_conditions}{suffix}")
        else:
            cleaned_lines.append(line)

    return "\n".join(cleaned_lines)

def remove_useless_rows(input_string):
    lines = input_string.split('\n')
    pattern = re.compile(r'\b\d+\s*->\s*\d+\b')
    
    filtered_lines = [line for line in lines if pattern.search(line)]
    
    return '\n'.join(filtered_lines)

def extract_accepting_states(dot_code):
    match = re.search(r'node \[shape = doublecircle\];\s*([\d;\s]+)', dot_code)
    if match:
        accepting_states = re.findall(r'\d+', match.group(1))
        return [int(state) for state in accepting_states]
    return []

def clean_input_string(input_string):
    cleaned_lines = []
    for line in input_string.splitlines():
        line = re.sub(r'\[label="', ' ', line)
        line = re.sub(r'"\];?', '', line)
        line = re.sub(r'[()]', '', line)
        cleaned_lines.append(line.strip())
    return "\n".join(cleaned_lines)

def extract_matrix(input_str, labels):
    lines = input_str.strip().split("\n")
    output = []

    for line in lines:
        parts = line.split(" | ")
        row_result = []
        
        for part in parts:
            result = []
            
            for label in labels:
                if part.endswith(f"~{label}") or f"~{label} " in part:
                    result.append(0)
                elif part.endswith(f"{label}") or label + " " in part:
                    result.append(1)
                else:
                    result.append(2)
            
            if 1 not in result:
                result = [1 if x == 2 else x for x in result]
            else:
                result = [0 if x == 2 else x for x in result]
            
            row_result.append(result)
        
        output.append(row_result)
    
    return output

def extract_arcs(input_str):
    arcs = re.findall(r"(\d+) -> (\d+)", input_str)
    arcs = [(int(n1), int(n2)) for n1, n2 in arcs]
    return arcs

def or_internal_states(matrix):
    result = []
    for x in matrix:
        if len(x) == 1:
            result.append(x[0])
        else:
            new_row = x[0]
            for i in range(1, len(x)):
                for j in range(len(x[i])):
                    if x[i][j] == 1:
                        new_row[j] = 1
            result.append(new_row)
    return result

def attach_arcs_to_matrix(arcs, matrix):
    for i in range(len(arcs)):
        for j in range(len(matrix[i])):
            if matrix[i][j] == 1:
                matrix[i][j] = (1, arcs[i][1])
    
    return matrix

def or_for_states(matrix, arcs):
    new_matrix = []
    for i in range(len(arcs)):
        if i == 0 and len(arcs) == 1:
            current_row = matrix[0]
            new_matrix.append(current_row)
            continue
        elif i == 0:
            current_row = matrix[0]
        elif i == len(arcs)-1:
            new_matrix.append(current_row)
            continue
    
        if arcs[i] == arcs[i+1]:
            for j in range(len(matrix[i])):
                if matrix[i+1][j] != 0:
                    current_row[j] = matrix[i+1][j]
        elif arcs[i] != arcs[i+1]:
            new_matrix.append(current_row)
            current_row = matrix[i+1]

    return new_matrix

def de_morgan(dfa_text, symbols_dict):
    lines = dfa_text.split("\n")
    new_lines = []
    
    for line in lines:
        if "& (" in line:
            matches = re.findall(r'"([^"]+)"', line)
            parsed_expr = eval(matches[0], {}, symbols_dict)
            new_condition = str(to_dnf(parsed_expr, simplify=True))
            new_line = line.replace(matches[0], new_condition)
            new_lines.append(new_line)
        else:
            new_lines.append(line)
    
    return "\n".join(new_lines)

def min_cost_to_accepting(edges, accepting_states):
    graph = {}
    all_states = set()

    for u, v in edges:
        all_states.update([u, v])
        if v not in graph:
            graph[v] = []
        graph[v].append(u)

    queue = deque(accepting_states)
    distances = {state: 0 for state in accepting_states}
    
    while queue:
        node = queue.popleft()
        if node in graph:
            for neighbor in graph[node]:
                if neighbor not in distances:
                    distances[neighbor] = distances[node] + 1
                    queue.append(neighbor)
    
    for state in all_states:
        if state not in distances:
            distances[state] = float('inf')

    return distances

def create_automata(formula_str, labels):
    parser = LTLfParser()
    formula = parser(formula_str)
    dfa_text = formula.to_dfa()
    symbols_dict = {name: symbols(name) for name in labels}
    accepting_states = extract_accepting_states(dfa_text)
    dfa_text = remove_useless_rows(dfa_text)
    dfa_text = de_morgan(dfa_text, symbols_dict)
    dfa_text = prune_and_conditions(dfa_text)
    dfa_text = clean_input_string(dfa_text)
    matrix = extract_matrix(dfa_text, labels)
    arcs = extract_arcs(dfa_text)
    distances = min_cost_to_accepting(arcs, accepting_states)
    matrix = or_internal_states(matrix)
    matrix = attach_arcs_to_matrix(arcs, matrix)
    states = [x[0] for x in arcs]
    matrix = or_for_states(matrix, states)
    states = list(set(states))
    matrix = [[value[1] for value in row] for row in matrix]
    deadlock_states = [k for k, v in distances.items() if v == float('inf')]

    return {
        "states": states,
        "distances": distances,
        "transition_matrix": matrix,
        "accepting_states": accepting_states,
        "deadlock_states": deadlock_states
    }

def print_automata(automata, labels):
    header = f"{'state':<15}" + f"{'distance':<10}| " + "".join(f"{label:<13}" for label in labels)
    print(header)
    print("-" * len(header))

    for idx, (state, row) in enumerate(zip(automata["states"], automata["transition_matrix"])):
        state_info = ""
        if state in automata.get("accepting_states", []):
            state_info = "(Accepting)"
        elif state in automata.get("deadlock_states", []):
            state_info = "(Deadlock)"
        
        row_str = "".join(f"{val}".ljust(13) for val in row)
        distance = automata["distances"][idx+1] if "distances" in automata else ""
        
        print(f"{str(state) + state_info:<16}{str(distance):<10}| {row_str}")

def save_automata(automata, filename='automata.json'):

    directory = os.path.dirname(filename)
    
    if directory:
        os.makedirs(directory, exist_ok=True)

    with open(filename, 'w') as json_file:
        json.dump(automata, json_file, indent=4)
    
    print(f"Automata saved to {filename}")

def import_automata(filename):

    with open(filename, "r") as json_file:
        automata = json.load(json_file)

    return automata