# %%
import glob
import os
import re

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from entropy_estimators import mi
from hsic import hsic_gam
from mdcrl import LitAutoEncoder, SimDataset
from pycomets.gcm import GCM
from pycomets.regression import LM, DefaultMultiRegression
from sklearn.linear_model import LinearRegression
from torch.utils.data import DataLoader
from torch_losses import hsic_poly, hsic_rbf


def get_checkpoint_from_outputs(exp_name, sim_id, selection_strategy="best"):
    """
    Looks in the Hydra output structure: outputs/{exp_name}/{sim_id}/checkpoints/

    Args:
        exp_name: Name of the experiment folder.
        sim_id: The simulation ID (folder name).
        selection_strategy: "best" (default) to look for best-*.ckpt,
                           "last" to look for last.ckpt.
    """
    base_dir = os.path.join("outputs", exp_name, str(sim_id), "checkpoints")

    if selection_strategy == "last":
        last_ckpt = os.path.join(base_dir, "last.ckpt")
        if os.path.exists(last_ckpt):
            return last_ckpt
        print(
            f"Warning: strategy='last' requested but {last_ckpt} not found. Falling back."
        )

    best_pattern = os.path.join(base_dir, "best-*.ckpt")
    best_files = glob.glob(best_pattern)

    if best_files:
        # Sort by modification time to get the latest best
        best_files.sort(key=os.path.getmtime, reverse=True)
        return best_files[0]

    # --- Fallback ---

    last_ckpt = os.path.join(base_dir, "last.ckpt")
    if os.path.exists(last_ckpt):
        return last_ckpt

    all_ckpts = glob.glob(os.path.join(base_dir, "*.ckpt"))
    if all_ckpts:
        all_ckpts.sort(key=os.path.getmtime, reverse=True)
        return all_ckpts[0]

    print(f"Error: No checkpoints found in {base_dir}")
    return None


def load_model(exp_name, sim_id, selection_strategy):
    """
    cpv: checkpoint version (when the checkpoint of the same exp_name saved more than once, it has a tag e.g., v1, v2, etc...)
    """

    ckpt_path = get_checkpoint_from_outputs(
        exp_name=exp_name, sim_id=sim_id, selection_strategy=selection_strategy
    )
    print(f"Loading model from {ckpt_path}")

    mod = LitAutoEncoder.load_from_checkpoint(
        ckpt_path, map_location=torch.device("cpu")
    )

    return mod


def load_model_and_dataloader(
    exp_name,
    sim_id,
    n_samples=10000,
    batch_size=200,
    max_epoch=500,
    version=None,
    seed=None,
):

    mod = load_model(
        exp_name=exp_name, sim_id=sim_id, max_epoch=max_epoch, version=version
    )
    num_pop = len(mod.hparams["dataset_args"]["Sig_hs"])
    loader = DataLoader(
        SimDataset(
            **mod.hparams["dataset_args"],
            num_draws=n_samples,
            num_obs=[n_samples for _ in range(num_pop)],
            seed=seed,
        ),
        batch_size=batch_size,
    )

    return mod, loader


# %%
# Generate new samples, obtain hvws and reconstructions


def get_eval_dataframes(
    exp_name, sim_id, max_epoch=500, n_samples=10000, seed=42
):
    """
    Wrapper to load model, generate data, and return clean DataFrames.
    """
    mod, loader = load_model_and_dataloader(
        exp_name=exp_name,
        sim_id=sim_id,
        max_epoch=max_epoch,
        n_samples=n_samples,
        seed=seed,
    )

    dfs = mod.encode_dataset(loader)

    return dfs


# %%
# ATE and R2


