import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde


def generate_skewed_distribution(n=1000):
    """
    Generate a skewed distribution of numbers with:
      - ~75–80% of values as 2
      - ~15–20% of values as 1
      - ~5% (remaining) as 3

    :param n: total number of samples
    :return: numpy array of generated data
    """
    num_2 = int(n * np.random.uniform(0.70, 0.8))  # ~75-80% 2s
    num_1 =  n - (num_2) # ~15-20% 1s
    # num_3 = n - (num_2 + num_1)  # remaining ~5% 3s

    data = np.concatenate([np.full(num_2, 2), np.full(num_1, 1)])
    np.random.shuffle(data)  # Shuffle for randomness

    return data, num_1, num_2


def plot_distribution(data, num_1, num_2):
    """
    Plot the smooth curve for the skewed distribution.
    """
    kde = gaussian_kde(data)
    x_vals = np.linspace(0, 4, 500)
    y_vals = kde(x_vals)

    plt.figure(figsize=(8, 5))
    plt.plot(x_vals, y_vals)
    plt.fill_between(x_vals, y_vals, alpha=0.3)

    # Add markers for clarity
    for val, count in zip([1, 2], [num_1, num_2]):
        plt.axvline(val, linestyle="--", alpha=0.6, label=f"{val}: {count} samples")

    plt.title(f"Disease Hotspot Distribution", fontsize=20)
    plt.xlabel("Number of patches", fontsize=20)
    plt.ylabel("Density (KDE)", fontsize=20)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.legend(fontsize=14)

    plt.savefig("/Users/C00540403/Documents/research/Foliagen/FoliageGenerator/src/soybean/hotspot_distribution.pdf", dpi=300, bbox_inches="tight")

    plt.show()


if __name__ == "__main__":
    data, num_1, num_2 = generate_skewed_distribution(n=8000)
    print(f"Counts -> 1: {num_1}, 2: {num_2}")
    plot_distribution(data, num_1, num_2)
