


import csv
import numpy as np

def load_saliency_csv(csv_path):
    with open(csv_path, newline='') as f:
        reader = csv.reader(f)
        header = next(reader)
        data = list(reader)

    image_names = [row[0] for row in data]
    labels = [int(row[1]) for row in data]
    saliency_maps = np.array([list(map(float, row[2:])) for row in data])

    return image_names, labels, saliency_maps

def analyze_saliency(csv_path):
    image_names, labels, saliency_maps = load_saliency_csv(csv_path)

    # Replace -1.0 with NaN to ignore unperturbed regions
    masked_saliency = np.where(saliency_maps == -1.0, np.nan, saliency_maps)

    # Flatten and filter valid values
    valid_vals = masked_saliency[~np.isnan(masked_saliency)]
    negative_vals = valid_vals[valid_vals < 0]
    positive_vals = valid_vals[valid_vals > 0]
    abs_vals = np.abs(valid_vals)

    # Compute stats
    stats = {
        "All": {
            "mean": np.nanmean(valid_vals),
            "sum": np.nansum(valid_vals),
            "median": np.nanmedian(valid_vals)
        },
        "Negative": {
            "mean": np.nanmean(negative_vals) if len(negative_vals) > 0 else np.nan,
            "sum": np.nansum(negative_vals) if len(negative_vals) > 0 else np.nan,
            "median": np.nanmedian(negative_vals) if len(negative_vals) > 0 else np.nan
        },
        "Positive": {
            "mean": np.nanmean(positive_vals) if len(positive_vals) > 0 else np.nan,
            "sum": np.nansum(positive_vals) if len(positive_vals) > 0 else np.nan,
            "median": np.nanmedian(positive_vals) if len(positive_vals) > 0 else np.nan
        },
        "Absolute": {
            "mean": np.nanmean(abs_vals),
            "sum": np.sum(abs_vals),
            "median": np.nanmedian(abs_vals)
        }
    }

    # Display summary
    print(f"\nLoaded {len(image_names)} images from: {csv_path}")
    print(f"Total valid entries (non-NaN): {len(valid_vals)}\n")
    for key, vals in stats.items():
        print(f"--- {key} Values ---")
        print(f"Mean   : {vals['mean']:.4f}")
        print(f"Sum    : {vals['sum']:.4f}")
        print(f"Median : {vals['median']:.4f}\n")


    #for name, mean_val, sum_val in zip(image_names, mean_per_image, sum_per_image):
    #    print(f"{name}: mean={mean_val:.4f}, sum={sum_val:.4f}")





