# %%
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

import numpy as np
import pandas as pd
from src.datagen import get_nested_circles, get_gaussian_linear_transform_data


PATH_DICT = {
    "linear_transform_gaussian": "exp_linear_transform_gaussian",
    "non_linear": "exp_non_linear",
}

EXP = (
    "linear_transform_gaussian"  # "linear_transform_gaussian" or "non_linear"
)

FONTSIZE = 20

# generate data once for all experiments
if EXP == "non_linear":
    (X, Y), (S_X, S_Y) = get_nested_circles(
        n_x=250,
        n_y=25,
        p_X0=0.5,
        p_Y0=0.5,
        noise_0=0.15,
        noise_1=0.2,
        diameter=4.0,
        rng=42,
        n_outliers_x=4,
        n_outliers_y=2,
    )
elif EXP == "linear_transform_gaussian":
    (X, Y), (S_X, S_Y) = get_gaussian_linear_transform_data(
        d=2,
        n_x=250,
        n_y=25,
        scale=0.2,
        p_x0=0.5,
        p_y0=0.5,
        centers_X=[np.array([0, 0]), np.array([2.0, 0.0])],
        centers_Y=[np.array([1.0, 1.0]), np.array([2.5, 0.5])],
        rng=42,
    )

results_mlp = pd.read_pickle(f"{PATH_DICT[EXP]}/results_mlp.pkl")
results_penalized = pd.read_pickle(f"{PATH_DICT[EXP]}/results_penalized.pkl")
results_mahalanobis = pd.read_pickle(
    f"{PATH_DICT[EXP]}/results_mahalanobis.pkl"
)
results_entropic_ot = pd.read_pickle(
    f"{PATH_DICT[EXP]}/results_entropic_ot.pkl"
)

results_penalized["final_fairness"] = results_penalized["fairness_loss_value"]
results_entropic_ot["final_fairness"] = results_entropic_ot[
    "fairness_loss_value"
]


# %%
# Function to compute F_hat from transport plan by aggregating over groups
def compute_F_hat(pi, S_X, S_Y):
    """Compute the group-aggregated transport matrix F_hat from the transport
    plan pi."""
    if hasattr(pi, "numpy"):
        pi = pi.numpy()
    if hasattr(S_X, "numpy"):
        S_X = S_X.numpy()
    if hasattr(S_Y, "numpy"):
        S_Y = S_Y.numpy()

    unique_S_X = np.unique(S_X)
    unique_S_Y = np.unique(S_Y)
    k_s = len(unique_S_X)
    k_w = len(unique_S_Y)

    F_hat = np.zeros((k_s, k_w))
    for i, s in enumerate(unique_S_X):
        for j, w in enumerate(unique_S_Y):
            mask_X = S_X == s
            mask_Y = S_Y == w
            F_hat[i, j] = pi[np.ix_(mask_X, mask_Y)].sum()

    return F_hat


# %%

plt.rc("font", family="serif")
sns.set_style("whitegrid")

fig, axes = plt.subplots(1, 4, figsize=(20, 4), sharey=True, sharex=True)

# 4 subplots for results
results_list = [
    results_penalized,
    results_mlp,
    results_mahalanobis,
    results_entropic_ot,
]
titles = ["Penalized OT", "MLP", "Mahalanobis", "Entropic OT"]
cmaps = ["Oranges", "Blues", "Greens", "Purples"]

# F matrix heatmap settings
inset_size = 0.25  # size of inset axes as fraction of main axes
heatmap_cmap = "Blues"
vmin_F = 0.04
vmax_F = 0.62


def get_penalty_indices(results, levels=["low", "mid", "high"]):
    """Get indices for lowest, middle, and highest penalty values."""
    sorted_df = results.sort_values("penalty").reset_index(drop=True)
    n = len(sorted_df)
    indices = {
        "low": 0,
        "mid": n // 2,
        "high": n - 1,
    }
    return {level: sorted_df.iloc[indices[level]] for level in levels}


