from typing import Tuple, Optional
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
import optax
from omegaconf import OmegaConf as oc
from jaxmodels_nnx import build_model

from reps.evalfactors import factor_predictor_train, factor_predictor_eval


def _compute_bin_indices(values: np.ndarray, n_bins: int) -> np.ndarray:
    """Discretise *values* into *n_bins* equal‐frequency bins (quantiles).

    Parameters
    ----------
    values : np.ndarray, shape (N,)
        1-D array of continuous values.
    n_bins : int
        Number of quantile bins.

    Returns
    -------
    np.ndarray, shape (N,)
        Integer bin index in [0, n_bins-1] for each sample.
    """
    # Percentile edges (0, 100) inclusive.
    q_edges = np.linspace(0.0, 100.0, n_bins + 1, dtype=np.float32)
    try:
        bin_edges = np.percentile(values, q_edges)
    except IndexError:
        # Degenerate case – not enough points.  Map all to bin 0.
        return np.zeros_like(values, dtype=np.int32)

    # Make sure edges are strictly increasing to avoid *digitize* issues.
    # If not, fall back to a single-bin mapping for this dimension.
    if np.unique(bin_edges).size < 2:
        return np.zeros_like(values, dtype=np.int32)

    # Exclude right edge so last bin is [edge_{n-1}, edge_n]
    return np.digitize(values, bin_edges[1:-1], right=False).astype(np.int32)


