import torch
import matplotlib.pyplot as plt


def get_2d_projection(d, projection_type="random", projection_dims=(0, 1), device=None):
    """
    Return a projection matrix P of shape [d, 2].

    projection_type:
        - "random"      : random 2D subspace in R^d
        - "coordinates" : projection on two coordinate axes
    """
    if device is None:
        device = "cpu"

    if projection_type == "random":
        # Random orthonormal directions in R^d
        Q, _ = torch.linalg.qr(torch.randn(d, 2, device=device))
        return Q

    elif projection_type == "coordinates":
        i, j = projection_dims
        assert 0 <= i < d and 0 <= j < d and i != j

        P = torch.zeros(d, 2, device=device)
        P[i, 0] = 1.0
        P[j, 1] = 1.0
        return P

    else:
        raise ValueError(f"Unknown projection_type: {projection_type}")



def plot_em2c_lambda_kernel_grid(
    results,
    target,
    metadata,
    n_mid=4,
    n_samples_plot=1500,
    projection_type="coordinates",
    projection_dims=(0, 1),
    text_col_width=0.1,
    verbose=True,
    manual_time_steps=None,
):
    """
    Grid visualization for a fixed kernel.
    results: dict keyed by lamda -> {"proposal_history", ...}
    """

    #Printing params
    
    model = metadata["model"]
    d = metadata["d"]
    kernel = metadata["kernel"]
    kernel_params = metadata["kernel_params"]
    n_iter = metadata["n_iter"]
    lamdas = metadata["lamdas"]


    if verbose:

        print(
            f" Model: {model} | d={d} | Kernel: {kernel} | "
            f" Kernel params: {kernel_params}"
        )
        print(
            f" T={n_iter} | Lambdas={lamdas} | Projection={projection_type}"
            + (f" {projection_dims}" if projection_type == "coordinates" else "")
        )
        print("==========================================\n")




    # λ values in insertion order 
    lamdas = list(results.keys())
    n_rows = len(lamdas)

    if n_rows == 0:
        raise ValueError("Empty results dictionary.")

    # Time steps (automatic or manual)
    if manual_time_steps is not None:
        # use user-specified indices directly
        time_steps = manual_time_steps
    else:
        if n_mid > 0:
            time_steps = torch.linspace(
                0, n_iter - 1, n_mid + 2
            ).long().tolist()
        else:
            time_steps = [0, n_iter - 1]

    n_snapshots = len(time_steps)
    n_cols = n_snapshots + 1

    # GridSpec (narrow text column)
    width_ratios = [text_col_width] + [1.0] * n_snapshots

    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(4.0 * n_snapshots + 1.8, 3.8 * n_rows),
        gridspec_kw={"width_ratios": width_ratios},
    )

    if n_rows == 1:
        axes = axes[None, :]

    # Projection matrix
    P = get_2d_projection(
        d,
        projection_type=projection_type,
        projection_dims=projection_dims,
        device=target.base_means.device, 
    )


    # SINGLE sampling pass (USED FOR BOTH LIMITS AND PLOTS)
    with torch.no_grad():
        # target samples (shared across all rows/cols)
        x_tgt = target.sample(n_samples_plot) @ P

        # proposal samples: dict[lamda][t] -> Tensor
        x_prop = {}
        for lamda in lamdas:
            proposal_history = results[lamda]["proposal_history"]
            x_prop[lamda] = {}
            for t in time_steps:
                x_prop[lamda][t] = proposal_history[t].sample(
                    n_samples_plot
                ) @ P

    # GLOBAL axis limits (from SAME samples)
    X_all = torch.cat(
        [x_tgt]
        + [x_prop[lamda][t] for lamda in lamdas for t in time_steps],
        dim=0,
    )

    xmin, xmax = X_all[:, 0].min(), X_all[:, 0].max()
    ymin, ymax = X_all[:, 1].min(), X_all[:, 1].max()


    # Plot rows
    for row, lamda in enumerate(lamdas):
        proposal_history = results[lamda]["proposal_history"]

        # ----- Left column: λ (vertical) -----
        ax_text = axes[row, 0]
        ax_text.axis("off")
        ax_text.text(
            0.5,
            0.5,
            rf"$\lambda = {lamda}$",
            ha="center",
            va="center",
            rotation=90,
            fontsize=30,
            fontweight="bold",
        )

        # ----- Snapshots -----
        for col, t in enumerate(time_steps):
            ax = axes[row, col + 1]

            ax.scatter(
                x_tgt[:, 0], x_tgt[:, 1],
                s=5, alpha=0.25, c="gray"
            )
            ax.scatter(
                x_prop[lamda][t][:, 0],
                x_prop[lamda][t][:, 1],
                s=5, alpha=0.35, c="red"
            )

            ax.set_xlim(xmin, xmax)
            ax.set_ylim(ymin, ymax)
            ax.set_aspect("equal", adjustable="box")
            ax.set_xticks([])
            ax.set_yticks([])

            if row == 0:
                ax.set_title(rf"$t = {t}$", fontsize=30)


    plt.tight_layout(rect=[0, 0.08, 1, 0.94])
    plt.show()