def simpleIV(Y, T, Z, C=None):
    """ "
    2SLS estimator
    """

    if C is None:
        lm = LinearRegression()
        lm.fit(y=T, X=Z)
        hT = lm.predict(X=Z)
        lm = LinearRegression()
        lm.fit(y=Y, X=hT)
    else:

        lm = LM()
        lm.fit(Y=Y, X=C)
        res_Y = lm.residuals(Y=Y, X=C)
        lm.fit(Y=T, X=C)
        res_T = lm.residuals(Y=T, X=C)
        mlm = DefaultMultiRegression(LM(), dim=Z.shape[1])
        mlm.fit(Y=Z, X=C)
        res_Z = mlm.residuals(Y=Z, X=C)
        lm = LinearRegression()
        lm.fit(y=res_T, X=res_Z)
        hres_T = lm.predict(X=res_Z)
        lm = LinearRegression()
        lm.fit(y=res_Y, X=hres_T)

    return lm.coef_[0][range(T.shape[1])]


def compute_estimates(dfs, batch_nums=None, plot=True, theta=1.0, ax=None):

    df_lst = []

    for pop_num in range(len(dfs)):

        df_pop = dfs[pop_num]
        if batch_nums is None:
            batch_nums = df_pop["batch_num"].unique()

        est0 = []  # Z (observables)
        est1 = []  # VW
        est2 = []  # V
        est3 = []  # W
        est4 = []  # hVW (hvws)
        est5 = []  # hV
        est6 = []  # hW
        est7 = []  # W but include V as additional covariates in T to Y
        est8 = []

        for batch_num in batch_nums:

            df_tmp = df_pop[df_pop["batch_num"] == batch_num]

            # Using Z
            est0.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^Z", axis=1),
                )
            )

            # Using VW
            est1.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^V|^W", axis=1),
                )
            )

            # Using V
            est2.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^V", axis=1),
                )
            )

            # Using W
            est3.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^W", axis=1),
                )
            )

            # Using hVW
            est4.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^hV|^hW", axis=1),
                )
            )

            est5.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^hV", axis=1),
                )
            )

            est6.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^hW", axis=1),
                )
            )

            est7.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^W", axis=1),
                    C=df_tmp.filter(regex="^V", axis=1),
                )
            )

            est8.append(
                simpleIV(
                    Y=df_tmp.filter(regex="^Y", axis=1),
                    T=df_tmp.filter(regex="^D", axis=1),
                    Z=df_tmp.filter(regex="^hW", axis=1),
                    C=df_tmp.filter(regex="^hV", axis=1),
                )
            )

        est_all = np.column_stack(
            (est0, est1, est2, est3, est4, est5, est6, est7, est8)
        )
        est_df = pd.DataFrame(
            est_all,
            columns=["Z", "VW", "V", "W", "hVhW", "hV", "hW", "WcV", "hWchV"],
        )

        est_df["pop_num"] = pop_num
        df_lst.append(est_df)

    est_df = pd.concat(df_lst, ignore_index=True)
    est_df_long = est_df.melt(
        id_vars=["pop_num"], var_name="instrument", value_name="estimate"
    )
    if plot and (ax is None):
        plt.figure()
        sns.boxplot(
            est_df_long,
            x="instrument",
            y="estimate",
            hue="pop_num",
            ax=ax,
            order=["Z", "VW", "hVhW", "V", "hV", "W", "hW", "WcV", "hWchV"],
        )
        plt.axhline(y=theta, color="red", linestyle="--", linewidth=1)
        plt.suptitle(f"Pop {pop_num}", y=1.02)
        plt.show()
    elif plot:
        sns.boxplot(
            est_df_long,
            x="instrument",
            y="estimate",
            hue="pop_num",
            ax=ax,
            order=["Z", "VW", "hVhW", "V", "hV", "W", "hW", "WcV", "hWchV"],
        )
        ax.axhline(y=theta, color="red", linestyle="--", linewidth=1)
        ax.legend_.remove()

    return est_df_long


def get_r2(X, Y):
    mod = LinearRegression()
    mod.fit(X=X, y=Y)
    return mod.score(X=X, y=Y)


