#!/usr/bin/env python3
"""
EF-based MI estimator tailored for Hierarchical VAEs with depth selection.

We approximate two quantities for a chosen depth l (1-indexed):

1) IF-based quadratic upper bound for I(\phi_{1:l}; U | X^n)
   - Uses a diagonal empirical Fisher approximation H^{-1} ≈ diag(1 / (EF + damping)).
   - Gradients are computed from a partial ELBO that includes reconstruction and KL
     terms only up to layer l. Upper layers (> l) are detached so their parameters
     and latents do not contribute.
   - Parameter scope mask selects encoder-related parameters that influence layers 1..l.

2) Maximum-entropy covariance upper bound for I(Z_{1:l, U}; U | \phi_{1:l}, X^n)
   - For the chosen l, we collect approximate posterior parameters for z_1..z_l.
   - We form a block-diagonal covariance bound using Cov[mu_{1:l}] + E[diag(var_{1:l})]
     and apply a Gaussian entropy upper bound to the mixture.

Supported architectures: HierarchicalMLPVAE (arch=hmlp), HierarchicalConvVAE (arch=hcnn).

Output JSON contains per-split results and summaries.
"""

import os
import sys
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import Subset, DataLoader

# Add project root to sys.path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from utils.model_utils import load_trained_model, load_experiment_data_splits
from utils.data_utils import get_mnist_data, get_fashion_mnist_data


@dataclass
class EstimatorConfig:
    experiment_dir: str
    device: str = "cpu"
    max_splits: Optional[int] = None
    max_samples_per_split: int = 200
    ef_max_train_samples: int = 1000
    damping: float = 1e-3
    save_filename: str = "mi_hierarchical_ef.json"
    depth_l: int = 1  # 1..L
    mode: str = "both"  # {both, if_params_u, zu_upper}
    z_batch_size: int = 256
    z_max_train_samples: int = -1
    z_cov_jitter: float = 1e-6
    z_target: str = "layer_l"  # {layer_l, concat_1_to_l}


def parse_args() -> EstimatorConfig:
    import argparse
    parser = argparse.ArgumentParser(description="Hierarchical EF-based MI estimation with depth selection")
    parser.add_argument("--experiment_dir", type=str, required=True,
                        help="Path to experiment directory under results/experiments/.../")
    parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda"],
                        help="Computation device")
    parser.add_argument("--max_splits", type=int, default=None,
                        help="Process at most this many splits (from split_0 upward)")
    parser.add_argument("--depth_l", type=int, default=1,
                        help="Depth l (1-indexed) to focus on in hierarchical model")
    parser.add_argument("--mode", type=str, choices=["both", "if_params_u", "zu_upper"], default="both",
                        help="Which estimators to compute")
    parser.add_argument("--max_samples_per_split", type=int, default=200,
                        help="Number of training samples per split used for IF estimator")
    parser.add_argument("--ef_max_train_samples", type=int, default=1000,
                        help="Samples to estimate EF diagonal (-1: use all training samples)")
    parser.add_argument("--damping", type=float, default=1e-3,
                        help="EF diagonal damping added to denominator for stability")
    parser.add_argument("--save_filename", type=str, default="mi_hierarchical_ef.json",
                        help="Output JSON filename under experiment_dir")
    parser.add_argument("--z_batch_size", type=int, default=256,
                        help="Batch size for encoder forward pass in Z_{1:l} bound")
    parser.add_argument("--z_max_train_samples", type=int, default=-1,
                        help="Max train samples for Z bound (-1: use all)")
    parser.add_argument("--z_cov_jitter", type=float, default=1e-6,
                        help="Initial jitter added to covariance for stable logdet")
    parser.add_argument("--z_target", type=str, choices=["layer_l", "concat_1_to_l"], default="layer_l",
                        help="Which Z to bound: only layer l (default) or concatenation of 1..l")

    args = parser.parse_args()
    return EstimatorConfig(
        experiment_dir=args.experiment_dir,
        device=args.device,
        max_splits=args.max_splits,
        max_samples_per_split=args.max_samples_per_split,
        ef_max_train_samples=args.ef_max_train_samples,
        damping=args.damping,
        save_filename=args.save_filename,
        depth_l=args.depth_l,
        mode=args.mode,
        z_batch_size=args.z_batch_size,
        z_max_train_samples=args.z_max_train_samples,
        z_cov_jitter=args.z_cov_jitter,
        z_target=args.z_target,
    )


