import pickle
from itertools import combinations, islice
import torch
import os
from formal_pruning import formal_patch_query
from pysat.examples.hitman import Hitman
import logging
from models.model import neuron_idx_to_neuron_name

neuron_idx_to_name = neuron_idx_to_neuron_name()


class ContrastiveTimeoutError(Exception):
    """
    Custom exception to handle timeouts during contrastive computation.
    """
    def __init__(self, message="MHS computation timed out."):
        self.message = message
        super().__init__(self.message)

def get_neurons_by_names(neurons_subsets):
    """
    Converts neuron indices or subsets of indices to their global names using neuron_idx_to_name.
    If an element is a scalar, return just the name. If it's a list/tuple, return a tuple of names.
    """
    result = []
    for subset in neurons_subsets:
        if isinstance(subset, int):
            result.append(neuron_idx_to_name[subset])
        else:
            result.append(tuple(neuron_idx_to_name[idx] for idx in subset))
    return result


def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def check_contrastive(dataset, pruned_net, x, full_net_path, saved_dup_patch_net_path, Z_mask, device, metric, epsilon, delta, patch_eps, exp_paths, adv_x_path, query_timeout, neurons_prefix):
    is_safe, verification_res = formal_patch_query(dataset, pruned_net, x, full_net_path, saved_dup_patch_net_path, Z_mask, device, metric, epsilon, delta, patch_eps, exp_paths, adv_x_path, verify_patching_only=False, query_timeout=query_timeout, verbose=False)
    is_contrastive = verification_res['unsafe'] > 0 or verification_res['timeout'] > 0
    if verification_res['timeout'] > 0:
        logging.warning(f"Timeout occurred during verification for neurons {neurons_prefix}. Treating as non-contrastive.")
        return False, True  # is_contrastive, had_timeout

    logging.info(f"neurons indices {neurons_prefix} | is_contrastive: {is_contrastive} -----  is_safe: {is_safe}, verification_res: {verification_res}")
    return is_contrastive, False


def check_circuit_contrastives(circuit_Z_mask, net, x, full_net_path, device, metric, epsilon,
                               delta, patch_eps,  exp_paths, adv_x_path, query_timeout, check_within_circuit=True, verbose=False):
    """
    Parameters:
    - circuit_Z_mask: Tensor representing the circuit's mask.
    - net: The network (either pruned_patch_net or full_net).
    - x: Input tensor.
    - metric: Contrastive metric.
    - epsilon: Epsilon value for contrastive checking.
    - patch_eps: Patch epsilon.
    - delta: Delta parameter.
    - check_all_neurons: If True, evaluates all neurons, not just circuit ones.
    - check_within_circuit: If True, checks contrastiveness with relation to circuit.
    """

    contrastive_neurons, noncontrastive_neurons = set(), set()
    total_neurons = len(neuron_idx_to_name)
    Z_mask = circuit_Z_mask.clone().detach() if check_within_circuit else torch.zeros(total_neurons)

    for i in range(total_neurons):
        if circuit_Z_mask[i] == 0:
            neuron_name = neuron_idx_to_name.get(i, i)
            if verbose: logging.info(f"Examining Neuron: {neuron_name}")

            Z_mask[i] = 1.0
            is_contrastive, had_timeout = check_contrastive("mnist", net, x, full_net_path, exp_paths['saved_dup_patch_net_path_mnist'],
                                               Z_mask, device, metric, epsilon, delta, patch_eps, exp_paths, adv_x_path, query_timeout, neuron_name)

            if is_contrastive:
                contrastive_neurons.add(neuron_name)
            else:
                noncontrastive_neurons.add(neuron_name)
            Z_mask[i] = 0.0  # Reset the mask

    return contrastive_neurons, noncontrastive_neurons