for i, (results, title, cmap) in enumerate(zip(results_list, titles, cmaps)):
    ax = axes[i]

    # Use separate normalization for each method based on its penalty range
    if EXP == "linear_transform_gaussian":
        results = results.query("final_fairness >= 1e-5")
    vmin_penalty = results["penalty"].min()
    vmax_penalty = results["penalty"].max()
    norm = plt.matplotlib.colors.LogNorm(vmin=vmin_penalty, vmax=vmax_penalty)

    scatter = ax.scatter(
        results["final_fairness"],
        results["cost_diff"],
        c=results["penalty"],
        cmap=cmap,
        marker="o",
        s=200,
        edgecolor="black",
        norm=norm,
    )

    ax.set_xlabel("Fairness Loss" if i == 1 else None, fontsize=20)
    ax.set_ylabel("Cost Difference" if i == 0 else None, fontsize=20)
    ax.set_title(title, fontsize=20, fontweight="bold")
    ax.grid(True, alpha=1)
    ax.tick_params(axis="both", labelsize=20)
    for spine in ax.spines.values():
        spine.set_color("black")
        spine.set_linewidth(1)
    ax.tick_params(axis="both", which="both", width=2, color="black")

    # Add colorbar for each subplot with its own normalization
    cbar = plt.colorbar(scatter, ax=ax)
    if i < 3:
        cbar.set_label(label="Fairness penalty", fontsize=16)

    else:
        cbar.set_label(label="Entropic penalty", fontsize=16)
    cbar.ax.tick_params(labelsize=16)

    # Get the three penalty levels (low, mid, high)
    selected_rows = get_penalty_indices(results)

    # Inset positions: [x, y, width, height] in axes coordinates
    # Position insets in reverse order (high to low) to avoid crossing arrows
    if EXP == "non_linear":
        if i == 0:
            inset_positions = {
                "high": [0.03, 0.6, inset_size, inset_size],
                "mid": [0.25, 0.08, inset_size, inset_size],
                "low": [0.7, 0.6, inset_size, inset_size],
            }
        elif i == 1:
            inset_positions = {
                "high": [0.03, 0.3, inset_size, inset_size],
                "mid": [0.3, 0.2, inset_size, inset_size],
                "low": [0.6, 0.05, inset_size, inset_size],
            }
        else:
            inset_positions = {
                "high": [0.2, 0.7, inset_size, inset_size],
                "mid": [0.2, 0.4, inset_size, inset_size],
                "low": [0.2, 0.1, inset_size, inset_size],
            }
    elif EXP == "linear_transform_gaussian":
        if i == 0:
            inset_positions = {
                "high": [0.03, 0.5, inset_size, inset_size],
                "mid": [0.25, 0.2, inset_size, inset_size],
                "low": [0.72, 0.7, inset_size, inset_size],
            }
        elif i == 1:
            inset_positions = {
                "high": [0.03, 0.3, inset_size, inset_size],
                "mid": [0.3, 0.2, inset_size, inset_size],
                "low": [0.6, 0.05, inset_size, inset_size],
            }
        else:
            inset_positions = {
                "high": [0.1, 0.6, inset_size, inset_size],
                "mid": [0.1, 0.32, inset_size, inset_size],
                "low": [0.1, 0.05, inset_size, inset_size],
            }

    for level, row in selected_rows.items():
        # Get the transport plan and compute F_hat
        pi = row["fair_ot_plan"]
        F_hat = compute_F_hat(pi, S_X, S_Y)

        # Create inset axes
        inset_ax = ax.inset_axes(inset_positions[level])

        # Plot heatmap in inset
        sns.heatmap(
            F_hat,
            ax=inset_ax,
            cmap=heatmap_cmap,
            cbar=False,
            annot=True,
            fmt=".2f",
            annot_kws={"size": 14},
            vmin=vmin_F,
            vmax=vmax_F,
            linewidths=0.5,
            linecolor="black",
        )
        inset_ax.set_xticks([])
        inset_ax.set_yticks([])
        # inset_ax.set_title(rf"$\lambda={row['penalty']:.1f}$", fontsize=16, pad=2)

        # Add border to inset - draw a rectangle around the heatmap
        for spine in inset_ax.spines.values():
            spine.set_visible(True)
            spine.set_color("black")
            spine.set_linewidth(1.5)
        # Force all spines to be on top and visible
        inset_ax.patch.set_edgecolor("black")
        inset_ax.patch.set_linewidth(1.5)
        # Add rectangle to ensure full border is visible
        rect = Rectangle(
            (0, 0),
            1,
            1,
            transform=inset_ax.transAxes,
            fill=False,
            edgecolor="black",
            linewidth=1.5,
            clip_on=False,
            zorder=10,
        )
        inset_ax.add_patch(rect)

        # Get the scatter point coordinates
        point_x = row["final_fairness"]
        point_y = row["cost_diff"]

        # Draw arrow from inset to scatter point
        # Convert inset position to data coordinates for the arrow
        inset_center_x = (
            inset_positions[level][0] + inset_positions[level][2] / 2
        )
        inset_center_y = inset_positions[level][1]

        ax.annotate(
            "",
            xy=(point_x, point_y),
            xytext=(inset_center_x, inset_center_y),
            textcoords="axes fraction",
            arrowprops=dict(
                arrowstyle="->",
                color="black",
                lw=1.5,
                connectionstyle="arc3,rad=-0.2",
            ),
        )
        ax.set_xscale("log")