def compute_R2(dfs, batch_nums=None, combine_batches=True):

    df_lst = []

    for pop_num in range(len(dfs)):

        df_pop = dfs[pop_num]

        if batch_nums is None:
            batch_nums = df_pop["batch_num"].unique()

        if combine_batches:

            df_tmp = df_pop[df_pop["batch_num"].isin(batch_nums)]

            r2_hw_w = get_r2(
                X=df_tmp.filter(regex="^hW").to_numpy(),
                Y=df_tmp.filter(regex="^W").to_numpy(),
            )

            r2_hw_v = get_r2(
                X=df_tmp.filter(regex="^hW").to_numpy(),
                Y=df_tmp.filter(regex="^V").to_numpy(),
            )

            r2_hv_w = get_r2(
                X=df_tmp.filter(regex="^hV").to_numpy(),
                Y=df_tmp.filter(regex="^W").to_numpy(),
            )

            r2_hv_v = get_r2(
                X=df_tmp.filter(regex="^hV").to_numpy(),
                Y=df_tmp.filter(regex="^V").to_numpy(),
            )

            r2_hw_h = get_r2(
                X=df_tmp.filter(regex="^hW").to_numpy(),
                Y=df_tmp.filter(regex="^H").to_numpy(),
            )

            r2_w_h = get_r2(
                X=df_tmp.filter(regex="^W").to_numpy(),
                Y=df_tmp.filter(regex="^H").to_numpy(),
            )

            r2_hv_h = get_r2(
                X=df_tmp.filter(regex="^hV").to_numpy(),
                Y=df_tmp.filter(regex="^H").to_numpy(),
            )

            r2_v_h = get_r2(
                X=df_tmp.filter(regex="^V").to_numpy(),
                Y=df_tmp.filter(regex="^H").to_numpy(),
            )

            r2_hw_d = get_r2(
                X=df_tmp.filter(regex="^hW").to_numpy(),
                Y=df_tmp.filter(regex="^D").to_numpy(),
            )

            r2_w_d = get_r2(
                X=df_tmp.filter(regex="^W").to_numpy(),
                Y=df_tmp.filter(regex="^D").to_numpy(),
            )

            r2_hv_d = get_r2(
                X=df_tmp.filter(regex="^hV").to_numpy(),
                Y=df_tmp.filter(regex="^D").to_numpy(),
            )

            r2_v_d = get_r2(
                X=df_tmp.filter(regex="^V").to_numpy(),
                Y=df_tmp.filter(regex="^D").to_numpy(),
            )

            df_lst.append(
                pd.DataFrame(
                    np.array(
                        [
                            pop_num,
                            r2_hw_w,
                            r2_hw_v,
                            r2_hv_w,
                            r2_hv_v,
                            r2_hw_h,
                            r2_w_h,
                            r2_hv_h,
                            r2_v_h,
                            r2_hw_d,
                            r2_w_d,
                            r2_hv_d,
                            r2_v_d,
                        ]
                    ).reshape(1, 13),
                    columns=[
                        "pop",
                        "r2_hWW",
                        "r2_hWV",
                        "r2_hVW",
                        "r2_hVV",
                        "r2_hWH",
                        "r2_WH",
                        "r2_hVH",
                        "r2_VH",
                        "r2_hWD",
                        "r2_WD",
                        "r2_hVD",
                        "r2_VD",
                    ],
                )
            )

        else:

            r2_hw_w = []
            r2_hw_v = []
            r2_hw_h = []
            r2_hw_d = []
            r2_w_d = []

            for batch_num in batch_nums:

                df_tmp = df_pop[df_pop["batch_num"] == batch_num]

                mod = LinearRegression()
                mod.fit(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^W").to_numpy(),
                )
                r2 = mod.score(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^W").to_numpy(),
                )
                r2_hw_w.append(r2)

                mod = LinearRegression()
                mod.fit(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^V").to_numpy(),
                )
                r2 = mod.score(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^V").to_numpy(),
                )
                r2_hw_v.append(r2)

                mod = LinearRegression()
                mod.fit(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^H").to_numpy(),
                )
                r2 = mod.score(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^H").to_numpy(),
                )
                r2_hw_h.append(r2)

                mod = LinearRegression()
                mod.fit(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^D").to_numpy(),
                )
                r2 = mod.score(
                    df_tmp.filter(regex="^hW").to_numpy(),
                    df_tmp.filter(regex="^D").to_numpy(),
                )
                r2_hw_d.append(r2)

                mod = LinearRegression()
                mod.fit(
                    df_tmp.filter(regex="^W").to_numpy(),
                    df_tmp.filter(regex="^D").to_numpy(),
                )
                r2 = mod.score(
                    df_tmp.filter(regex="^W").to_numpy(),
                    df_tmp.filter(regex="^D").to_numpy(),
                )
                r2_w_d.append(r2)

            est_all = np.column_stack(
                (r2_hw_w, r2_hw_w, r2_hw_h, r2_hw_d, r2_w_d)
            )
            est_df = pd.DataFrame(
                est_all,
                columns=["r2_hWW", "r2_hWV", "r2_hWH", "r2_hWD", "r2_WD"],
            )
            est_df["pop"] = pop_num
            df_lst.append(est_df)

    est_df = pd.concat(df_lst, ignore_index=True)
    return est_df


