import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import norm

# Read CSV file (handles single row or single column)
df = pd.read_csv("image_distribution.csv", header=None)

# If data is in a single row, flatten it
if df.shape[0] == 1:
    data = df.iloc[0].dropna().to_numpy()
else:  # normal case: one number per row
    data = df.iloc[:, 0].dropna().to_numpy()

# Convert to numeric
data = pd.to_numeric(data, errors="coerce")
data = data[~np.isnan(data)]  # remove NaNs

# Fit a normal distribution (get mean and std)
mu, std = norm.fit(data)

# Histogram
plt.figure(figsize=(10, 6))
counts, bins, _ = plt.hist(data, bins=range(int(min(data)), int(max(data)) + 2),
         density=True, alpha=0.6, color='skyblue', edgecolor='black', label="Histogram")

# Normal distribution curve
x_vals = np.linspace(min(data), max(data), 500)
pdf = norm.pdf(x_vals, mu, std)
pdf_scaled = pdf * len(data) * (bins[1] - bins[0])  # scale to counts
plt.plot(x_vals, pdf, color='red', linewidth=2, label=f"Normal Fit\nμ={mu:.2f}, σ={std:.2f}")

# Labels and title
plt.title("Distribution of disease rate in dataset", fontsize=24)
plt.xlabel("Disease rate (%)", fontsize=24)
plt.ylabel("Density", fontsize=24)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

plt.legend(fontsize=14)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.savefig("/Users/C00540403/Documents/research/Foliagen/FoliageGenerator/src/soybean/image_distribution.pdf", dpi=300,
            bbox_inches="tight")

plt.show()
