import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from scipy.stats import pearsonr
import os


def plot_spatial_signature_with_enhanced_visuals(all_values, save_bool=False, save_dir="results/plots"):
    os.makedirs(save_dir, exist_ok=True)
    for dataset_name, semivalues in all_values.items():
        for semivalue_name, data in semivalues.items():
            values = data[0]

            _, ax = plt.subplots(figsize=(7, 7))
            x_values = values[1]
            y_values = values[2]

            x_range = np.max(x_values) - np.min(x_values)
            y_range = np.max(y_values) - np.min(y_values)
            max_range = max(x_range, y_range)

            circle_radius = max_range * 0.65
            ax.scatter(
                x_values,
                y_values,
                color="royalblue",
                alpha=0.7,
                marker="x",
                s=20,
                label="$S_{\omega, \mathcal{D}}$",
            )

            theta = np.linspace(0, 2 * np.pi, 300)
            x_circle = circle_radius * np.cos(theta)
            y_circle = circle_radius * np.sin(theta)
            ax.plot(
                x_circle,
                y_circle,
                linestyle="dashed",
                color="gray",
                linewidth=1.2,
                label="$\mathcal{S}^1$",
            )

            alpha = (2, 1)
            u1 = alpha[0]
            u2 = alpha[1]
            norm = np.sqrt(u1**2 + u2**2)
            ax.scatter(
                (u1 / norm) * circle_radius,
                (u2 / norm) * circle_radius,
                label="α",
                edgecolors="black",
                s=60,
                marker="o",
            )

            ax.set_xlim(-circle_radius * 1.5, circle_radius * 1.5)
            ax.set_ylim(-circle_radius * 1.5, circle_radius * 1.5)

            ax.set_xlabel(r"$\phi(., \omega, \lambda)$", fontsize=14)
            ax.set_ylabel(r"$\phi(., \omega, \gamma)$", fontsize=14)
            ax.grid(True, linestyle="dotted", alpha=0.6)

            ax.legend(loc="upper left", fontsize=12)

            formatter = ScalarFormatter(useMathText=True)
            formatter.set_powerlimits(
                (-2, 2)
            )
            formatter.set_scientific(True)
            ax.xaxis.set_major_formatter(formatter)
            ax.yaxis.set_major_formatter(formatter)

            if save_bool:
                filename = f"{dataset_name}_{semivalue_name}.png"
                plt.savefig(os.path.join(save_dir, filename), dpi=300, bbox_inches="tight")


def plot_flat_Rp_df(df, p_values, save_path=None):
    methods = set()
    for col in df.columns:
        if col.startswith("R") and col.endswith("_mean"):
            parts = col.split("_")
            if len(parts) == 3:
                _, method, _ = parts
                methods.add(method)

    palette = ["#AEC6CF", "#FFD1DC", "#77DD77", "#FFB347", "#CBAACB", "#B0E0E6"]
    semivalue_specs = {
        method: (f"R{{p}}_{method}_mean", f"R{{p}}_{method}_se", palette[i % len(palette)])
        for i, method in enumerate(methods)
    }

    _, axes = plt.subplots(2, 4, figsize=(12, 6), sharey=True)
    axes = axes.flatten()

    for idx, row in df.iterrows():
        ax = axes[idx]
        for name, (mean_tmpl, se_tmpl, color) in semivalue_specs.items():
            try:
                means = [row[mean_tmpl.format(p=p)] for p in p_values]
                errs = [row[se_tmpl.format(p=p)] for p in p_values]
                ax.errorbar(
                    p_values,
                    means,
                    yerr=errs,
                    marker="o",
                    capsize=3,
                    label=name,
                    color=color,
                    linestyle="-",
                )
            except KeyError:
                continue 

        ax.set_title(row["Dataset"])
        ax.set_xticks(p_values)
        ax.set_xlabel(r"$p$")
        ax.set_ylabel(r"$R_p$")
        ax.grid(True, linestyle=":", linewidth=0.5)
        if idx == 0:
            ax.legend(fontsize=8)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
        print(f"Plot saved to {save_path}")
    plt.show()


def compute_normalized_rj(marg_contrib_dict):
    results = {}

    for dataset, runs in marg_contrib_dict.items():
        R = len(runs)
        n_points = len(runs[0][1])

        all_rj = np.zeros((R, n_points))

        for r_idx, (_, mu_dict, nu_dict) in enumerate(runs):
            rj_vec = np.zeros(n_points)

            for j in range(n_points):
                mu = np.array(mu_dict[j])
                nu = np.array(nu_dict[j])

                if np.std(mu) == 0 or np.std(nu) == 0:
                    rj_vec[j] = np.nan
                else:
                    corr = pearsonr(mu, nu)[0]
                    rj_vec[j] = corr * np.sqrt(np.var(mu) * np.var(nu))

            max_val = np.nanmax(np.abs(rj_vec))
            all_rj[r_idx] = rj_vec / max_val if max_val != 0 else rj_vec

        rj_mean = np.nanmean(all_rj, axis=0)
        rj_se = np.nanstd(all_rj, axis=0, ddof=1) / np.sqrt(np.sum(~np.isnan(all_rj), axis=0))

        results[dataset] = (rj_mean, rj_se)

    return results

