import sklearn.metrics
import torch
import numpy as np

from sklearn.cross_decomposition import CCA
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import LinearRegression

from crc.baselines.contrastive_crl.src.mcc import mean_corr_coef, mean_corr_coef_out_of_sample

def check_recovery(dataset, model, device):
    results = dict()
    samples_for_evaluation = 5000
    model.eval()
    model.to(device)
    z_gt, _, _ = dataset.sample(samples_for_evaluation, repeat_obs_samples=False)
    z_gt = z_gt[: samples_for_evaluation]

    x_gt = dataset.f(torch.tensor(z_gt, dtype=torch.float)).to(device)
    # x_int = torch.tensor(dataset.int_f, dtype=torch.float, device=device)[:samples_for_evaluation]
    z_pred = model.get_z(x_gt).cpu().detach().numpy()
    x_gt = x_gt.reshape(x_gt.size(0), -1).cpu().detach().numpy()
    regressor = LinearRegression().fit(z_pred, z_gt)
    results['R2_Z'] = regressor.score(z_pred, z_gt)

    if x_gt.size / x_gt.shape[0] < 1000:
        regressor = LinearRegression().fit(x_gt, z_gt)
        results['R2_IN'] = regressor.score(x_gt, z_gt)
    z_pred_sign_matched = z_pred * np.sign(z_pred)[:, 0:1] * np.sign(z_gt)[:, 0:1]
    if not np.isnan(z_pred).any():
        mccs = compute_mccs(z_gt, z_pred)
        mccs_sign_matched = compute_mccs(z_gt, z_pred_sign_matched)
        mccs_abs = compute_mccs(np.abs(z_gt), np.abs(z_pred))
        for k in mccs:
            results[k] = mccs[k]
            results[k +'_sign_matched'] = mccs_sign_matched[k]
            results[k +'_abs'] = mccs_abs[k]
    else:
        print('Model predicted NANs, skipping mccs evaluation')

    model.to(device)
    model.train()
    return results


def get_R2_values(x_obs, pred_obs):
    pred_obs = pred_obs - torch.mean(pred_obs, dim=0, keepdim=True)
    x_obs = x_obs - torch.mean(x_obs, dim=0, keepdim=True)
    scales = torch.sum(x_obs * pred_obs, dim=0, keepdim=True) / torch.sum(pred_obs * pred_obs, dim=0, keepdim=True)
    return 1 - torch.mean((x_obs - pred_obs * scales) ** 2, dim=0) / torch.mean(x_obs ** 2, dim=0)

def compute_mccs(x, y, return_assignments=False):
    cutoff = len(x) // 2
    ii, iinot = np.arange(cutoff), np.arange(cutoff, 2 * cutoff)
    
    if return_assignments:
        mcc_s_in, assignment_s_in = mean_corr_coef(x=x[ii], y=y[ii], return_assignment=True)
        mcc_s_out, assignment_s_out = mean_corr_coef_out_of_sample(
            x=x[ii], y=y[ii], x_test=x[iinot], y_test=y[iinot], return_assignment=True)
    else:
        mcc_s_in = mean_corr_coef(x=x[ii], y=y[ii])
        mcc_s_out = mean_corr_coef_out_of_sample(x=x[ii], y=y[ii], x_test=x[iinot], y_test=y[iinot])
    
    d = x.shape[1]
    cca_dim = min(5, d)
    cca = CCA(n_components=cca_dim, max_iter=5000)
    cca.fit(x[ii], y[ii])
    res_out = cca.transform(x[iinot], y[iinot])
    res_in = cca.transform(x[ii], y[ii])
    
    if return_assignments:
        mcc_w_out, assignment_w_out = mean_corr_coef(res_out[0], res_out[1], return_assignment=True)
        mcc_w_in, assignment_w_in = mean_corr_coef(res_in[0], res_in[1], return_assignment=True)
    else:
        mcc_w_out = mean_corr_coef(res_out[0], res_out[1])
        mcc_w_in = mean_corr_coef(res_in[0], res_in[1])
    
    result = {
        "mcc_s_in": mcc_s_in, 
        "mcc_s_out": mcc_s_out, 
        "mcc_w_in": mcc_w_in, 
        "mcc_w_out": mcc_w_out
    }
    
    if return_assignments:
        result.update({
            "assignment_s_in": assignment_s_in,
            "assignment_s_out": assignment_s_out,
            "assignment_w_in": assignment_w_in,
            "assignment_w_out": assignment_w_out,
            "cca_model": cca
        })
    
    return result


