import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# =========================================================
# 1. FILE PATH
# =========================================================
# Since p.py and the Excel file are in the same folder,
# use only the file name.
file_path = "Statistical data analysis results_Proteomics_Differential analysis_updated_v1_HJ SK_07022025.xlsx"

# Output folder
output_dir = "volcano_plots_500dpi"
os.makedirs(output_dir, exist_ok=True)

# =========================================================
# 2. SETTINGS
# =========================================================
FC_THRESHOLD = 1.0
P_THRESHOLD = 0.05
DPI_VALUE = 500

# =========================================================
# 3. HELPER FUNCTIONS
# =========================================================
def safe_name(name):
    return re.sub(r'[\\/*?:"<>|]', "_", str(name))

def clean_volcano_df(df, log2fc_col, pval_col):
    out = pd.DataFrame({
        "log2FC": pd.to_numeric(df[log2fc_col], errors="coerce"),
        "pval": pd.to_numeric(df[pval_col], errors="coerce")
    })

    out = out.dropna(subset=["log2FC", "pval"])
    out = out[np.isfinite(out["log2FC"])]
    out = out[np.isfinite(out["pval"])]
    out = out[out["pval"] > 0]

    if out.empty:
        return out

    out["neglog10p"] = -np.log10(out["pval"])

    out["group"] = "Not Significant"
    out.loc[(out["log2FC"] >= FC_THRESHOLD) & (out["pval"] < P_THRESHOLD), "group"] = "Up"
    out.loc[(out["log2FC"] <= -FC_THRESHOLD) & (out["pval"] < P_THRESHOLD), "group"] = "Down"

    return out

def plot_single_volcano(df, title, save_path):
    if df.empty:
        print(f"[SKIPPED] {title} -> no valid numeric data found")
        return

    ns = df[df["group"] == "Not Significant"]
    down = df[df["group"] == "Down"]
    up = df[df["group"] == "Up"]

    plt.figure(figsize=(9, 7), dpi=DPI_VALUE)

    plt.scatter(
        ns["log2FC"], ns["neglog10p"],
        s=12, alpha=0.45, c="lightgray", edgecolors="none"
    )
    plt.scatter(
        down["log2FC"], down["neglog10p"],
        s=14, alpha=0.70, c="red", edgecolors="none"
    )
    plt.scatter(
        up["log2FC"], up["neglog10p"],
        s=14, alpha=0.70, c="#3b5ba9", edgecolors="none"
    )

    y_thr = -np.log10(P_THRESHOLD)
    plt.axhline(y=y_thr, color="gray", linestyle="--", linewidth=1)
    plt.axvline(x=-FC_THRESHOLD, color="gray", linestyle="--", linewidth=1)
    plt.axvline(x= FC_THRESHOLD, color="gray", linestyle="--", linewidth=1)
    plt.axvline(x=0, color="black", linewidth=2.5)

    x_min = df["log2FC"].min()
    x_max = df["log2FC"].max()
    y_max = df["neglog10p"].max()

    x_pad = max(0.3, 0.05 * (x_max - x_min if x_max != x_min else 1))
    y_pad = max(0.5, 0.05 * (y_max if y_max > 0 else 1))

    plt.xlim(x_min - x_pad, x_max + x_pad)
    plt.ylim(0, y_max + y_pad)

    plt.xlabel("Fold Change (log$_2$)", fontsize=16)
    plt.ylabel("-log$_{10}$ Adjusted p-value", fontsize=16)
    plt.title(title, fontsize=14, fontweight="bold")

    plt.text(
        x_min, y_max * 0.95,
        "Negative change in\nprotein expression\ncompared to control",
        ha="left", va="top", fontsize=10
    )
    plt.text(
        x_max, y_max * 0.95,
        "Positive change in\nprotein expression\ncompared to control",
        ha="right", va="top", fontsize=10
    )
    plt.text(
        0, y_max * 0.98,
        "Zero Point",
        ha="center", va="bottom", fontsize=11, fontweight="bold"
    )

    plt.grid(alpha=0.2)
    plt.tight_layout()
    plt.savefig(save_path, dpi=DPI_VALUE, bbox_inches="tight")
    plt.close()

    print(f"[SAVED] {save_path}")