plt.tight_layout()
plt.savefig(f"{PATH_DICT[EXP]}/cost_diff_vs_fairness_{EXP}.pdf")
plt.show()


# %%
# Plot F matrix for each method (3 rows) and 5 penalty values (5 columns)
plt.rc("font", family="serif")
sns.set_style("whitegrid")

vmin = 0.04
vmax = 0.51

results_list = [results_penalized, results_mlp, results_mahalanobis]
method_names = ["Penalized OT", "MLP", "Mahalanobis"]


# Select 5 evenly spaced penalties for each method separately
def get_selected_penalties(results, n_select=5):
    sorted_penalties = sorted(results["penalty"].unique())
    n_penalties = len(sorted_penalties)
    penalty_indices = np.linspace(0, n_penalties - 1, n_select, dtype=int)
    return [sorted_penalties[i] for i in penalty_indices]


selected_penalties_per_method = [
    get_selected_penalties(r) for r in results_list
]

fig, axes = plt.subplots(3, 5, figsize=(15, 9))

for row_idx, (results, method_name, selected_penalties) in enumerate(
    zip(results_list, method_names, selected_penalties_per_method)
):
    for col_idx, penalty in enumerate(selected_penalties):
        ax = axes[row_idx, col_idx]

        # Get the result for this penalty
        result_row = results[results["penalty"] == penalty]

        if len(result_row) > 0:
            # Get the transport plan and compute F_hat
            pi = result_row.iloc[0]["fair_ot_plan"]
            F_hat = compute_F_hat(pi, S_X, S_Y)

            sns.heatmap(
                F_hat,
                ax=ax,
                cmap="Blues",
                cbar=False,
                annot=True,
                fmt=".2f",
                annot_kws={"size": 8},
                vmin=vmin,
                vmax=vmax,
            )

        # Row labels (method names) on the left
        if col_idx == 0:
            ax.set_ylabel(method_name, fontsize=12, fontweight="bold")
        else:
            ax.set_ylabel(None)

        # Column labels (penalty values) on top
        ax.set_title(f"λ={penalty:.1f}", fontsize=11)

        ax.set_xlabel(None)
        ax.tick_params(axis="both", labelsize=8)
        for spine in ax.spines.values():
            spine.set_color("black")
            spine.set_linewidth(1)

plt.tight_layout()
plt.savefig("F_matrix_comparison.pdf")
plt.show()

# %%
# For LogNorm, vmin must be positive, so use min of positive values
vmin = min(
    min(
        results["fair_ot_plan"].apply(
            lambda x: x[x > 0].min() if (x > 0).any() else 1e-10
        )
    )
    for results in results_list
)

vmax = max(
    max(results["fair_ot_plan"].apply(lambda x: x.max()))
    for results in results_list
)

# Plot transport plans as heatmaps for each method (3 rows) and 5 penalty
# values (5 columns)

fig, axes = plt.subplots(3, 5, figsize=(25, 12))

for row_idx, (results, method_name, selected_penalties) in enumerate(
    zip(results_list, method_names, selected_penalties_per_method)
):
    for col_idx, penalty in enumerate(selected_penalties):
        ax = axes[row_idx, col_idx]

        # Get the result for this penalty
        result_row = results[results["penalty"] == penalty]

        if len(result_row) > 0:
            # Get the transport plan
            pi = result_row.iloc[0]["fair_ot_plan"]
            if hasattr(pi, "numpy"):
                pi = pi.numpy()

            # Plot transport plan as heatmap
            sns.heatmap(
                pi,
                norm=plt.matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax),
                cmap="YlGnBu",
                ax=ax,
                cbar=False,
                vmin=vmin,
                vmax=vmax,
            )

        # Column labels (penalty values) on top
        ax.set_title(f"λ={penalty:.1f}", fontsize=11)

        # Row labels (method names) on the left
        if col_idx == 0:
            ax.set_ylabel(method_name, fontsize=12, fontweight="bold")
        else:
            ax.set_ylabel("")

        ax.set_xlabel("")
        ax.set_xticks([])
        ax.set_yticks([])


fig.colorbar(
    plt.cm.ScalarMappable(
        cmap="YlGnBu", norm=plt.Normalize(vmin=vmin, vmax=vmax)
    ),
    ax=axes,
    shrink=0.8,
    fontsize=16,
)


