import os
import numpy as np
import json
from absl import app, flags, logging
from time import time

from joblib import Parallel, delayed
from multiprocessing import Value, Lock
from tqdm import tqdm

from graph_modelling.Utils.GraphUtils import generate_DMG, calculate_E_separation_triple
from graph_modelling.Utils.Utils import generate_directed_matrices, generate_symmetric_matrices, parallel_process_with_joblib, parallel_group_configurations, find_maximal_matrix_pair

flags.DEFINE_integer('n_nodes', 4, 'How many SDEs to sample for each test case') # should be one if cherry_picked = True
flags.DEFINE_enum("graph_type", "DMG", ('DG', 'DMG'), "What graph type you want to construct?")
flags.DEFINE_boolean('verbose', True, 'Whether to have extensive print statements? ') # Yes of course :D 
flags.DEFINE_boolean('print_all_configurations', True, 'Whether to print out all the configurations for on independence model?')
flags.DEFINE_integer('num_workers', 64, 'num_workers for parallel processing')

FLAGS = flags.FLAGS

def main(_):
    print("Running main")
    exp_name = f"{FLAGS.graph_type}_{FLAGS.n_nodes}"
    logging.info(f"Running experiment: {exp_name}")
    
    start_time = time()
    configurations = []  # Store all configurations
    # global_counting_variable = 0  # Define a global counting variable
    directed_matrices = list(generate_directed_matrices(FLAGS.n_nodes))
    symmetric_matrices = list(generate_symmetric_matrices(FLAGS.n_nodes))
    if FLAGS.graph_type == "DG":
        symmetric_matrices = [np.zeros((FLAGS.n_nodes, FLAGS.n_nodes))]
    print("Step 1: Calculate all independence models") # parallel_process

    # configurations = parallel_process(directed_matrices, symmetric_matrices, verbose=FLAGS.verbose, num_workers=FLAGS.num_workers)
    configurations = parallel_process_with_joblib(directed_matrices, symmetric_matrices, num_workers=FLAGS.num_workers)


    # Parallelize the grouping process
    print("Step 2: Group graphs to independence models")
    grouped_configurations = parallel_group_configurations(configurations, num_workers=FLAGS.num_workers)

    print("Step 3: Find maximal matrix pairs among grouped independence-model-graphs")
    configs_without_maximal_pairs = []
    for triple_key, configs in grouped_configurations.items(): # 3. Display grouped results and find maximal matrix pairs
        if FLAGS.verbose == True:
            print(f"Independence triples: {triple_key}")
            print(f"Configurations with these triples ({len(configs)} configurations):")
            if FLAGS.print_all_configurations == True:
                for config in configs:
                    print(f"Directed matrix:\n{config['directed_matrix']}")
                    print(f"Symmetric matrix:\n{config['symmetric_matrix']}")

        # Find the maximal matrix pair for this group
        maximal_pair = find_maximal_matrix_pair(configs)
        if maximal_pair is not None:
            if FLAGS.verbose == True:
                print(f"Maximal matrix pair found:\nDirected matrix:\n{maximal_pair[0]}\nSymmetric matrix:\n{maximal_pair[1]}")
        else:
            print("No maximal matrix pair found")
            configs_without_maximal_pairs.extend(configs)
    def convert_numpy_int64(obj):
        if isinstance(obj, np.int64):
            return int(obj)
        raise TypeError
    def convert_tuple_keys_to_strings(d):
        return {str(k): v for k, v in d.items()}
    
    end_time = time()
    print(f"Time taken: {end_time - start_time}")
    # grouped_configurations_with_string_keys = {str(triple_key): configs for triple_key, configs in grouped_configurations.items()}

    # file = f"results/{exp_name}.json"
    # with open(file, 'w') as json_file:
    #     json.dump(configs_without_maximal_pairs, json_file, indent=4, default=convert_numpy_int64)

if __name__ == '__main__':
    app.run(main)