def print_cor_coef(x_obs, pred_obs):
    pred_obs = pred_obs - torch.mean(pred_obs, dim=0, keepdim=True)
    x_obs = x_obs - torch.mean(x_obs, dim=0, keepdim=True)
    pred_obs_copy = pred_obs.detach().clone()
    x_obs_copy = x_obs.detach().clone()
    pred_obs_copy = pred_obs_copy.unsqueeze(1)
    x_obs_copy = x_obs_copy.unsqueeze(2)
    cors = torch.mean(pred_obs_copy * x_obs_copy, dim=0)
    var_x = torch.std(x_obs_copy, dim=0)
    var_pred = torch.std(pred_obs_copy, dim=0)
    print((1 / var_pred).view(-1, 1) * cors * (1/var_x).view(1, -1))


def evaluate_graph_metrics(W_true, W, thresh=.3, nr_edges=1):
    B_true = np.where(np.abs(W_true)>.01, 1, 0)
    np.fill_diagonal(W, 0.)
    B = np.where(np.abs(W)>thresh, 1, 0)
    np.fill_diagonal(B, 0)
    loss_dict = evaluate_graph_metrics_for_thresholded_graphs(B_true, B)
    opt_threshold = get_opt_thresh(B_true, W)
    B = np.where(np.abs(W)>opt_threshold, 1, 0)
    np.fill_diagonal(B, 0)
    loss_dict_opt = evaluate_graph_metrics_for_thresholded_graphs(B_true, B)
    loss_dict['SHD_opt'] = loss_dict_opt['SHD']
    loss_dict['FPR_opt'] = loss_dict_opt['FPR']
    loss_dict['FDR_opt'] = loss_dict_opt['FDR']
    loss_dict['TPR_opt'] = loss_dict_opt['TPR']
    loss_dict['opt_thresh'] = opt_threshold

    edge_threshold = get_edge_threshold(W, nr_edges)
    B = np.where(np.abs(W)>edge_threshold, 1, 0)
    np.fill_diagonal(B, 0)
    loss_dict_opt = evaluate_graph_metrics_for_thresholded_graphs(B_true, B)
    loss_dict['SHD_edge_matched'] = loss_dict_opt['SHD']
    loss_dict['FPR_edge_matched'] = loss_dict_opt['FPR']
    loss_dict['FDR_edge_matched'] = loss_dict_opt['FDR']
    loss_dict['TPR_edge_matched'] = loss_dict_opt['TPR']
    loss_dict['edge_thresh'] = opt_threshold
    loss_dict['auroc'] = get_auroc(B_true, W)
    
    # Automatically add threshold + permutation optimization method results
    try:
        # Call new optimization method
        optimization_results = evaluate_graph_metrics_with_threshold_optimization(
            W_true, W, max_permutations=1000
        )
        
        # Add results in order of SHD, FPR, FDR, TPR
        loss_dict.update({
            'SHD_optimized': optimization_results['SHD_optimized'],
            'FPR_optimized': optimization_results['FPR_optimized'],
            'FDR_optimized': optimization_results['FDR_optimized'],
            'TPR_optimized': optimization_results['TPR_optimized'],
            'AUROC_optimized': optimization_results['AUROC_optimized'],
            'best_threshold': optimization_results['best_threshold'],
            'best_permutation': optimization_results['best_permutation'],
            'optimization_method': optimization_results['optimization_method']
        })
        
    except Exception as e:
        # If optimization method fails, add placeholders
        loss_dict.update({
            'SHD_optimized': float('inf'),
            'FPR_optimized': 1.0,
            'FDR_optimized': 1.0,
            'TPR_optimized': 0.0,
            'AUROC_optimized': 0.0,
            'best_threshold': -1,
            'best_permutation': [],
            'optimization_method': 'failed'
        })
    
    return loss_dict


def get_edge_threshold(W, nr_edges):
    W_abs = np.abs(W).reshape(-1)
    return W_abs[np.argsort(W_abs)[-(nr_edges+1)]]