# %%
# =============================================================================
# Figure: Both settings (Gaussian and Nested Circles) point clouds
# =============================================================================

# Generate data for both settings
(X_gaussian, Y_gaussian), (S_X_gaussian, S_Y_gaussian) = (
    get_gaussian_linear_transform_data(
        d=2,
        n_x=250,
        n_y=25,
        scale=0.2,
        p_x0=0.5,
        p_y0=0.5,
        centers_X=[np.array([0, 0]), np.array([2.0, 0.0])],
        centers_Y=[np.array([1.0, 1.0]), np.array([2.5, 0.5])],
        rng=42,
    )
)

(X_circles, Y_circles), (S_X_circles, S_Y_circles) = get_nested_circles(
    n_x=250,
    n_y=25,
    p_X0=0.5,
    p_Y0=0.5,
    noise_0=0.15,
    noise_1=0.2,
    diameter=4.0,
    rng=42,
    n_outliers_x=4,
    n_outliers_y=2,
)

# Create figure with GridSpec for custom layout
plt.rc("font", family="serif")
# sns.set_style("white")

fig = plt.figure(figsize=(8, 4))
gs = fig.add_gridspec(2, 2, height_ratios=[0.15, 1], hspace=0.1)

# Define colors for groups
colors_X = {0: "tab:blue", 1: "tab:orange"}
colors_Y = {0: "tab:green", 1: "tab:red"}

# Row 0: Single shared legend spanning both columns
ax_legend = fig.add_subplot(gs[0, :])
ax_legend.axis("off")
legend_elements = [
    plt.Line2D(
        [0],
        [0],
        marker="o",
        color="w",
        markerfacecolor=colors_X[0],
        markersize=12,
        label=r"$X$, $S=0$",
    ),
    plt.Line2D(
        [0],
        [0],
        marker="o",
        color="w",
        markerfacecolor=colors_X[1],
        markersize=12,
        label=r"$X$, $S=1$",
    ),
    plt.Line2D(
        [0],
        [0],
        marker="s",
        color="w",
        markerfacecolor=colors_Y[0],
        markersize=12,
        label=r"$Y$, $W=0$",
    ),
    plt.Line2D(
        [0],
        [0],
        marker="s",
        color="w",
        markerfacecolor=colors_Y[1],
        markersize=12,
        label=r"$Y$, $W=1$",
    ),
]
ax_legend.legend(
    handles=legend_elements,
    loc="center",
    fontsize=14,
    ncol=4,
    frameon=True,
    fancybox=True,
)

# Row 1: Point clouds
# Plot Gaussian point cloud (1, 0)
ax_gaussian = fig.add_subplot(gs[1, 0])
for group in [0, 1]:
    mask_X = S_X_gaussian == group
    ax_gaussian.scatter(
        X_gaussian[mask_X, 0],
        X_gaussian[mask_X, 1],
        c=colors_X[group],
        marker="o",
        s=50,
        alpha=0.7,
        label=f"X - Group {group}",
    )
    mask_Y = S_Y_gaussian == group
    ax_gaussian.scatter(
        Y_gaussian[mask_Y, 0],
        Y_gaussian[mask_Y, 1],
        c=colors_Y[group],
        marker="s",
        s=100,
        alpha=0.9,
        edgecolor="black",
        linewidth=1,
        label=f"Y - Group {group}",
    )
ax_gaussian.tick_params(axis="both", labelsize=14)
ax_gaussian.set_box_aspect(1)
for spine in ax_gaussian.spines.values():
    spine.set_color("black")
    spine.set_linewidth(1)

# Plot Nested Circles point cloud (1, 1)
ax_circles = fig.add_subplot(gs[1, 1])
for group in [0, 1]:
    mask_X = S_X_circles == group
    ax_circles.scatter(
        X_circles[mask_X, 0],
        X_circles[mask_X, 1],
        c=colors_X[group],
        marker="o",
        s=50,
        alpha=0.7,
        label=f"X - Group {group}",
    )
    mask_Y = S_Y_circles == group
    ax_circles.scatter(
        Y_circles[mask_Y, 0],
        Y_circles[mask_Y, 1],
        c=colors_Y[group],
        marker="s",
        s=100,
        alpha=0.9,
        edgecolor="black",
        linewidth=1,
        label=f"Y - Group {group}",
    )
ax_circles.tick_params(axis="both", labelsize=14)
ax_circles.set_box_aspect(1)
for spine in ax_circles.spines.values():
    spine.set_color("black")
    spine.set_linewidth(1)

plt.tight_layout()
plt.savefig("simulated_settings.pdf")
plt.show()

# %%