def plot_three_panel_volcano(datasets, main_title, save_path):
    fig, axes = plt.subplots(1, 3, figsize=(20, 6), dpi=DPI_VALUE)

    for ax, (sub_title, df) in zip(axes, datasets):
        if df.empty:
            ax.set_title(sub_title, fontsize=11, fontweight="bold")
            ax.text(0.5, 0.5, "No valid data", ha="center", va="center", transform=ax.transAxes)
            ax.axis("off")
            continue

        ns = df[df["group"] == "Not Significant"]
        down = df[df["group"] == "Down"]
        up = df[df["group"] == "Up"]

        ax.scatter(ns["log2FC"], ns["neglog10p"], s=12, alpha=0.45, c="lightgray", edgecolors="none")
        ax.scatter(down["log2FC"], down["neglog10p"], s=14, alpha=0.70, c="red", edgecolors="none")
        ax.scatter(up["log2FC"], up["neglog10p"], s=14, alpha=0.70, c="#3b5ba9", edgecolors="none")

        y_thr = -np.log10(P_THRESHOLD)
        ax.axhline(y=y_thr, color="gray", linestyle="--", linewidth=1)
        ax.axvline(x=-FC_THRESHOLD, color="gray", linestyle="--", linewidth=1)
        ax.axvline(x= FC_THRESHOLD, color="gray", linestyle="--", linewidth=1)
        ax.axvline(x=0, color="black", linewidth=2.2)

        x_min = df["log2FC"].min()
        x_max = df["log2FC"].max()
        y_max = df["neglog10p"].max()

        x_pad = max(0.3, 0.05 * (x_max - x_min if x_max != x_min else 1))
        y_pad = max(0.5, 0.05 * (y_max if y_max > 0 else 1))

        ax.set_xlim(x_min - x_pad, x_max + x_pad)
        ax.set_ylim(0, y_max + y_pad)

        ax.set_title(sub_title, fontsize=11, fontweight="bold")
        ax.set_xlabel("Fold Change (log$_2$)", fontsize=11)
        ax.set_ylabel("-log$_{10}$ Adjusted p-value", fontsize=11)
        ax.grid(alpha=0.2)
        ax.tick_params(labelsize=9)

    fig.suptitle(main_title, fontsize=14, fontweight="bold")
    fig.tight_layout()
    fig.savefig(save_path, dpi=DPI_VALUE, bbox_inches="tight")
    plt.close(fig)

    print(f"[SAVED] {save_path}")

# =========================================================
# 4. CHECK FILE EXISTS
# =========================================================
if not os.path.exists(file_path):
    raise FileNotFoundError(
        f"Excel file not found: {file_path}\n"
        f"Current working folder: {os.getcwd()}\n"
        f"Make sure p.py and the Excel file are in the same folder."
    )

# =========================================================
# 5. LOAD EXCEL FILE
# =========================================================
xls = pd.ExcelFile(file_path)
sheet_names = xls.sheet_names

print("\nSheets found:")
for i, s in enumerate(sheet_names, 1):
    print(f"{i}. {s}")

# =========================================================
# 6. PROCESS SHEETS
# =========================================================
for sheet in sheet_names:
    print(f"\nProcessing sheet: {sheet}")

    if not sheet.startswith("Protein_List_Venn"):
        df = pd.read_excel(file_path, sheet_name=sheet)

        # First try named columns
        if "Log2FC" in df.columns and "Adjusted_pval" in df.columns:
            temp = clean_volcano_df(df, "Log2FC", "Adjusted_pval")
        else:
            # Fallback to G and J columns
            temp_df = pd.DataFrame({
                "log2FC": df.iloc[:, 6],
                "pval": df.iloc[:, 9]
            })
            temp = clean_volcano_df(temp_df, "log2FC", "pval")

        save_path = os.path.join(output_dir, safe_name(sheet) + ".png")
        plot_single_volcano(temp, sheet, save_path)

    else:
        raw = pd.read_excel(file_path, sheet_name=sheet, header=None)
        data = raw.iloc[1:].reset_index(drop=True)

        if sheet == "Protein_List_Venn1":
            specs = [
                ("Snx13 ++ vs Snx13 +-", 2, 5),
                ("Snx13 +- vs Untransfected +-", 7, 10),
                ("Snx13 ++ vs Untransfected +-", 12, 15),
            ]
        elif sheet == "Protein_List_Venn2":
            specs = [
                ("Snx13 ++ vs Snx13 +-", 2, 5),
                ("Snx13 +- vs Snx13 --", 7, 10),
                ("Snx13 ++ vs Snx13 --", 12, 15),
            ]
        elif sheet == "Protein_List_Venn3":
            specs = [
                ("Snx13 ++ vs Snx13 +-", 2, 5),
                ("Snx13 +- vs mCherry +-", 7, 10),
                ("Snx13 ++ vs mCherry +-", 12, 15),
            ]
        else:
            print(f"[SKIPPED] Unknown Venn sheet format: {sheet}")
            continue

        datasets = []
        for sub_title, fc_col_idx, p_col_idx in specs:
            temp_df = pd.DataFrame({
                "log2FC": data.iloc[:, fc_col_idx],
                "pval": data.iloc[:, p_col_idx]
            })
            clean_df = clean_volcano_df(temp_df, "log2FC", "pval")
            datasets.append((sub_title, clean_df))

        save_path = os.path.join(output_dir, safe_name(sheet) + ".png")
        plot_three_panel_volcano(datasets, sheet, save_path)

print("\nDone.")
print(f"All plots saved in: {output_dir}")