def check_contrastive_subset(neuron_indices, dataset,  net, x, full_net_path, device, metric, epsilon, delta,
                             patch_eps, exp_paths, adv_x_path, query_timeout,  verbose=False):
    neuron_names = [neuron_idx_to_name.get(idx, idx) for idx in neuron_indices]
    if verbose: logging.info(f"Examining Neurons: {neuron_names}")

    total_neurons = len(neuron_idx_to_name)
    Z_mask = torch.zeros(total_neurons)
    for idx in neuron_indices:
        Z_mask[idx] = 1.0

    is_contrastive, had_timeout = check_contrastive(dataset, net, x, full_net_path, exp_paths['saved_dup_patch_net_path_mnist'], Z_mask, device,
                                       metric, epsilon, delta, patch_eps, exp_paths, adv_x_path, query_timeout, neuron_names)


    if verbose: logging.info(f"Neurons {neuron_names} contrastive?: {is_contrastive}")

    return is_contrastive, had_timeout


def check_all_contrastive_subsets(subset_size, dataset, net, x, full_net_path,
                                  noncontrastive_subsets_previous_sizes, device, metric, epsilon,
                                  delta, patch_eps, exp_paths, adv_x_path, query_timeout, verbose=False,
                                  job_id=None, num_jobs=None, stop_signal_file=None):
    logging.info(f"[checking contrastive subsets]: parameters: subset_size={subset_size}, full_net_path={full_net_path},  device={device}, metric={metric}, epsilon={epsilon}, delta={delta}, patch_eps={patch_eps}, exp_paths={exp_paths}, adv_x_path={adv_x_path}, query_timeout={query_timeout}")
    logging.info(f"[checking contrastive subsets]: noncontrastive_subsets_previous_sizes={noncontrastive_subsets_previous_sizes}")
    multiple_jobs = job_id is not None and num_jobs is not None and num_jobs > 1
    contrastive_subsets, noncontrastive_subsets, timeout_subsets = set(), set(), set()
    total_neurons = len(neuron_idx_to_name)

    noncontrastive = set(noncontrastive_subsets_previous_sizes)

    # Generator for minimal (prefiltered) subsets
    def candidate_generator():
        for subset in combinations(range(total_neurons), subset_size):
            if all(
                smaller in noncontrastive
                for r in range(1, subset_size)
                for smaller in combinations(subset, r)
            ):
                yield subset

    # Decide on work distribution
    if multiple_jobs:
        # Count total candidates without storing them
        total_work = sum(1 for _ in candidate_generator())
        base_chunk_size = total_work // num_jobs
        remainder = total_work % num_jobs

        if job_id < remainder:
            start_index = job_id * (base_chunk_size + 1)
            end_index = start_index + base_chunk_size + 1
        else:
            start_index = remainder * (base_chunk_size + 1) + (job_id - remainder) * base_chunk_size
            end_index = start_index + base_chunk_size

        # Slice the generator for this job
        combinations_to_check = islice(candidate_generator(), start_index, end_index)
        logging.info(f"Job {job_id}/{num_jobs}: processing {end_index - start_index} combinations from index {start_index} to {end_index} (out of {total_work} candidates).")
    else: # Single worker: process all
        combinations_to_check = candidate_generator()
        if job_id is not None:
            total_work = sum(1 for _ in candidate_generator())
            logging.info(f"Job {job_id}/{num_jobs}: processing all {total_work} combinations.")

    for i, subset in enumerate(combinations_to_check):
        # Check for the stop signal file before every combination.
        if stop_signal_file and os.path.exists(stop_signal_file):
            logging.warning(f"Stop signal detected by job {job_id}. Halting subset checks for this stage.")
            break

        # Directly verify the subset
        is_contrastive, had_timeout = check_contrastive_subset(subset, dataset, net, x, full_net_path, device, metric,
                                                  epsilon, delta, patch_eps, exp_paths, adv_x_path, query_timeout, verbose=verbose)
        subset_names = tuple(neuron_idx_to_name[idx] for idx in subset)
        if is_contrastive:
            contrastive_subsets.add(subset)
        else:
            noncontrastive_subsets.add(subset)

        if had_timeout:
            timeout_subsets.add(subset)
            if verbose: logging.info(f"Subset {subset_names} contrastive?: {is_contrastive}")

    return contrastive_subsets, noncontrastive_subsets, timeout_subsets