def get_auroc(B_true, W):
    W_abs = np.abs(W)
    B_true_del = B_true[~np.eye(B_true.shape[0], dtype=bool)].reshape(B_true.shape[0], -1)
    W_abs_del = W_abs[~np.eye(W_abs.shape[0], dtype=bool)].reshape(W_abs.shape[0], -1)
    if np.sum(B_true_del) < 1:
        return -1
    return roc_auc_score(B_true_del.reshape(-1), W_abs_del.reshape(-1))


def evaluate_graph_metrics_for_thresholded_graphs(B_true, B):
    loss_dict = dict()
    loss_dict['SHD'] = np.sum(np.abs(B_true - B))
    if np.sum(B_true) < 1:
        loss_dict['TPR'] = 1
    else:
        loss_dict['TPR'] = np.sum(B_true * B) / np.sum(B_true)
    loss_dict['FPR'] = np.sum((1-B_true) * B) / (np.sum(1-B_true) - B_true.shape[0])
    if np.sum(B) < 1:
        loss_dict['FDR'] = 0
    else:
        loss_dict['FDR'] = np.sum((1-B_true) * B) / np.sum(B)
    return loss_dict


def get_opt_thresh(B_true, W):
    thresholds = np.arange(0., 2., .005)
    opt_thresh = 0.
    min_shd = np.inf
    for t in thresholds:
        B = np.where(np.abs(W) > t, 1, 0)
        shd_t = np.sum(np.abs(B_true - B))
        if shd_t < min_shd:
            opt_thresh = t
            min_shd = shd_t
    return opt_thresh





class LossCollection:
    def __init__(self):
        self.loss_dict = dict()
        self.steps = 0

    def add_loss(self, loss_updates, bs):
        for key, value in loss_updates.items():
            if key in self.loss_dict:
                self.loss_dict[key] = self.loss_dict[key] + bs * value
            else:
                self.loss_dict[key] = bs * value
        self.steps += bs

    def reset(self):
        self.steps = 0
        self.loss_dict = dict()

    def get_mean_loss(self):
        mean_loss_dict = dict()
        for key, value in self.loss_dict.items():
            mean_loss_dict[key] = value / self.steps
        return mean_loss_dict

    def print_mean_loss(self):
        print("Current mean losses are:")
        print(self.get_mean_loss())


def evaluate_denoising_performance(dataset, model, device, noise_level=0.1, noise_type='gaussian'):
    """Evaluate denoising performance of the model"""
    results = dict()
    samples_for_evaluation = min(5000, len(dataset.obs))

    model.eval()
    model.to(device)

    # Get clean ground truth
    z_gt, _, _ = dataset.sample(samples_for_evaluation, repeat_obs_samples=False)
    z_gt = z_gt[:samples_for_evaluation]

    # Generate clean observations
    x_clean = dataset.f(torch.tensor(z_gt, dtype=torch.float)).to(device)

    # Add noise
    if noise_type == 'gaussian':
        noise = torch.randn_like(x_clean) * noise_level
    elif noise_type == 'uniform':
        noise = (torch.rand_like(x_clean) - 0.5) * 2 * noise_level
    else:
        noise = torch.zeros_like(x_clean)

    x_noisy = x_clean + noise

    # Get denoised representations
    if hasattr(model, 'get_z'):
        if hasattr(model, 'denoise') or 'denoising' in model.__class__.__name__.lower():
            z_denoised = model.get_z(x_noisy, denoise=True) if hasattr(model, 'get_z') else model.get_z(x_noisy)
        else:
            z_denoised = model.get_z(x_noisy)
    else:
        z_denoised = model(x_noisy)

    z_clean_pred = model.get_z(x_clean) if hasattr(model, 'get_z') else model(x_clean)

    # Compute metrics
    z_denoised_np = z_denoised.cpu().detach().numpy()
    z_clean_pred_np = z_clean_pred.cpu().detach().numpy()

    # Denoising quality metrics
    from sklearn.metrics import mean_squared_error, mean_absolute_error

    # MSE between denoised and clean representations
    results['denoising_mse'] = mean_squared_error(z_clean_pred_np, z_denoised_np)
    results['denoising_mae'] = mean_absolute_error(z_clean_pred_np, z_denoised_np)

    # Correlation between denoised and clean
    correlations = []
    for i in range(z_clean_pred_np.shape[1]):
        corr = np.corrcoef(z_clean_pred_np[:, i], z_denoised_np[:, i])[0, 1]
        if not np.isnan(corr):
            correlations.append(corr)
    results['denoising_correlation'] = np.mean(correlations) if correlations else 0.0

    # Signal-to-noise ratio improvement
    noise_before = torch.norm(x_noisy - x_clean).item()
    if hasattr(model, 'decoder'):
        with torch.no_grad():
            x_reconstructed = model.decoder(z_denoised)
            noise_after = torch.norm(x_reconstructed - x_clean).item()
            results['snr_improvement'] = noise_before / (noise_after + 1e-8)

    # Recovery quality compared to ground truth
    regressor = LinearRegression().fit(z_denoised_np, z_gt)
    results['R2_denoised'] = regressor.score(z_denoised_np, z_gt)

    # MCC for denoised representations
    if not np.isnan(z_denoised_np).any():
        mccs_denoised = compute_mccs(z_gt, z_denoised_np)
        for k, v in mccs_denoised.items():
            results[f'{k}_denoised'] = v

    model.train()
    return results


