
import os
import glob
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr



def neg_sigmoid(x, a, b, c):
    return a * (1 - 1 / (1 + np.exp(-b * (x - c))))


def analyze_global_degradation_fit(csv_path, output_path=None, title=None,
                                   blindness_col="blindness", 
                                   corr_analysis=False, 
                                   quantile=0.99999, 
                                   thresholds = [0.5,0.8],
                                   num_bins=100, return_data=False, return_metrics=False):


    base_dir = os.path.dirname(csv_path)
    base_name = os.path.basename(csv_path).replace(".csv", "")  # e.g. 01_random_flowers_resnet50
    base_file_exists = os.path.exists(csv_path)

    # Match all files like base_name*.csv (including continued/part)
    candidate_files = glob.glob(os.path.join(base_dir, f"{base_name}*.csv"))

    # Extract numeric suffix for sorting (e.g., part12 → 12, main file → 0)
    def extract_part_number(path):
        if path.endswith(f"{base_name}.csv"):
            return 0
        match = re.search(r"_part(\d+)\.csv", path)
        return int(match.group(1)) if match else float('inf')

    if not candidate_files:
        raise FileNotFoundError(f"No CSV found at {csv_path} or matching pattern {base_name}*.csv")

    # Sort files so base file comes first
    candidate_files = sorted(candidate_files, key=extract_part_number)

    print(f"[INFO] Found {len(candidate_files)} relevant CSV files:")
    for f in candidate_files:
        print(f"  - {f}")

    # Concatenate all (even if only one file)
    data = pd.concat([pd.read_csv(f) for f in candidate_files], ignore_index=True)
    print(f"[INFO] Loaded {len(data)} rows from {len(candidate_files)} files.")



    # CALCULATION
    x_all = data[blindness_col].values
    y_all = data["gt_confidence"].values

    # Plot: scatter
    if output_path is not None:
        plt.figure(figsize=(8, 6))
        plt.scatter(x_all, y_all, alpha=0.4, color="#00876C", label="Samples")

    # Quantile envelope binning
    bins = np.linspace(x_all.min(), x_all.max(), num_bins + 1)
    bin_centers = 0.5 * (bins[:-1] + bins[1:])
    quant_vals = []

    for j in range(num_bins):
        bin_mask = (x_all >= bins[j]) & (x_all < bins[j + 1])
        if np.any(bin_mask):
            quant_vals.append(np.quantile(y_all[bin_mask], quantile))
        else:
            quant_vals.append(np.nan)

    bin_centers = np.array(bin_centers)
    quant_vals = np.array(quant_vals)
    valid = ~np.isnan(quant_vals)
    bin_centers = bin_centers[valid]
    quant_vals = quant_vals[valid]

    try:
        a0 = np.max(quant_vals)
        b0 = 10.0
        c0 = np.median(bin_centers)
        popt, _ = curve_fit(neg_sigmoid, bin_centers, quant_vals, p0=[a0, b0, c0], maxfev=10000)


        print(f"\n=== Fitted Sigmoid Parameters ===")
        print(f"a (scale):     {popt[0]:.4f}")
        print(f"b (slope):     {popt[1]:.4f}")
        print(f"c (inflection point): {popt[2]:.4f}")
        print("")
        print("")


        x_fit = np.linspace(bin_centers.min(), bin_centers.max(), 300)
        y_fit = neg_sigmoid(x_fit, *popt)

        if output_path is not None:
            plt.plot(x_fit, y_fit, color="#D5001C", linewidth=2, label=f"Sigmoid fit ({int(quantile*100)}th percentile)")

        # Derivatives, metrics
        dy = np.gradient(y_fit, x_fit)
        rbs = x_fit[np.argmax(np.abs(dy))]

        y_pred = neg_sigmoid(bin_centers, *popt)
        r2 = r2_score(quant_vals, y_pred)
        rmse = np.sqrt(mean_squared_error(quant_vals, y_pred))
        mae = mean_absolute_error(quant_vals, y_pred)

        aubc = np.trapz(y_fit, x_fit)

        acps = {}
        for tau in thresholds:
            below = y_fit < tau
            acps[f"ACP_{tau}"] = np.min(x_fit[below]) if np.any(below) else np.nan

        # Print metrics
        #print("\n=== Quantile-Based Fit Metrics ===\n")
        #print(f"R2:   {r2:.4f}")
        #print(f"RMSE: {rmse:.4f}")
        #print(f"MAE:  {mae:.4f}")

        print("METRICS")
        print(f"AUBC: {aubc:.4f}")
        print(f"RBS (drop point): {rbs:.4f}")
        for k, v in acps.items():
            print(f"{k}: {v:.4f}")

    except RuntimeError:
        print("Sigmoid fitting failed.")
    
    if output_path is not None:
        # Final plot
        print
        if blindness_col == "rel_blindness":
            plt.xlabel("Relative Blindness", fontsize=14)
        else:
            plt.xlabel("Blindness", fontsize=14)


        plt.ylabel("Confidence $C$", fontsize=14)
        
        
        #plt.title(f"Global Degradation Curve Fit (Quantile={quantile})", fontsize=16)
        plt.title(title, fontsize=16)

        plt.grid(True)
        plt.tight_layout()

        ext = os.path.splitext(output_path)[-1].lower()

        if ext == ".pdf":
            plt.savefig(output_path, format="pdf", bbox_inches="tight")

        elif ext == ".png":
            # High-res PNG (no quality param needed)
            plt.savefig(output_path, format="png", dpi=300, bbox_inches="tight")

        elif ext in [".jpg", ".jpeg"]:
            # High-quality JPEG (quality parameter allowed)
            plt.savefig(output_path, format="jpeg", dpi=300, quality=95, bbox_inches="tight")

        else:
            print(f"[WARN] Unknown file extension '{ext}' — saving as default PDF.")
            plt.savefig(output_path, format="pdf", bbox_inches="tight")

        plt.show()
    
    if return_data:
        return popt[0], popt[1], popt[2]
    
    if return_metrics:
        return aubc, rbs, acps

    


    