def stratified_uniform_sample(
    observations: jnp.ndarray,
    ground_truth: jnp.ndarray,
    *,
    discrete: bool = False,
    n_bins: int = 10,
    rng: jax.random.PRNGKey = jax.random.PRNGKey(0),
    max_samples_per_bin: Optional[int] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Return a subset of *(observations, ground_truth)* with (approx.) uniform GT distribution.

    Each ground-truth dimension is first discretised into *n_bins* equal-frequency
    bins.  The joint bins across *all* dimensions define strata.  We then sample
    **uniformly** across strata, i.e. the same number of samples is taken from
    every occupied stratum (limited by *max_samples_per_bin* if provided).

    This aims at removing biases caused by the over-representation of certain
    states/factors when training/evaluating the predictor networks.
    """
    obs_np = np.asarray(jax.device_get(observations))
    gt_np = np.asarray(jax.device_get(ground_truth))

    N, D = gt_np.shape

    # -------------------------------------------------------------
    # 1. Discretise each factor into bins
    #    - if discrete, treat each unique value as its own bin
    #    - else perform quantile-binning into n_bins
    # -------------------------------------------------------------
    if discrete:
        # Each ground-truth factor is a discrete label; cast to int directly
        bin_idx_per_dim = [gt_np[:, d].astype(np.int32) for d in range(D)]
    else:
        # Continuous variables: quantile-based binning
        bin_idx_per_dim = [
            _compute_bin_indices(gt_np[:, d], n_bins) for d in range(D)
        ]
    bin_idx_per_dim = np.stack(bin_idx_per_dim, axis=1)  # (N, D)

    # -------------------------------------------------------------
    # 2. Map each sample to a unique stratum ID (joint bins)
    # -------------------------------------------------------------
    if D == 1:
        strata_ids = bin_idx_per_dim[:, 0]
        n_strata = n_bins
    else:
        # Convert the tuple of bin indexes to a single integer via
        # base-*n_bins* encoding: id = Σ bin_d * n_bins^d
        mult_factors = (n_bins ** np.arange(D, dtype=np.int64))
        strata_ids = (bin_idx_per_dim * mult_factors).sum(axis=1)
        n_strata = n_bins ** D

    # Group indices by stratum
    strata_to_indices: dict[int, list[int]] = {}
    for idx, sid in enumerate(strata_ids):
        strata_to_indices.setdefault(int(sid), []).append(idx)

    # Remove empty strata (common in high-D cases)
    strata_to_indices = {k: v for k, v in strata_to_indices.items() if len(v) > 0}
    if len(strata_to_indices) == 0:
        raise ValueError("All strata are empty after binning – check the data.")

    # -------------------------------------------------------------
    # 3. Decide how many samples per stratum
    # -------------------------------------------------------------
    
    
    counts = np.array([len(idxs) for idxs in strata_to_indices.values()])
    probs = counts / counts.sum()
    print(f'min: {min(counts)}, max: {max(counts)}, max_prob: {max(probs)}, min_prob: {min(probs)}')

    min_count = min(counts)
    sample_per_stratum = min_count if max_samples_per_bin is None else min(
        max_samples_per_bin, min_count
    )

    # Plot histogram of counts distribution
    import matplotlib.pyplot as plt
    
    plt.figure(figsize=(10, 6))
    plt.bar(range(len(counts)), counts, color='steelblue', edgecolor='black')
    plt.xlabel('Stratum Index')
    plt.ylabel('Number of Samples')
    plt.title('Distribution of Samples across Strata')
    plt.grid(alpha=0.3)
    
  
    plt.tight_layout()
    plt.savefig('stratum_distribution.png')
    plt.close()


    # -------------------------------------------------------------
    # 4. Uniform sampling across strata
    # -------------------------------------------------------------
    rng_int = int(jax.random.randint(rng, (), 0, 2**31 - 1))
    rng_np = np.random.RandomState(rng_int)

    selected_indices: list[int] = []
    for idxs in strata_to_indices.values():
        if len(idxs) <= sample_per_stratum:
            selected_indices.extend(idxs)
        else:
            chosen = rng_np.choice(idxs, sample_per_stratum, replace=False)
            selected_indices.extend(chosen.tolist())

    selected_indices = np.asarray(selected_indices, dtype=np.int64)

    # -------------------------------------------------------------
    # 5. Return sub-set as JAX arrays (preserve original device)
    # -------------------------------------------------------------
    obs_out = jnp.asarray(obs_np[selected_indices])
    gt_out = jnp.asarray(gt_np[selected_indices])
    return obs_out, gt_out


# ===============================================
# Convenience wrappers around existing predictors
# ===============================================

def factor_predictor_train_balanced(
    test_data: jnp.ndarray,
    ground_truth: jnp.ndarray,
    eval_config: str,
    model,
    rng: jax.random.PRNGKey,
    *,
    discrete: bool = False,
    gt_n_classes: Optional[int] = None,
    n_bins: int = 10,
    noise_std: float = 0.0,
    debug: bool = False,
):
    """Balanced counterpart to *factor_predictor_train*.

    Prior to calling the original training routine, the dataset is
    *stratified-uniformly sampled* such that the ground-truth factors follow a
    more balanced distribution, mitigating spurious correlations stemming from
    heavily over-sampled regions of the state-space.
    """
    # Split RNG for sampling, noise, and inner training
    rng_sample, rng_noise, rng_inner = jax.random.split(rng, 3)
    # 1) Whiten distribution via stratified sampling
    data_bal, gt_bal = stratified_uniform_sample(
        test_data, ground_truth, discrete=discrete, n_bins=n_bins, rng=rng_sample
    )

    min_gt = jnp.min(gt_bal, axis=0, keepdims=True)
    max_gt = jnp.max(gt_bal, axis=0, keepdims=True)
    gt_norm = (gt_bal - min_gt) / (max_gt - min_gt + 1e-8)
    noise = noise_std * jax.random.normal(rng_noise, gt_norm.shape)
    gt_bal = gt_norm + noise

    print(jax.tree.map(lambda x: x.shape, (data_bal, gt_bal)))
    # 3) Delegate to original training routine
    return factor_predictor_train(
        data_bal,
        gt_bal,
        eval_config,
        model,
        rng_inner,
        debug=debug,
    )


def factor_predictor_eval_balanced(
    states: jnp.ndarray,
    gt_states: jnp.ndarray,
    gt_predictor,
    factor_predictor,
    model,
    rng: jax.random.PRNGKey,
    *,
    discrete: bool = False,
    gt_n_classes: Optional[int] = None,
    n_bins: int = 10,
    noise_std: float = 0.0,
    debug: bool = False,
):
    """Evaluate predictors on a *balanced* subset of the dataset.

    The same stratified-uniform sampling used during training is applied so
    that evaluation scores reflect performance uniformly across the factor
    space rather than being dominated by majority regions.
    """
    # Split RNG for sampling and noise
    rng_sample, rng_noise = jax.random.split(rng, 2)
    # 1) Whiten distribution via stratified sampling
    states_bal, gt_bal = stratified_uniform_sample(
        states, gt_states, discrete=discrete, n_bins=n_bins, rng=rng_sample
    )
    # 2) Optionally normalize continuous GT to [0,1] and add Gaussian noise
    min_gt = jnp.min(gt_bal, axis=0, keepdims=True)
    max_gt = jnp.max(gt_bal, axis=0, keepdims=True)
    gt_norm = (gt_bal - min_gt) / (max_gt - min_gt + 1e-8)
    noise = noise_std * jax.random.normal(rng_noise, gt_norm.shape)
    gt_bal = gt_norm + noise

    # 3) Delegate to original evaluation routine
    return factor_predictor_eval(
        states_bal,
        gt_bal,
        gt_predictor,
        factor_predictor,
        model,
        debug=debug,
    )


def compute_importance_weights(
    ground_truth: jnp.ndarray,
    *,
    discrete: bool = False,
    n_bins: int = 10,
    max_weight: float = 100.0,
) -> jnp.ndarray:
    """Compute importance weights inversely proportional to GT factor density.
    
    Parameters
    ----------
    ground_truth : jnp.ndarray
        Ground truth factors of shape (N, D)
    discrete : bool, default False
        Whether ground truth contains discrete (integer) factors
    n_bins : int, default 10
        Number of bins for continuous factors
    max_weight : float, default 100.0
        Maximum weight to assign (prevents exploding weights for rare points)
        
    Returns
    -------
    jnp.ndarray
        Array of shape (N,) containing importance weights for each sample
    """
    gt_np = np.asarray(jax.device_get(ground_truth))
    N, D = gt_np.shape
    
    # 1) Bin data same as in stratified sampling
    if discrete:
        bin_idx_per_dim = [gt_np[:, d].astype(np.int32) for d in range(D)]
    else:
        bin_idx_per_dim = [
            _compute_bin_indices(gt_np[:, d], n_bins) for d in range(D)
        ]
    bin_idx_per_dim = np.stack(bin_idx_per_dim, axis=1)
    
    # 2) Map to stratum IDs 
    if D == 1:
        strata_ids = bin_idx_per_dim[:, 0]
    else:
        mult_factors = (n_bins ** np.arange(D, dtype=np.int64))
        strata_ids = (bin_idx_per_dim * mult_factors).sum(axis=1)
    
    # 3) Count samples per stratum
    unique_strata, counts = np.unique(strata_ids, return_counts=True)
    stratum_to_count = dict(zip(unique_strata, counts))
    
    # 4) Assign inverse frequency weights
    weights = np.array([N / stratum_to_count[sid] for sid in strata_ids])
    
    # 6) Clip extreme weights to prevent optimization instability
    # weights = np.minimum(weights, max_weight)
    
    return jnp.array(weights)


def factor_predictor_train_weighted(
    test_data: jnp.ndarray,
    ground_truth: jnp.ndarray,
    eval_config: str,
    model,
    rng: jax.random.PRNGKey,
    *,
    discrete: bool = False,
    gt_n_classes: Optional[int] = None,
    n_bins: int = 10,
    noise_std: float = 0.0,
    max_weight: float = 100.0,
    debug: bool = False,
):
    """Importance-weighted version of factor predictor training.
    
    Instead of uniform subsampling, this computes weights inversely proportional
    to the density of ground truth factors, then trains a weighted model.
    This retains all data points but effectively upweights rare states.
    
    Parameters
    ----------
    max_weight : float, default 100.0
        Maximum weight multiplier to assign to prevent numerical issues
    """
    
    # Split RNG for noise and inner training
    rng_noise, rng_inner = jax.random.split(rng)
    
    # 1) Compute importance weights - no sampling, keep all data
    weights = compute_importance_weights(
        ground_truth, discrete=discrete, n_bins=n_bins, max_weight=max_weight
    )
    
    # Print diagnostic information
    if debug:
        w_min, w_max = float(weights.min()), float(weights.max())
        n_unique = int(len(np.unique(jax.device_get(ground_truth), axis=0)))
        print(f"Weight range: [{w_min:.3f}, {w_max:.3f}], {n_unique} unique GT points")
    
    # 2) Apply normalization and noise to GT (always, per user changes)
    min_gt = jnp.min(ground_truth, axis=0, keepdims=True)
    max_gt = jnp.max(ground_truth, axis=0, keepdims=True)
    gt_norm = (ground_truth - min_gt) / (max_gt - min_gt + 1e-8)
    noise = noise_std * jax.random.normal(rng_noise, gt_norm.shape)
    gt_noisy = gt_norm + noise
    
    
    # 3b) Create a patched version that includes weights
    def weighted_predictor_train(
        test_data, ground_truth, eval_config, model, rng, 
        *, discrete=False, gt_n_classes=None, debug=False, weights=None
    ):
        """Patched version that incorporates sample weights."""
        cfg = oc.load(eval_config)
        ground_truth_dim = ground_truth.shape[-1]
        
        # Use debug parameters if debug mode is enabled
        if debug and hasattr(cfg, 'debug'):
            cfg.batch_size = cfg.debug.batch_size
            cfg.n_epochs = cfg.debug.n_epochs
        
        # Copy the model building code
        if gt_discrete := (gt_n_classes is not None):
            cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
            cfg.predictor.output_dim = ground_truth_dim * gt_n_classes
        else:
            cfg.predictor.input_dim = model.config.get('vars_per_factor', 1)
            cfg.predictor.output_dim = ground_truth_dim
        rng, rng_gt, rng_factor = jax.random.split(rng, 3)
        rngs = jax.random.split(rng_gt, model.n_factors)
        gt_predictor = nnx.vmap(
            lambda rng: build_model(cfg.predictor, nnx.Rngs(rng)),
            in_axes=(0,)
        )(rngs)

        # Build factor predictor
        if discrete:
            n_classes = model.config.vars_per_factor
            cfg.predictor.input_dim = 1
            cfg.predictor.output_dim = model.n_factors * n_classes
        else:
            cfg.predictor.input_dim = 1
            cfg.predictor.output_dim = model.n_factors * model.config.vars_per_factor
        rngs = jax.random.split(rng_factor, ground_truth_dim)
        factor_predictor = nnx.vmap(
            lambda rng: build_model(cfg.predictor, nnx.Rngs(rng)),
            in_axes=(0,)
        )(rngs)

        # optimizer setup
        tx = optax.adamw(cfg.lr)
        optstate = tx.init(
            (
                nnx.state(gt_predictor, nnx.Param),
                nnx.state(factor_predictor, nnx.Param)
            )
        )
        
        # compute latents in batches to avoid memory issues
        batch_size = 128
        n_samples = test_data.shape[0]
        latents_list = []
        for i in range(0, n_samples, batch_size):
            batch = test_data[i:i+batch_size]
            batch_latents = model.encode(batch, states=ground_truth[i:i+batch_size])
            latents_list.append(batch_latents)
        latents = jnp.concatenate(latents_list, axis=0).reshape(
            -1,
            model.n_factors,
            model.config.vars_per_factor
        )
        
        # Transform latents
        from src.reps.evalfactors import transform_latents_to_range
        latents = transform_latents_to_range(latents)
        
        N = latents.shape[0]

        # Modified update step with weighted loss
        def _update_step(
            training_state,
            rng
        ):
            gt_predictor, factor_predictor, optstate = training_state
            gt_predictor = nnx.merge(*gt_predictor)
            factor_predictor = nnx.merge(*factor_predictor)

            rng_data, rng_train = jax.random.split(rng)

            def _sample_data(rng, dataset):
                idx = jax.random.randint(
                    rng,
                    (cfg.batch_size,),
                    0,
                    dataset[0].shape[0]
                )
                return jax.tree.map(lambda x: x[idx], dataset)

            def _weighted_loss(gt_predictor, factor_predictor, latents, ground_truth, sample_weights):
                gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(gt_predictor, latents).swapaxes(0,1)
                if gt_discrete:
                    # Reshape gt_pred to (batch, n_predictors, n_classes)
                    gt_pred = gt_pred.reshape(*gt_pred.shape[:-1], ground_truth_dim, gt_n_classes)
                    # Convert ground truth to one-hot
                    gt_labels = ground_truth[:, None].repeat(gt_pred.shape[1], axis=1).astype(jnp.int32)
                    gt_loss = optax.softmax_cross_entropy_with_integer_labels(gt_pred, gt_labels)
                    # Apply weights
                    gt_loss = (gt_loss * sample_weights[:, None, None]).mean()
                else:
                    unweighted_loss = optax.l2_loss(gt_pred, ground_truth[:, None].repeat(gt_pred.shape[1], axis=1))
                    # Apply weights
                    gt_loss = (unweighted_loss * sample_weights[:, None, None]).mean()

                if discrete:
                    # factor predictor yields logits for classes.
                    n_classes = model.config.vars_per_factor
                    factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, ground_truth[..., None]).swapaxes(0,1)
                    # reshape to (batch, predictors, n_factors, n_classes)
                    factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, n_classes)
                    # assume ground_truth contains discrete labels for each factor
                    labels = latents.reshape(factor_pred.shape[0], model.n_factors, n_classes).argmax(-1)
                    factor_loss = optax.softmax_cross_entropy_with_integer_labels(factor_pred, labels[:, None].repeat(factor_pred.shape[1], axis=1))
                    # Apply weights
                    factor_loss = (factor_loss * sample_weights[:, None, None, None]).mean()
                else:
                    factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0,1))(factor_predictor, ground_truth[..., None]).swapaxes(0,1)
                    factor_pred = factor_pred.reshape(*factor_pred.shape[0:2], model.n_factors, model.config.vars_per_factor)
                    unweighted_loss = optax.l2_loss(factor_pred, latents[:, None].repeat(factor_pred.shape[1], axis=1))
                    # Apply weights
                    factor_loss = (unweighted_loss * sample_weights[:, None, None, None]).mean()
                
                loss = gt_loss + factor_loss
                return loss
                
            z, s, w = _sample_data(rng_data, (latents, ground_truth, weights))
            loss, grads = nnx.value_and_grad(_weighted_loss, argnums=(0,1))(
                gt_predictor, factor_predictor,
                z, s, w
            )

            model_params = (nnx.state(gt_predictor, nnx.Param), nnx.state(factor_predictor, nnx.Param))
            updates, optstate = tx.update(grads, optstate, model_params)
            gt_predictor_params, factor_predictor_params = optax.apply_updates(model_params, updates)

            nnx.update(gt_predictor, gt_predictor_params)
            nnx.update(factor_predictor, factor_predictor_params)

            return (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate), loss
        
        # Standard training loop with our weighted update
        n_training_steps = int(N / cfg.batch_size * cfg.n_epochs)
        rng, rng_train = jax.random.split(rng)
        (gt_predictor, factor_predictor, optstate), losses = jax.lax.scan(
            _update_step,
            (nnx.split(gt_predictor), nnx.split(factor_predictor), optstate),
            jax.random.split(rng_train, n_training_steps)
        )

        gt_predictor = nnx.merge(*gt_predictor)
        factor_predictor = nnx.merge(*factor_predictor)
        return gt_predictor, factor_predictor
        
    
    # 4) Call our weighted version
    return weighted_predictor_train(
        test_data,
        gt_noisy,
        eval_config,
        model,
        rng_inner,
        debug=debug,
        weights=weights
    )


def factor_predictor_eval_weighted(
    states: jnp.ndarray,
    gt_states: jnp.ndarray,
    gt_predictor,
    factor_predictor,
    model,
    rng,
    *,
    discrete: bool = False,
    gt_n_classes: Optional[int] = None,
    n_bins: int = 10,
    noise_std: float = 0.0,
    max_weight: float = 100.0,
    debug: bool = False,
):
    """Importance-weighted evaluation of factor predictors.
    
    Instead of uniform subsampling, this computes performance metrics
    weighted by importance weights (inverse frequency of GT factors).
    """
    # Compute importance weights
    weights = compute_importance_weights(
        gt_states, discrete=discrete, n_bins=n_bins, max_weight=max_weight
    )
    
    # Print diagnostic information
    if debug:
        w_min, w_max = float(weights.min()), float(weights.max())
        n_unique = int(len(np.unique(jax.device_get(gt_states), axis=0)))
        print(f"Eval weight range: [{w_min:.3f}, {w_max:.3f}], {n_unique} unique GT points")
    
    # Apply normalization and noise to GT
    min_gt = jnp.min(gt_states, axis=0, keepdims=True)
    max_gt = jnp.max(gt_states, axis=0, keepdims=True)
    gt_norm = (gt_states - min_gt) / (max_gt - min_gt + 1e-8)
    
    # Add noise with fixed seed for reproducibility
    rng = jax.random.PRNGKey(42)
    noise = noise_std * jax.random.normal(rng, gt_norm.shape)
    gt_noisy = gt_norm + noise
    
    # Compute latents in batches
    batch_size = 128
    n_samples = states.shape[0]
    latents_list = []
    for i in range(0, n_samples, batch_size):
        batch_states = states[i:i+batch_size]
        batch_latents = model.encode(batch_states, states=gt_noisy[i:i+batch_size])
        latents_list.append(batch_latents)
    latents = jnp.concatenate(latents_list, axis=0).reshape(
        -1, model.n_factors, model.config.vars_per_factor
    )
    
    # Transform latents to [-1, 1] range
    from src.reps.evalfactors import transform_latents_to_range, r2_score, f1_score
    latents = transform_latents_to_range(latents)
    
    # Evaluate gt_predictor
    gt_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(gt_predictor, latents)
    gt_pred = gt_pred.swapaxes(0, 1)
    
    # Evaluate factor_predictor
    # Add a dummy last-dimension as expected
    factor_pred = nnx.vmap(lambda m, x: m(x), in_axes=(0, 1))(
        factor_predictor, gt_noisy[..., None]
    )
    factor_pred = factor_pred.swapaxes(0, 1)
    
    # Compute weighted scores
    if gt_n_classes is not None:  # discrete GT
        import optax
        # Reshape gt_pred and apply weights to the F1 score
        gt_pred = gt_pred.reshape(-1, gt_pred.shape[1], gt_states.shape[-1], gt_n_classes)
        gt_states_labels = optax.one_hot(gt_noisy, gt_n_classes)
        gt_score = jax.vmap(lambda p, t: f1_score(p, t, reduce=False),
                           in_axes=(1, None))(gt_pred, gt_states_labels)
    else:
        # Apply weights to R2 score
        gt_score = jax.vmap(lambda p, t: r2_score(p, t, reduce=False),
                           in_axes=(1, None))(gt_pred, gt_noisy)
    
    if discrete:  # discrete latents
        # Use f1 score for the classifier
        n_classes = model.config.vars_per_factor
        factor_pred = factor_pred.reshape(-1, gt_noisy.shape[-1], model.n_factors, n_classes)
        
        labels = latents.reshape(-1, model.n_factors, n_classes).argmax(-1)
        labels_one_hot = jax.nn.one_hot(labels, n_classes)
        factor_score = jax.vmap(lambda p, t: f1_score(p, t, reduce=False),
                               in_axes=(1, None))(factor_pred, labels_one_hot)
    else:
        factor_pred = factor_pred.reshape(
            -1, gt_noisy.shape[-1], model.n_factors, model.config.vars_per_factor
        )
        factor_score = jax.vmap(lambda p, t: r2_score(p, t, reduce=True),
                               in_axes=(1, None))(factor_pred, latents)
    
    # Normalize weighted scores by sum of weights
    gt_score = gt_score / weights.sum()
    factor_score = factor_score / weights.sum()
    
    # Use existing permutation optimizer
    from src.permutation_optimizer import find_best_permutation
    
    gt_score_results = dict(zip(
        ('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'), 
        find_best_permutation(jnp.clip(gt_score, 0, 1), axis=1)
    ))
    
    factor_score_results = dict(zip(
        ('permuted_matrix', 'assignment', 'diag_score', 'off_diag_score'),
        find_best_permutation(jnp.clip(factor_score, 0, 1))
    ))
    
    return {
        'gt_score_matrix': gt_score_results,
        'gt_score/permuted_matrix': gt_score_results['permuted_matrix'],
        'gt_score/assignment': gt_score_results['assignment'],
        'gt_score/diag_score': gt_score_results['diag_score'],
        'gt_score/off_diag_score': gt_score_results['off_diag_score'],
        'factor_score_matrix': factor_score_results,
        'factor_score/permuted_matrix': factor_score_results['permuted_matrix'],
        'factor_score/assignment': factor_score_results['assignment'],
        'factor_score/diag_score': factor_score_results['diag_score'],
        'factor_score/off_diag_score': factor_score_results['off_diag_score'],
    }