# %%
# Other statistics


def compute_recon_err(dfs, dim_z=5, batch_nums=None, combine_batches=True):

    df_lst = []
    z_cols = [f"Z_{i}" for i in range(dim_z)]
    hz_cols = [f"hZ_{i}" for i in range(dim_z)]

    for pop_num in range(len(dfs)):

        df_pop = dfs[pop_num]

        if batch_nums is None:

            batch_nums = df_pop["batch_num"].unique()

        if combine_batches:

            df_tmp = df_pop[df_pop["batch_num"].isin(batch_nums)]
            squared_errors = (
                df_tmp[z_cols].values - df_tmp[hz_cols].values
            ) ** 2

            mse = squared_errors.mean()

            df_lst.append(
                pd.DataFrame(
                    np.array([pop_num, mse]).reshape(1, 2),
                    columns=[
                        "pop",
                        "mse",
                    ],
                )
            )
        else:

            mses = []

            for batch_num in batch_nums:

                df_tmp = df_pop[df_pop["batch_num"] == batch_num]
                squared_errors = (
                    df_tmp[z_cols].values - df_tmp[hz_cols].values
                ) ** 2

                mses.append(squared_errors.mean())

            mse_df = pd.DataFrame(
                mses,
                columns=["mse"],
            )
            mse_df["pop"] = pop_num
            df_lst.append(mse_df)

    mse_df = pd.concat(df_lst, ignore_index=True)
    return mse_df


def compute_gcm(dfs):
    """
    Computes Generalized Covariance Measure (GCM) independence tests
    for each population in the list of DataFrames.
    """
    res = []

    for pop_num, df in enumerate(dfs):
        # 1. Extract Data via Column Names (Robust)
        # We use .filter(regex=...) to grab all dimensions of hV, hW, etc.
        hV = df.filter(regex="^hV_").to_numpy()
        hW = df.filter(regex="^hW_").to_numpy()
        V = df.filter(regex="^V_").to_numpy()
        W = df.filter(regex="^W_").to_numpy()

        # 2. H0: hV \indep W | V
        gcm_v = GCM()
        gcm_v.test(
            X=hV,  # hV
            Y=W,  # W
            Z=V,  # Condition on V
            reg_yz=LM(),
            reg_xz=LM(),
            test_type="max",
            B=4999,
            show_summary=False,
        )

        # 3. H0: hW \indep V | W
        gcm_w = GCM()
        gcm_w.test(
            X=hW,  # hW
            Y=V,  # V
            Z=W,  # Condition on W
            reg_yz=LM(),
            reg_xz=LM(),
            test_type="max",
            B=4999,
            show_summary=False,
        )

        new_row = {
            "pop": pop_num,
            "gs_w": gcm_w.stat,
            "gp_w": gcm_w.pval,
            "gs_v": gcm_v.stat,
            "gp_v": gcm_v.pval,
        }
        res.append(pd.DataFrame([new_row]))

    return pd.concat(res, ignore_index=True)