def check_recovery_with_denoising(dataset, model, device):
    """Extended recovery check including denoising metrics"""
    # Get standard recovery metrics
    results = check_recovery(dataset, model, device)

    # Add denoising metrics if applicable
    if hasattr(dataset, 'add_observation_noise') and dataset.add_observation_noise:
        denoising_results = evaluate_denoising_performance(
            dataset, model, device,
            noise_level=dataset.noise_level,
            noise_type=dataset.noise_type
        )
        results.update(denoising_results)

    # Check if model has denoising capabilities
    if hasattr(model, 'denoise') or 'denoising' in model.__class__.__name__.lower():
        # Evaluate robustness to different noise levels
        noise_levels = [0.05, 0.1, 0.2, 0.3]
        robustness_scores = []

        for nl in noise_levels:
            denoise_res = evaluate_denoising_performance(
                dataset, model, device, noise_level=nl
            )
            robustness_scores.append(denoise_res.get('R2_denoised', 0))

        results['robustness_mean'] = np.mean(robustness_scores)
        results['robustness_std'] = np.std(robustness_scores)

    return results


def permute_matrix_by_permutation(W, permutation):
    """
    Apply row and column transformations to adjacency matrix using given permutation
    
    :param W: adjacency matrix
    :param permutation: permutation list, e.g. [1, 0, 2, 3, 4]
    :return: permuted matrix
    """
    d = W.shape[0]
    if len(permutation) != d:
        raise ValueError(f"Permutation length ({len(permutation)}) must match matrix dimension ({d})")
    
    # Create permutation matrix P
    P = np.zeros((d, d))
    for i, j in enumerate(permutation):
        P[i, j] = 1
    
    # Apply permutation: P @ W @ P.T
    W_permuted = P @ W @ P.T
    return W_permuted


def evaluate_graph_metrics_for_permutation(W_true, W, thresh=0.01):
    """
    Calculate graph structure metrics for a single permuted graph
    
    :param W_true: true adjacency matrix
    :param W: estimated adjacency matrix
    :param thresh: threshold
    :return: dictionary containing SHD, TPR, FPR, FDR, AUROC
    """
    # Calculate binary matrix for true graph
    B_true = np.where(np.abs(W_true) > 0.01, 1, 0)
    np.fill_diagonal(B_true, 0)
    
    # Calculate binary matrix for estimated graph
    B = np.where(np.abs(W) > thresh, 1, 0)
    np.fill_diagonal(B, 0)
    
    # Calculate basic metrics
    SHD = np.sum(np.abs(B_true - B))
    
    if np.sum(B_true) < 1:
        TPR = 1.0
    else:
        TPR = np.sum(B_true * B) / np.sum(B_true)
    
    FPR = np.sum((1-B_true) * B) / (np.sum(1-B_true) - B_true.shape[0])
    
    if np.sum(B) < 1:
        FDR = 0.0
    else:
        FDR = np.sum((1 - B_true) * B) / np.sum(B)
    
    # Calculate AUROC - using calculation method consistent with old method
    try:
        # Use calculation method consistent with old get_auroc function
        W_abs = np.abs(W)
        B_true_del = B_true[~np.eye(B_true.shape[0], dtype=bool)].reshape(B_true.shape[0], -1)
        W_abs_del = W_abs[~np.eye(W_abs.shape[0], dtype=bool)].reshape(W_abs.shape[0], -1)
        
        if np.sum(B_true_del) < 1:
            AUROC = -1
        else:
            AUROC = roc_auc_score(B_true_del.reshape(-1), W_abs_del.reshape(-1))
    except:
        AUROC = -1
    
    return {
        'SHD': SHD,
        'TPR': TPR,
        'FPR': FPR,
        'FDR': FDR,
        'AUROC': AUROC
    }


