from joblib import Parallel, delayed, parallel_config
import json
import networkx as nx
import utils.isa_utils as isa_utils
import utils.all_exceptions as E
import os
import itertools  
from tqdm import tqdm  


def largest_connected_component(formal_statements, checker):
    G = nx.Graph()
    for i in range(len(formal_statements)):
        G.add_node(i)
    # iterate each node pair
    pairs = list(itertools.combinations(range(len(formal_statements)), 2))  
    proofs = {}
    memory_heuristics = []
    # try:  
    for i, j in tqdm(pairs): 
        # check the equivalence
        if nx.has_path(G, i, j):
            G.add_edge(i, j)
        else:
            print("="*100, flush=True)
            print(formal_statements[i], '\n', formal_statements[j])
            # check existing proofs
            msg = None
            for (u,v) in proofs.keys():
                if formal_statements[u] == formal_statements[i] and formal_statements[v] == formal_statements[j]:
                    flag, msg = proofs[(u,v)]
                elif formal_statements[u] == formal_statements[j] and formal_statements[v] == formal_statements[i]:
                    flag, msg = proofs[(u,v)]
                else:
                    continue
            if msg is not None:
                print((i,j), 'using existing proof')
                if flag: G.add_edge(i, j) 
                continue
            try:
                flag, msg = isa_utils.check_equivalence_simplify(formal_statements[i], formal_statements[j], checker, memory_heuristics)
                for m in msg.split('\n'):
                    if "using tactic" in m: memory_heuristics.append(m.split("using tactic:")[1].strip())
            except (E.ThmFormatException, E.SimplifyException, E.ConcException) as e:
                flag = False
                msg = str(e)
            proofs[(i,j)] = (flag, msg); 
            print((i,j), msg)
            if flag: G.add_edge(i, j); continue
            try:
                flag, msg = isa_utils.check_equivalence(formal_statements[i], formal_statements[j], checker, memory_heuristics)
                for m in msg.split('\n'):
                    if "using tactic" in m: print(123131231); memory_heuristics.append(m.split("using tactic:")[1].strip())
            except (E.ThmFormatException, E.SimplifyException, E.ConcException) as e:
                flag = False
                msg = str(e)    
            proofs[(i,j)] = (flag, msg); 
            print((i,j), msg)
            if flag: G.add_edge(i, j); continue
            
    # except KeyboardInterrupt:  
        # print("Process interrupted by user.")  
        # raise SystemExit 
    connected_components = sorted(nx.connected_components(G), key=len, reverse=True) 
    connected_subgraphs = [list(c) for c in connected_components]
    # max_size = len(connected_components[0])  
    # max_connected_subgraphs = [c for c in connected_components if len(c) == max_size]  
    return connected_subgraphs

def process_file(idx):
    checker = isa_utils.start_isa(port=4050+idx)
    for file_path in divided_data[idx]:
        with open(file_path, 'r') as f:
            data = json.load(f)
        # if "prediction" in data and len(data["prediction"]) > 0:
            # print(f"File {file_path} has been processed.", data["prediction"])
            # continue
        formal_statements = []
        labels = []
        for i in range(k):
            name = f"a_{i}"
            formal_statements.append(data[name]['formal problem'])
            if "label" not in data[name] or "syntax" not in data[name]:
                continue
            else:
                labels.append(int(data[name]['label']) * int(data[name]['syntax']))
        if len(labels) < 10:
            continue  # Skip files with less than 10 labels
        else:
            largest_components = largest_connected_component(formal_statements, checker)
            print(f"File {file_path} has {len(largest_components)} largest connected components.", largest_components)
        ### write into file_path
        largest_components = [list(l) for l in largest_components]
        data["prediction"] = largest_components
        with open(file_path, 'w') as f:
            json.dump(data, f, indent=4)
    checker.exit()
    
k = 10
path = './batch/task_test_gpt-4/0'  

json_file_paths = [os.path.join(dirpath, f)
                    for dirpath, dirnames, filenames in os.walk(path)
                    for f in filenames if f.endswith('.json')]

divided_data = [json_file_paths]
process_file(0)