def compute_gcm_cor(dfs):
    """
    Computes the GCM correlation (measure of dependence) for the tests.
    """
    res_hvw_v = []
    res_hwv_w = []

    for pop_num, df in enumerate(dfs):
        hV = df.filter(regex="^hV_").to_numpy()
        hW = df.filter(regex="^hW_").to_numpy()
        V = df.filter(regex="^V_").to_numpy()
        W = df.filter(regex="^W_").to_numpy()

        # H0: hV \indep W | V
        gcm_v = GCM()
        gcm_v.test(
            X=hV,
            Y=W,
            Z=V,
            reg_yz=LM(),
            reg_xz=LM(),
            test_type="max",
            B=100,
            show_summary=False,
        )
        res_hvw_v.append(gcm_v.get_cor())

        # H0: hW \indep V | W
        gcm_w = GCM()
        gcm_w.test(
            X=hW,
            Y=V,
            Z=W,
            reg_yz=LM(),
            reg_xz=LM(),
            test_type="max",
            B=100,
            show_summary=False,
        )
        res_hwv_w.append(gcm_w.get_cor())

    return res_hvw_v, res_hwv_w


def compute_hsic(dfs, kernel_type="rbf", sigma=None):
    res = []

    for pop_num, df in enumerate(dfs):
        hV = torch.tensor(df.filter(regex="^hV_").values, dtype=torch.float32)
        hW = torch.tensor(df.filter(regex="^hW_").values, dtype=torch.float32)
        V = torch.tensor(df.filter(regex="^V_").values, dtype=torch.float32)
        W = torch.tensor(df.filter(regex="^W_").values, dtype=torch.float32)

        if kernel_type == "rbf":
            # H0: hV \indep hW
            hwhv, _, _ = hsic_rbf(hW, hV, sigma_x=sigma, sigma_y=sigma)
            # H0: hW \indep V
            hwv, _, _ = hsic_rbf(hW, V, sigma_x=sigma, sigma_y=sigma)
            # H0: hV \indep W
            hvw, _, _ = hsic_rbf(hV, W, sigma_x=sigma, sigma_y=sigma)

        elif kernel_type == "poly":
            hwhv = hsic_poly(hW, hV, degree_x=2, degree_y=2, c=1.0)
            hwv = hsic_poly(hW, V, degree_x=2, degree_y=2, c=1.0)
            hvw = hsic_poly(hV, W, degree_x=2, degree_y=2, c=1.0)

        new_row = {
            "pop": pop_num,
            "hwhv": hwhv.detach().cpu().item(),
            "hwv": hwv.detach().cpu().item(),
            "hvw": hvw.detach().cpu().item(),
        }
        res.append(pd.DataFrame([new_row]))

    return pd.concat(res, ignore_index=True)


def test_linear_indep(dfs, test_type="max"):
    """
    Tests linear independence using GCM but conditioning on an intercept only (Z=ones).
    """
    res = []

    for pop_num, df in enumerate(dfs):
        hV = df.filter(regex="^hV_").to_numpy()
        hW = df.filter(regex="^hW_").to_numpy()
        V = df.filter(regex="^V_").to_numpy()
        W = df.filter(regex="^W_").to_numpy()

        # Intercept (Conditioning set Z)
        ones = np.ones((len(df), 1))

        # Test 1: hV vs W (unconditional / intercept only)
        gcm_v = GCM()
        gcm_v.test(
            X=hV,
            Y=W,
            Z=ones,
            reg_yz=LM(),
            reg_xz=LM(),
            test_type=test_type,
            B=4999,
            show_summary=False,
        )

        # Test 2: hW vs V
        gcm_w = GCM()
        gcm_w.test(
            X=hW,
            Y=V,
            Z=ones,
            reg_yz=LM(),
            reg_xz=LM(),
            test_type=test_type,
            B=4999,
            show_summary=False,
        )

        new_row = {
            "pop": pop_num,
            "hwv_stat": gcm_w.stat,
            "hwv_pval": gcm_w.pval,
            "hvw_stat": gcm_v.stat,
            "hvw_pval": gcm_v.pval,
        }
        res.append(pd.DataFrame([new_row]))

    return pd.concat(res, ignore_index=True)