def get_optimal_threshold_and_permutation_metrics(W_true, W, threshold_range=None, max_permutations=1000):
    """
    Simultaneously optimize threshold and permutation order to find best SHD value, then calculate all metrics
    
    :param W_true: true adjacency matrix
    :param W: estimated adjacency matrix
    :param threshold_range: threshold range, auto-generated if None
    :param max_permutations: maximum number of permutations (for large matrices)
    :return: dictionary containing best threshold, permutation order and all metrics
    """
    d = W.shape[0]
    
    # If no threshold range provided, auto-generate
    if threshold_range is None:
        # Get absolute value range of matrix
        W_abs = np.abs(W)
        min_val = np.min(W_abs[W_abs > 0]) if np.any(W_abs > 0) else 0.01
        max_val = np.max(W_abs)
        
        # Generate 20 thresholds from min to max
        threshold_range = np.linspace(min_val, max_val, 20)
        # Add some special thresholds
        threshold_range = np.concatenate([
            [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
            threshold_range
        ])
        threshold_range = np.unique(threshold_range)
    
    best_shd = float('inf')
    best_threshold = None
    best_permutation = None
    best_metrics = None
    
    # Find best permutation for each threshold
    for i, threshold in enumerate(threshold_range):
        # Threshold matrix for current threshold
        W_thresholded = np.where(np.abs(W) > threshold, W, 0)
        
        # Find best permutation for current threshold
        if d <= 8:
            # Use exhaustive search for small matrices
            current_best_shd, current_best_perm, current_metrics = _find_best_permutation_exhaustive(
                W_true, W_thresholded, threshold
            )
        else:
            # Use heuristic search for large matrices
            current_best_shd, current_best_perm, current_metrics = _find_best_permutation_heuristic(
                W_true, W_thresholded, threshold, max_permutations
            )
        
        # Update global best result
        if current_best_shd < best_shd:
            best_shd = current_best_shd
            best_threshold = threshold
            best_permutation = current_best_perm
            best_metrics = current_metrics
    
    # Recalculate all metrics based on best configuration
    W_final = np.where(np.abs(W) > best_threshold, W, 0)
    W_final_permuted = permute_matrix_by_permutation(W_final, best_permutation)
    
    # Calculate final metrics
    final_metrics = evaluate_graph_metrics_for_permutation(W_true, W_final_permuted, best_threshold)
    
    # Build result dictionary
    results = {
        'best_threshold': best_threshold,
        'best_permutation': best_permutation,
        'best_SHD': best_shd,
        'final_metrics': final_metrics,
        'threshold_range': threshold_range.tolist() if isinstance(threshold_range, np.ndarray) else threshold_range
    }
    
    return results


def _find_best_permutation_exhaustive(W_true, W, threshold):
    """
    Use exhaustive search to find best permutation (for small matrices)
    
    :param W_true: true adjacency matrix
    :param W: thresholded adjacency matrix
    :param threshold: current threshold
    :return: (best SHD, best permutation, best metrics)
    """
    d = W.shape[0]
    
    # Generate all possible permutations
    from itertools import permutations
    all_permutations = list(permutations(range(d)))
    
    best_shd = float('inf')
    best_permutation = None
    best_metrics = None
    
    for perm in all_permutations:
        # Apply permutation
        W_permuted = permute_matrix_by_permutation(W, perm)
        
        # Calculate metrics
        metrics = evaluate_graph_metrics_for_permutation(W_true, W_permuted, threshold)
        
        # Update best result
        if metrics['SHD'] < best_shd:
            best_shd = metrics['SHD']
            best_permutation = list(perm)
            best_metrics = metrics
    
    return best_shd, best_permutation, best_metrics


def _find_best_permutation_heuristic(W_true, W, threshold, max_iterations=1000):
    """
    Use simulated annealing to find best permutation (for large matrices)
    
    :param W_true: true adjacency matrix
    :param W: thresholded adjacency matrix
    :param threshold: current threshold
    :param max_iterations: maximum number of iterations
    :return: (best SHD, best permutation, best metrics)
    """
    d = W.shape[0]
    
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Initial permutation
    current_perm = list(range(d))
    current_W = permute_matrix_by_permutation(W, current_perm)
    current_metrics = evaluate_graph_metrics_for_permutation(W_true, current_W, threshold)
    current_shd = current_metrics['SHD']
    
    best_perm = current_perm.copy()
    best_shd = current_shd
    best_metrics = current_metrics
    
    # Simulated annealing parameters
    temperature = 1.0
    cooling_rate = 0.95
    min_temperature = 0.01
    
    for iteration in range(max_iterations):
        # Generate new permutation (randomly swap two positions)
        new_perm = current_perm.copy()
        i, j = np.random.choice(d, 2, replace=False)
        new_perm[i], new_perm[j] = new_perm[j], new_perm[i]
        
        # Calculate metrics for new permutation
        new_W = permute_matrix_by_permutation(W, new_perm)
        new_metrics = evaluate_graph_metrics_for_permutation(W_true, new_W, threshold)
        new_shd = new_metrics['SHD']
        
        # Decide whether to accept new solution
        delta_shd = new_shd - current_shd
        
        if delta_shd < 0 or np.random.random() < np.exp(-delta_shd / temperature):
            current_perm = new_perm
            current_shd = new_shd
            current_metrics = new_metrics
            
            # Update global best
            if current_shd < best_shd:
                best_perm = current_perm.copy()
                best_shd = current_shd
                best_metrics = current_metrics
        
        # Cool down
        temperature *= cooling_rate
        if temperature < min_temperature:
            temperature = min_temperature
    
    return best_shd, best_perm, best_metrics


def evaluate_graph_metrics_with_threshold_optimization(W_true, W, max_permutations=1000):
    """
    Calculate graph structure metrics using threshold and permutation optimization
    
    :param W_true: true adjacency matrix
    :param W: estimated adjacency matrix
    :param max_permutations: maximum number of permutations
    :return: dictionary containing optimized metrics
    """
    # Get best threshold and permutation
    optimization_results = get_optimal_threshold_and_permutation_metrics(
        W_true, W, max_permutations=max_permutations
    )
    
    # Extract final metrics
    final_metrics = optimization_results['final_metrics']
    
    # Build result dictionary
    results = {
        'SHD_optimized': final_metrics['SHD'],
        'TPR_optimized': final_metrics['TPR'],
        'FPR_optimized': final_metrics['FPR'],
        'FDR_optimized': final_metrics['FDR'],
        'AUROC_optimized': final_metrics['AUROC'],
        'best_threshold': optimization_results['best_threshold'],
        'best_permutation': optimization_results['best_permutation'],
        'optimization_method': 'threshold_and_permutation_optimization'
    }
    
    return results


def permute_adjacency_matrix_by_assignment(W_hat, assignment):
    """
    Apply row and column transformations to adjacency matrix using MCC assignment permutation
    
    :param W_hat: estimated adjacency matrix
    :param assignment: variable matching determined by MCC
    :return: permuted matrix
    """
    d = W_hat.shape[0]
    if len(assignment) != d:
        raise ValueError(f"Assignment length ({len(assignment)}) must match matrix dimension ({d})")
    
    # Use permutation matrix method
    W_permuted = permute_matrix_by_permutation(W_hat, assignment)
    return W_permuted


def evaluate_graph_metrics_with_permutation(W_true, W_hat, assignment, thresh=.3, nr_edges=1):
    """
    Calculate graph structure metrics using MCC assignment permutation
    
    :param W_true: true adjacency matrix
    :param W_hat: estimated adjacency matrix
    :param assignment: variable matching determined by MCC
    :param thresh: threshold
    :param nr_edges: edge count matching parameter
    :return: dictionary containing all metrics
    """
    print(f"Using MCC assignment: {assignment}")
    W_hat_permuted = permute_adjacency_matrix_by_assignment(W_hat, assignment)
    
    # Call original evaluate_graph_metrics function
    return evaluate_graph_metrics(W_true, W_hat_permuted, thresh, nr_edges)