def run_constrastives_within_circuit_experiment(circuit_Z_mask, full_net, x, device, metric, epsilon,
                                                delta, patch_eps, exp_paths, adv_x_path, query_timeout):
    contrastive_neurons, noncontrastive_neurons = check_circuit_contrastives(circuit_Z_mask, full_net,
                                                                             x, device, metric, epsilon,
                                                                              delta, patch_eps, exp_paths, adv_x_path, query_timeout,
                                                                             check_within_circuit=True)
    logging.info(f"Contrastive circuit neurons (checking withing circuit): {contrastive_neurons}")
    logging.info(f"Non Contrastive circuit neurons (checking withing circuit): {noncontrastive_neurons}")

    # Check circuits neurons within the full network
    contrastive_neurons, noncontrastive_neurons = check_circuit_contrastives(circuit_Z_mask, full_net,
                                                                             x, device, metric, epsilon,
                                                                             delta, patch_eps,  exp_paths, adv_x_path, query_timeout,
                                                                             check_within_circuit=False)
    logging.info(f"Contrastive circuit neurons (checking withing full net): {contrastive_neurons}")
    logging.info(f"Non Contrastive circuit neurons (checking withing full net) {noncontrastive_neurons}")



def get_contrastives_mhs(groups):
    hitman = Hitman(solver='m22', htype='rc2')
    for group in groups:
        hitman.hit(group)
    mhs = hitman.get()
    hitman.delete()

    return mhs


def verify_mhs_neurons(dataset, mhs, net, x, full_net_path, device,
                       metric, epsilon, delta, patch_eps, exp_paths, adv_x_path,
                       query_timeout, verbose=True):
    """
    Verifies that masking out the given MHS preserves safety, and then checks that
    all neurons in the MHS are contrastive.
    """

    sufficiency_status = verify_mhs_sufficiency(adv_x_path, dataset, delta, device, epsilon, exp_paths, full_net_path,
                                             metric, mhs, net, patch_eps, query_timeout, x, verbose)
    if sufficiency_status == 'SUFFICIENT':
        logging.info("[verify_mhs_neurons] Verification passed. Checking contrastiveness within circuit.")

        Z_mask = torch.ones(len(neuron_idx_to_name))
        for neuron_idx in mhs:
            Z_mask[neuron_idx] = 0

        contrastive_neurons, noncontrastive_neurons = check_circuit_contrastives(
            Z_mask, net, x, full_net_path, device,
            metric, epsilon, delta, patch_eps,
            exp_paths, adv_x_path, query_timeout,
            check_within_circuit=True, verbose=verbose
        )
        logging.info(f"[verify_mhs_neurons] Contrastive neurons: {contrastive_neurons}")
        logging.info(f"[verify_mhs_neurons] Non-contrastive neurons: {noncontrastive_neurons}")
        return sufficiency_status, contrastive_neurons, noncontrastive_neurons

    return sufficiency_status, None, None

# TODO VERIFY HANDLING OF TIMEOUTS
def verify_mhs_sufficiency(adv_x_path, dataset, delta, device, epsilon, exp_paths, full_net_path, metric, mhs, net,
                           patch_eps, query_timeout, x, verbose):

    Z_mask = torch.ones(len(neuron_idx_to_name))
    for neuron_idx in mhs:
        Z_mask[neuron_idx] = 0
    if verbose: logging.info(f"[verify_mhs_sufficiency] Z_mask: {Z_mask}")
    is_safe, verification_res = formal_patch_query(dataset, net, x, full_net_path,
                                                   exp_paths['saved_dup_patch_net_path_mnist'], Z_mask,
                                                   device, metric, epsilon, delta, patch_eps, exp_paths, adv_x_path,
                                                   verify_patching_only=False, query_timeout=query_timeout,
                                                   verbose=verbose)
    logging.info(f"[verify_mhs_sufficiency] Is safe: {is_safe}")
    logging.info(f"[verify_mhs_sufficiency] Verification result: {verification_res}")
    if verification_res['unsafe'] > 0:
        return 'INSUFFICIENT'
    if verification_res['timeout'] > 0:
        return 'TIMEOUT'
    return 'SUFFICIENT'