import itertools
import numpy as np
from joblib import Parallel, delayed

from multiprocessing import Value, Lock
import multiprocessing as mp
from tqdm import tqdm
from .GraphUtils import generate_DMG, calculate_E_separation_triple
from collections import defaultdict

counter = Value('i', 0)  # 'i' for integer
lock = Lock()

def parallel_process_with_joblib(directed_matrices, symmetric_matrices, verbose = False, num_workers=4):
    # Prepare the input data for parallelization (cartesian product of directed and symmetric matrices)
    matrix_pairs = [(directed_matrix, symmetric_matrix) for directed_matrix in directed_matrices for symmetric_matrix in symmetric_matrices]
    total_tasks = len(matrix_pairs)
    if verbose == True:
        print(f"Processing {total_tasks} matrix pairs")
    # Use joblib to parallelize the execution
    configurations = Parallel(n_jobs=num_workers)(delayed(process_single_matrix_pair)(directed_matrix, symmetric_matrix) for directed_matrix, symmetric_matrix in matrix_pairs)
    progress_bar = tqdm(total=total_tasks, desc="Processing matrix pairs")
    def update_progress(result):
        # This function will be called whenever a task completes to update the progress bar
        progress_bar.update(1)
        return result  # return the result for further collection

    # Use joblib to parallelize the execution, with a callback to update progress
    configurations = Parallel(n_jobs=num_workers)(
        delayed(process_single_matrix_pair)(directed_matrix, symmetric_matrix) for directed_matrix, symmetric_matrix in matrix_pairs
    )
    return configurations

def convert_independence_triples(independence_triple): # ChatGPT: convert into something useable as a dictionary key
    return sorted([(triple[0], triple[1], tuple(triple[2])) for triple in independence_triple])


def process_single_matrix_pair(directed_matrix, symmetric_matrix):
    # Simulating the task
    graph = generate_DMG(directed_matrix, symmetric_matrix, graph_type="DMG")
    independence_triple = calculate_E_separation_triple(graph)

    # # Update the global counter and print progress at regular intervals
    # global_counting_variable += 1
    # if global_counting_variable % 1000 == 0:
    #     print(f"Processed {global_counting_variable} matrix pairs")

    return {
        'directed_matrix': directed_matrix,
        'symmetric_matrix': symmetric_matrix,
        'independence_triples': convert_independence_triples(independence_triple)
    }

# Check
def generate_directed_matrices(n: int): # Generate all possible n x n directed matrices where each entry is either 0 or 1.
    for matrix_entries in itertools.product([0, 1], repeat=n * n):
        yield np.array(matrix_entries).reshape(n, n)

# Check
def generate_symmetric_matrices(n: int): # Generate all possible n x n symmetric matrices where each entry is either 0 or 1 and diag = 0.
    for upper_triangle_entries in itertools.product([0, 1], repeat=n * (n - 1) // 2):
        symmetric_matrix = np.zeros((n, n), dtype=int)
        upper_triangle_indices = np.triu_indices(n, k=1)
        symmetric_matrix[upper_triangle_indices] = upper_triangle_entries
        symmetric_matrix += symmetric_matrix.T  # Make it symmetric
        yield symmetric_matrix


# Helper function to group a subset of configurations
def group_configurations_chunk(config_chunk):
    local_grouped_configurations = defaultdict(list)
    for config in config_chunk:
        triple_key = tuple(config['independence_triples'])
        local_grouped_configurations[triple_key].append({
            'directed_matrix': config['directed_matrix'],
            'symmetric_matrix': config['symmetric_matrix']
        })
    return local_grouped_configurations

# Function to merge grouped configurations from different chunks
def merge_grouped_configurations(*partial_groupings):
    merged_grouped_configurations = defaultdict(list)
    for grouping in partial_groupings:
        for triple_key, config_list in grouping.items():
            merged_grouped_configurations[triple_key].extend(config_list)
    return merged_grouped_configurations

def parallel_group_configurations(configurations, num_workers=4):
    # Split configurations into chunks for parallel processing
    chunk_size = len(configurations) // num_workers
    config_chunks = [configurations[i:i + chunk_size] for i in range(0, len(configurations), chunk_size)]

    # Use joblib to parallelize the grouping process
    grouped_chunks = Parallel(n_jobs=num_workers)(
        delayed(group_configurations_chunk)(chunk) for chunk in config_chunks
    )

    # Merge all the partial groupings into a single dictionary
    grouped_configurations = merge_grouped_configurations(*grouped_chunks)

    return grouped_configurations

def is_maximal_matrix_pair(directed_matrix, symmetric_matrix, all_directed_matrices, all_symmetric_matrices, verbose=False): # is (directed_matrix, symmetric_matrix) the maximal matrix pair
    for other_directed_matrix, other_symmetric_matrix in zip(all_directed_matrices, all_symmetric_matrices):
        if not (np.all(directed_matrix >= other_directed_matrix) and np.all(symmetric_matrix >= other_symmetric_matrix)):
            return False  # This matrix pair is not maximal
    if verbose:
        print(f"Maximal matrix pair found:\nDirected matrix:\n{directed_matrix}\nSymmetric matrix:\n{symmetric_matrix}")
    return True  # This matrix pair is maximal


def find_maximal_matrix_pair(configs):
    all_directed_matrices = [config['directed_matrix'] for config in configs]
    all_symmetric_matrices = [config['symmetric_matrix'] for config in configs]
    maximal_pair = None
    for directed_matrix, symmetric_matrix in zip(all_directed_matrices, all_symmetric_matrices):
        if is_maximal_matrix_pair(directed_matrix, symmetric_matrix, all_directed_matrices, all_symmetric_matrices):
            # If a maximal pair is found, store it in the result dictionary
            maximal_pair = (directed_matrix, symmetric_matrix)
            break  # Stop searching after finding the maximal pair for this group
    return maximal_pair