def _get_num_layers(model: nn.Module) -> int:
    # Both hierarchical model classes define num_layers
    if hasattr(model, "num_layers"):
        return int(getattr(model, "num_layers"))
    # Fallback for non-hierarchical models
    return 1


def _infer_posterior_lists(model: nn.Module, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
    # Use the model's private posterior routine if exposed; otherwise emulate via forward/encode
    if hasattr(model, "_infer_posterior"):
        return model._infer_posterior(x)  # type: ignore[attr-defined]
    # Non-hierarchical fallback returning lists of length 1
    recon, mu, logvar = model(x)
    return [mu], [logvar], [mu]


def _detach_above(q_mus: List[torch.Tensor], q_logvars: List[torch.Tensor], z_samples: List[torch.Tensor], l_index0: int) -> None:
    # Detach layers above l (indexes > l_index0)
    top = len(q_mus) - 1
    for idx in range(top, l_index0, -1):
        if q_mus[idx] is not None:
            q_mus[idx] = q_mus[idx].detach()
        if q_logvars[idx] is not None:
            q_logvars[idx] = q_logvars[idx].detach()
        if z_samples[idx] is not None:
            z_samples[idx] = z_samples[idx].detach()


def compute_partial_elbo_grad(model: nn.Module, x: torch.Tensor, beta: float, l: int) -> torch.Tensor:
    """
    Compute gradient of a partial ELBO that includes reconstruction from z1 and KLs
    for layers 1..l only. Upper-layer paths are detached to avoid contributions.
    Returns flattened gradient over all parameters (masking applied later).
    """
    device = x.device
    num_layers = _get_num_layers(model)
    l = max(1, min(int(l), num_layers))

    if hasattr(model, "_infer_posterior"):
        q_mus, q_logvars, z_samples = model._infer_posterior(x)  # type: ignore[attr-defined]
    else:
        # Non-hierarchical fallback
        recon, mu, logvar = model(x)
        loss = torch.nn.functional.binary_cross_entropy(recon, x, reduction='mean') + beta * (
            -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        )
        model.zero_grad(set_to_none=True)
        loss.backward()
        grads: List[torch.Tensor] = []
        for p in model.parameters():
            grads.append((torch.zeros_like(p) if p.grad is None else p.grad.detach()).view(-1))
        return torch.cat(grads)

    # Detach above l
    _detach_above(q_mus, q_logvars, z_samples, l_index0=l - 1)

    # Reconstruction from z1
    if hasattr(model, "decoder"):
        if getattr(model, "__class__").__name__ == "HierarchicalConvVAE":
            x_recon = model.decoder(z_samples[0], model.feature_channels, model.feature_size)  # type: ignore[attr-defined]
        else:
            x_recon = model.decoder(z_samples[0])
    else:
        # Fallback: use forward's recon
        x_recon, _, _ = model(x)

    recon_loss = torch.nn.functional.binary_cross_entropy(x_recon, x, reduction='mean')

    # KL terms up to l
    total_kl = 0.0
    # Top index
    top = num_layers - 1
    # Build prior parameters as in hierarchical models
    def gaussian_kl(q_mu: torch.Tensor, q_logvar: torch.Tensor, p_mu: torch.Tensor, p_logvar: torch.Tensor) -> torch.Tensor:
        q_logvar = torch.clamp(q_logvar, min=-10.0, max=10.0)
        p_logvar = torch.clamp(p_logvar, min=-10.0, max=10.0)
        var_ratio = torch.exp(q_logvar - p_logvar)
        diff = q_mu - p_mu
        inv_p_var = torch.exp(-p_logvar)
        kl_per_dim = 0.5 * (p_logvar - q_logvar + var_ratio + diff * diff * inv_p_var - 1.0)
        return kl_per_dim.sum(dim=-1).mean()

    # l-th to top need priors; above l already detached
    # KL for top layer if included
    if l - 1 == top:
        p_mu_top = torch.zeros_like(q_mus[top])
        p_logvar_top = torch.zeros_like(q_logvars[top])
        total_kl = total_kl + gaussian_kl(q_mus[top], q_logvars[top], p_mu_top, p_logvar_top)
        start = top - 1
    else:
        start = min(top - 1, l - 1)

    # For layers start..0 but only counting those <= l-1
    for li in range(start, -1, -1):
        if li > l - 1:
            continue
        parent = li + 1
        if hasattr(model, "p_mu_layers") and hasattr(model, "p_logvar_layers"):
            p_mu_l = model.p_mu_layers[li](z_samples[parent])  # type: ignore[attr-defined]
            p_logvar_l = model.p_logvar_layers[li](z_samples[parent])  # type: ignore[attr-defined]
        else:
            # Non-hierarchical prior
            p_mu_l = torch.zeros_like(q_mus[li])
            p_logvar_l = torch.zeros_like(q_logvars[li])
        total_kl = total_kl + gaussian_kl(q_mus[li], q_logvars[li], p_mu_l, p_logvar_l)

    total_loss = recon_loss + beta * total_kl
    model.zero_grad(set_to_none=True)
    total_loss.backward()

    # Flatten gradient
    grads: List[torch.Tensor] = []
    for p in model.parameters():
        grads.append((torch.zeros_like(p) if p.grad is None else p.grad.detach()).view(-1))
    return torch.cat(grads)


def flatten_grad(model: nn.Module) -> torch.Tensor:
    parts: List[torch.Tensor] = []
    for p in model.parameters():
        parts.append((torch.zeros_like(p) if p.grad is None else p.grad.detach()).view(-1))
    return torch.cat(parts)


def build_encoder_mask_upto_l(model: nn.Module, l: int) -> Optional[torch.Tensor]:
    """
    Build a binary mask over flattened parameters selecting encoder-side parameters
    that affect layers 1..l. We use name heuristics consistent with hierarchical_vae.py:
      - HMLP: encoder_trunk, u_mu_heads[:l], u_logvar_heads[:l]
      - HCNN: enc_cnn, enc_fc, u_mu_heads[:l], u_logvar_heads[:l]
    Decoder and prior (p_mu_layers, p_logvar_layers) are excluded to focus on encoder.
    """
    device = next(model.parameters()).device
    num_layers = _get_num_layers(model)
    l = max(1, min(int(l), num_layers))
    mask_parts: List[torch.Tensor] = []

    def include_name(param_name: str) -> bool:
        # Generic encoder parts
        if param_name.startswith("encoder."):
            return True
        # Hierarchical MLP
        if param_name.startswith("encoder_trunk."):
            return True
        # Hierarchical CNN
        if param_name.startswith("enc_cnn.") or param_name.startswith("enc_fc."):
            return True
        # Upward heads per layer
        if param_name.startswith("u_mu_heads."):
            # format u_mu_heads.{idx}.weight/bias
            try:
                idx = int(param_name.split(".")[1])
                return idx <= (l - 1)
            except Exception:
                return False
        if param_name.startswith("u_logvar_heads."):
            try:
                idx = int(param_name.split(".")[1])
                return idx <= (l - 1)
            except Exception:
                return False
        return False

    for name, p in model.named_parameters():
        take = include_name(name)
        mask_parts.append((torch.ones(p.numel(), dtype=torch.float32, device=device) if take else torch.zeros(p.numel(), dtype=torch.float32, device=device)))

    return torch.cat(mask_parts) if mask_parts else None


def compute_empirical_fisher_diagonal_partial(
    model: nn.Module,
    beta: float,
    dataset_subset: Subset,
    max_samples: int,
    device: torch.device,
    depth_l: int,
    mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    model.eval()
    ef_diag: Optional[torch.Tensor] = None

    indices: List[int] = list(dataset_subset.indices)
    if max_samples is not None and max_samples >= 0 and len(indices) > max_samples:
        indices = indices[:max_samples]

    for idx in indices:
        x, _ = dataset_subset.dataset[idx]
        x = x.to(device).unsqueeze(0)
        g = compute_partial_elbo_grad(model, x, beta, l=depth_l)
        if mask is not None:
            g = g * mask
        g2 = g.pow(2)
        if ef_diag is None:
            ef_diag = g2
        else:
            ef_diag.add_(g2)

    if ef_diag is None:
        # No samples
        ef_diag = torch.zeros(sum(p.numel() for p in model.parameters()), dtype=torch.float32, device=device)
    else:
        ef_diag /= max(1, len(indices))

    return ef_diag


def compute_if_upper_bound_for_split(
    model: nn.Module,
    beta: float,
    train_subset: Subset,
    ef_diag: torch.Tensor,
    damping: float,
    max_samples_per_split: int,
    device: torch.device,
    depth_l: int,
    mask: Optional[torch.Tensor] = None,
) -> Dict:
    n_train = len(train_subset.indices)
    if n_train == 0:
        return {"num_train_samples": 0, "M": 0, "upper_bound": None}

    indices: List[int] = list(train_subset.indices)
    if max_samples_per_split is not None and max_samples_per_split >= 0 and len(indices) > max_samples_per_split:
        indices = indices[:max_samples_per_split]
    M = len(indices)
    if M == 0:
        return {"num_train_samples": n_train, "M": 0, "upper_bound": None}

    H_inv_diag = (ef_diag + damping).reciprocal().detach()
    if mask is not None:
        H_inv_diag = H_inv_diag * mask
    H_inv_diag = H_inv_diag.cpu()

    sum_qi = 0.0
    sumG = torch.zeros_like(H_inv_diag)

    for idx in indices:
        x, _ = train_subset.dataset[idx]
        x = x.to(device).unsqueeze(0)
        g_i = compute_partial_elbo_grad(model, x, beta, l=depth_l).detach()
        if mask is not None:
            g_i = g_i * mask
        g_i = g_i.cpu()
        sum_qi += torch.sum(H_inv_diag * (g_i * g_i)).item()
        sumG.add_(g_i)

    sumG2_w = torch.sum(H_inv_diag * (sumG * sumG)).item()
    sum_pair = 2.0 * M * sum_qi - 2.0 * sumG2_w
    mean_pair = sum_pair / (M * M)
    upper_bound = (1.0 / (2.0 * n_train)) * mean_pair

    return {
        "num_train_samples": n_train,
        "M": M,
        "sum_qi": sum_qi,
        "sumG2_w": sumG2_w,
        "mean_pair": mean_pair,
        "upper_bound": upper_bound,
    }


def _stable_logdet_psd(cov: torch.Tensor, init_jitter: float = 1e-6, max_tries: int = 6) -> Tuple[float, float]:
    assert cov.ndim == 2 and cov.shape[0] == cov.shape[1]
    d = cov.shape[0]
    eye = torch.eye(d, dtype=cov.dtype, device=cov.device)
    jitter = init_jitter
    for _ in range(max_tries):
        try:
            L = torch.linalg.cholesky(cov + jitter * eye)
            logdet = 2.0 * torch.log(torch.diag(L)).sum().item()
            return logdet, jitter
        except RuntimeError:
            jitter *= 10.0
            continue
    evals = torch.linalg.eigvalsh(cov)
    eps = max(jitter, 1e-12)
    logdet = torch.log(torch.clamp(evals, min=eps)).sum().item()
    return logdet, jitter


def compute_mi_z_upper_for_split(
    model: nn.Module,
    train_subset: Subset,
    device: torch.device,
    batch_size: int,
    max_train_samples: int,
    cov_jitter: float,
    depth_l: int,
    target: str,
) -> Dict:
    model.eval()
    indices: List[int] = list(train_subset.indices)
    if max_train_samples is not None and max_train_samples >= 0 and len(indices) > max_train_samples:
        indices = indices[:max_train_samples]
    if len(indices) == 0:
        return {"num_train_samples": 0, "mi_upper_bound": None}

    subset = Subset(train_subset.dataset, indices)
    loader = DataLoader(subset, batch_size=max(1, batch_size), shuffle=False)

    # We build block statistics across concatenated z1..zl
    sum_mu: Optional[torch.Tensor] = None
    sum_outer_mu: Optional[torch.Tensor] = None
    sum_var_diag: Optional[torch.Tensor] = None
    sum_logvar_total: float = 0.0
    n_count: int = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            if hasattr(model, "_infer_posterior"):
                q_mus, q_logvars, _ = model._infer_posterior(x)  # type: ignore[attr-defined]
                if target == "layer_l":
                    idx = min(depth_l, len(q_mus)) - 1
                    mu_cat = q_mus[idx]
                    logvar_cat = q_logvars[idx]
                else:
                    mus = [q_mus[i] for i in range(min(depth_l, len(q_mus)))]
                    logvars = [q_logvars[i] for i in range(min(depth_l, len(q_logvars)))]
                    mu_cat = torch.cat(mus, dim=1)
                    logvar_cat = torch.cat(logvars, dim=1)
            else:
                recon, mu, logvar = model(x)
                mu_cat = mu
                logvar_cat = logvar

            var_cat = torch.exp(logvar_cat)
            if sum_mu is None:
                d = mu_cat.shape[1]
                sum_mu = torch.zeros(d, dtype=torch.float64, device=device)
                sum_outer_mu = torch.zeros(d, d, dtype=torch.float64, device=device)
                sum_var_diag = torch.zeros(d, dtype=torch.float64, device=device)
            sum_mu = sum_mu + mu_cat.to(torch.float64).sum(dim=0)
            sum_outer_mu = sum_outer_mu + mu_cat.to(torch.float64).T @ mu_cat.to(torch.float64)
            sum_var_diag = sum_var_diag + var_cat.to(torch.float64).sum(dim=0)
            sum_logvar_total += logvar_cat.sum().item()
            n_count += mu_cat.shape[0]

    d = sum_mu.shape[0]  # type: ignore[union-attr]
    mu_mix = (sum_mu / max(1, n_count)).to(torch.float64)
    mean_outer_mu = (sum_outer_mu / max(1, n_count)).to(torch.float64)
    cov_mu = mean_outer_mu - torch.outer(mu_mix, mu_mix)
    cov_mu = 0.5 * (cov_mu + cov_mu.T)
    mean_var_diag = (sum_var_diag / max(1, n_count))
    Sigma_mix = cov_mu + torch.diag(mean_var_diag)
    Sigma_mix_cpu = Sigma_mix.detach().cpu().to(torch.float64)
    logdet, jitter_used = _stable_logdet_psd(Sigma_mix_cpu, init_jitter=cov_jitter)
    avg_sum_logvar = sum_logvar_total / max(1, n_count)
    mi_upper = 0.5 * (logdet - float(avg_sum_logvar))

    return {
        "num_train_samples": n_count,
        "latent_dim": int(d),
        "logdet_sigma_mix": float(logdet),
        "avg_sum_logvar": float(avg_sum_logvar),
        "jitter_used": float(jitter_used),
        "mi_upper_bound": float(mi_upper),
        "z_target": target,
    }


def main() -> None:
    cfg = parse_args()
    device = torch.device(cfg.device if (cfg.device == 'cuda' and torch.cuda.is_available()) else 'cpu')

    splits, metadata = load_experiment_data_splits(cfg.experiment_dir)
    num_splits = metadata['num_splits']
    if cfg.max_splits is not None:
        num_splits = min(num_splits, cfg.max_splits)

    want_if = cfg.mode in ("both", "if_params_u")
    want_zu = cfg.mode in ("both", "zu_upper")

    per_split_if: List[Dict] = []
    per_split_zu: List[Dict] = []
    per_split_combined: List[Dict] = []

    for split_idx in range(num_splits):
        try:
            model, model_config = load_trained_model(cfg.experiment_dir, split_idx, device)
        except Exception as e:
            print(f"[WARN] Skip split {split_idx}: {e}")
            continue

        dataset_name = model_config.get('data', {}).get('dataset', 'mnist')
        arch = model_config.get('model', {}).get('arch', 'mlp')
        # Select flatten based on arch
        flatten = arch in {"mlp", "hmlp"}
        if dataset_name.lower() in ('fashion_mnist', 'fashion', 'fmnist', 'fashion-mnist', 'fashionmnist'):
            train_dataset, _ = get_fashion_mnist_data(flatten=flatten)
        else:
            train_dataset, _ = get_mnist_data(flatten=flatten)

        train_indices, _ = splits[split_idx]
        train_subset = Subset(train_dataset, train_indices)

        this_if_val: Optional[float] = None
        this_zu_val: Optional[float] = None

        if want_if:
            beta = float(model_config['training']['beta'])
            enc_mask = build_encoder_mask_upto_l(model, l=cfg.depth_l)
            ef_diag = compute_empirical_fisher_diagonal_partial(
                model=model,
                beta=beta,
                dataset_subset=train_subset,
                max_samples=cfg.ef_max_train_samples,
                device=device,
                depth_l=cfg.depth_l,
                mask=enc_mask,
            )
            if_info = compute_if_upper_bound_for_split(
                model=model,
                beta=beta,
                train_subset=train_subset,
                ef_diag=ef_diag,
                damping=cfg.damping,
                max_samples_per_split=cfg.max_samples_per_split,
                device=device,
                depth_l=cfg.depth_l,
                mask=enc_mask,
            )
            if_info["split"] = split_idx
            if_info["depth_l"] = int(cfg.depth_l)
            per_split_if.append(if_info)
            this_if_val = if_info.get("upper_bound")

        if want_zu:
            zu_info = compute_mi_z_upper_for_split(
                model=model,
                train_subset=train_subset,
                device=device,
                batch_size=cfg.z_batch_size,
                max_train_samples=cfg.z_max_train_samples,
                cov_jitter=cfg.z_cov_jitter,
                depth_l=cfg.depth_l,
                target=cfg.z_target,
            )
            zu_info["split"] = split_idx
            zu_info["depth_l"] = int(cfg.depth_l)
            per_split_zu.append(zu_info)
            this_zu_val = zu_info.get("mi_upper_bound")

        comb_val = None
        if (this_if_val is not None) and (this_zu_val is not None):
            comb_val = float(this_if_val) + float(this_zu_val)
        per_split_combined.append({
            "split": split_idx,
            "depth_l": int(cfg.depth_l),
            "if_params_u_upper": (float(this_if_val) if this_if_val is not None else None),
            "z1tol_upper": (float(this_zu_val) if this_zu_val is not None else None),
            "combined_upper": (float(comb_val) if comb_val is not None else None),
        })

    # Summaries
    if_summary = None
    if want_if and len(per_split_if) > 0:
        vals = [d.get("upper_bound") for d in per_split_if if d.get("upper_bound") is not None]
        if len(vals) > 0:
            if_summary = {
                "mean_upper_bound": float(np.mean(vals)),
                "std_upper_bound": float(np.std(vals)),
                "num_splits": len(vals),
            }

    zu_summary = None
    if want_zu and len(per_split_zu) > 0:
        vals = [d.get("mi_upper_bound") for d in per_split_zu if d.get("mi_upper_bound") is not None]
        if len(vals) > 0:
            zu_summary = {
                "mean_upper_bound": float(np.mean(vals)),
                "std_upper_bound": float(np.std(vals)),
                "num_splits": len(vals),
            }

    comb_summary = None
    if (want_if and want_zu) and len(per_split_combined) > 0:
        vals = [d.get("combined_upper") for d in per_split_combined if d.get("combined_upper") is not None]
        if len(vals) > 0:
            comb_summary = {
                "mean_upper_bound": float(np.mean(vals)),
                "std_upper_bound": float(np.std(vals)),
                "num_splits": len(vals),
            }

    out: Dict[str, object] = {"estimator_config": asdict(cfg)}
    if want_if:
        out["if_upper_bounds_per_split"] = per_split_if
        out["if_upper_bounds_summary"] = if_summary
    if want_zu:
        out["mi_z_upper_per_split"] = per_split_zu
        out["mi_z_upper_summary"] = zu_summary
    if want_if and want_zu:
        out["combined_per_split"] = per_split_combined
        out["combined_summary"] = comb_summary

    # Default filename disambiguation by depth
    final_filename = cfg.save_filename
    if cfg.save_filename == "mi_hierarchical_ef.json":
        final_filename = f"mi_hierarchical_ef_l{int(cfg.depth_l)}.json"

    save_path = os.path.join(cfg.experiment_dir, final_filename)
    with open(save_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"Saved hierarchical MI (EF) results to: {save_path}")


if __name__ == "__main__":
    main()


