import argparse

from mixed_diffusion.data_loading.mr_mash_data_generation import simulate_mr_mash_data

import seaborn as sns
import matplotlib.pyplot as plt


def main(args):
    output = simulate_mr_mash_data(
        n=args.n,
        p=args.p,
        p_causal=args.p_causal,
        r=args.r,
        r_causal=args.r_causal,
        intercepts=args.intercepts,
        pve=args.pve,
        B_cor=args.B_cor,
        B_scale=args.B_scale,
        w=args.w,
        X_cor=args.X_cor,
        X_scale=args.X_scale,
        V_cor=args.V_cor,
        seed=args.seed,
    )
    X = output["X"]
    Y = output["Y"]

    # Print X and Y as heatmaps

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    sns.heatmap(X, cmap="viridis", ax=axes[0])
    axes[0].set_title("X")
    sns.heatmap(Y, cmap="viridis", ax=axes[1])
    axes[1].set_title("Y")
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Generate mr-mash synthetic data.")
    # Add arguments matching simulate_mr_mash_data parameters
    parser.add_argument(
        "--n", type=int, default=50, help="Scalar indicating the number of samples."
    )
    parser.add_argument(
        "--p", type=int, default=40, help="Scalar indicating the number of variables."
    )
    parser.add_argument(
        "--p_causal",
        type=int,
        default=20,
        help="Scalar indicating the number of causal variables.",
    )
    parser.add_argument(
        "--r", type=int, default=5, help="Scalar indicating the number of responses."
    )
    parser.add_argument(
        "--r_causal",
        type=str,
        default=[[1, 2], [3, 4]],
        help="List of numeric vectors (JSON) indicating in which responses the causal variables have an effect.",
    )
    parser.add_argument(
        "--intercepts",
        type=str,
        default=[1] * 5,
        help="Numeric vector (JSON) of intercept for each response.",
    )
    parser.add_argument(
        "--pve",
        type=float,
        default=0.20,
        help="Per-response proportion of variance explained by the causal variables.",
    )
    parser.add_argument(
        "--B_cor",
        type=str,
        default=[0, 1],
        help="Scalar or numeric vector (JSON) with positive correlation [0, 1] between causal effects.",
    )
    parser.add_argument(
        "--B_scale",
        type=str,
        default=[0.5, 1.0],
        help="Scalar or numeric vector (JSON) with the diagonal value for Sigma_k.",
    )
    parser.add_argument(
        "--w",
        type=str,
        default=[0.5, 0.5],
        help="Scalar or numeric vector (JSON) with mixture proportions for each mixture component.",
    )
    parser.add_argument(
        "--X_cor",
        type=float,
        default=0.5,
        help="Scalar indicating the positive correlation [0, 1] between variables.",
    )
    parser.add_argument(
        "--X_scale",
        type=float,
        default=1.0,
        help="Scalar indicating the diagonal value for Gamma.",
    )
    parser.add_argument(
        "--V_cor",
        type=float,
        default=0.0,
        help="Scalar indicating the positive correlation [0, 1] between residuals.",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility."
    )

    args = parser.parse_args()
    # Call main with updated argument names
    main(args)