def check_hvhw_indep(dfs):
    """
    Comprehensive independence check between hV and hW.
    """
    res = []

    for pop_num, df in enumerate(dfs):
        hV = df.filter(regex="^hV_").to_numpy()
        hW = df.filter(regex="^hW_").to_numpy()
        V = df.filter(regex="^V_").to_numpy()
        W = df.filter(regex="^W_").to_numpy()

        # 1. Correlation
        emb_pop = np.hstack([hV, hW])
        abs_cor = np.abs(np.corrcoef(emb_pop.T))
        mask = ~np.eye(abs_cor.shape[0], dtype=bool)
        cor_max = np.max(abs_cor[mask]) if mask.any() else 0
        cor_min = np.min(abs_cor[mask]) if mask.any() else 0

        # 2. Mutual Information
        mutinfo = mi(x=hV, y=hW)
        print(f"Pop {pop_num} Mutual Info: {mutinfo:.4f}")

        # 3. GCM Tests (Conditional Independence)
        gcm_v = GCM()
        gcm_v.test(
            X=hV, Y=W, Z=V, reg_yz=LM(), reg_xz=LM(), B=4999, show_summary=False
        )

        gcm_w = GCM()
        gcm_w.test(
            X=hW, Y=V, Z=W, reg_yz=LM(), reg_xz=LM(), B=4999, show_summary=False
        )

        # 4. HSIC (Marginal Independence of hV, hW)
        stat, thre = hsic_gam(X=hV, Y=hW, alph=0.05)

        new_row = {
            "pop_num": pop_num,
            "gcm_stat_w": gcm_w.stat,
            "gcm_pval_w": gcm_w.pval,
            "gcm_stat_v": gcm_v.stat,
            "gcm_pval_v": gcm_v.pval,
            "cor_min": cor_min,
            "cor_max": cor_max,
            "mi": mutinfo,
            "hsic_stat": stat,
            "hsic_thre": thre,
            "hsic_reject": stat > thre,
        }
        res.append(pd.DataFrame([new_row]))

    return pd.concat(res, ignore_index=True)


# %%
# Make plots


def pair_selected(df, filter_regex=".*", plot_title=None):
    g = sns.pairplot(
        df.filter(regex=filter_regex),
        kind="reg",
        height=1,
        aspect=1,
        corner=True,
        plot_kws={"scatter_kws": {"s": 0.5}},
    )
    if plot_title:
        g.figure.suptitle(plot_title, y=1.02)
        g.figure.tight_layout()
        g.figure.subplots_adjust(top=0.95)
    return g


def plot_values(dfs, value_vars=["W0", "W1", "W2"]):
    """
    Plot values from a list of DataFrames as boxplots.

    Parameters
    ----------
    dfs : list of pd.DataFrame
        Each DataFrame should have a 'batch_num' column.
    value_vars : list of str
        Columns to plot.
    """

    def melt_df(df):
        return df.melt(
            id_vars="batch_num",
            value_vars=value_vars,
            var_name="variable",
            value_name="value",
        )

    dfs_long = [melt_df(df) for df in dfs]

    num_dfs = len(dfs)
    fig, axes = plt.subplots(num_dfs, 1, figsize=(16, 5 * num_dfs), sharex=True)

    if num_dfs == 1:
        axes = [axes]

    for i, (df_long, ax) in enumerate(zip(dfs_long, axes)):
        sns.boxplot(
            x="batch_num", y="value", hue="variable", data=df_long, ax=ax
        )
        ax.set_title(f"Pop {i+1}")
        ax.set_xlabel("" if i < num_dfs - 1 else "batch_num")
        ax.set_ylabel("Value")
        ax.legend(title="Variable")
        ax.set_xticks(range(len(dfs[0]["batch_num"].unique())))

    plt.tight_layout()
    plt.show()
