import cv2
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import os
from glob import glob
import cv2
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import os
from glob import glob
import seaborn as sns
import matplotlib.pyplot as plt


'''
#To run this code, fisrt user must run the skintone code on the entire dataset, which saves all the images in a 6 different directories
(Dark, Brown, Light, Very Light, Intermediate,Tan) 

# Then use the PCA code to test the variability of the skin tone on the dataset. 

'''
f = "Give Dataset Path"
# Paths to different skin tone image folders

skin_tone_folders = {
    'Brown': r"{}\Brown".format(f),
    'Light': r"{}\Light".format(f),
    'Tan': r"{}\Tan".format(f),
    'Intermediate': r"{}\Intermediate".format(f),
    'Dark': r"{}\Dark".format(f),
    'Very_Light': r"{}\Very_Light".format(f)
}


# Function to extract color histogram features from an RGB image, excluding black pixels
def extract_color_histogram(image, bins=(8, 8, 8)):
    """
    Extracts color histogram features from an RGB image, excluding black pixels.
    Args:
        image: Input RGB image as a NumPy array of shape (H, W, 3).
        bins: Number of bins for the histogram in each channel.
    Returns:
        hist: Flattened histogram as a NumPy array.
    """
    # Resize image to 224x224
    image = cv2.resize(image, (224, 224))

    # Create a mask for non-black pixels
    mask = cv2.inRange(image, np.array([1, 1, 1]), np.array([255, 255, 255]))

    # Apply the mask to the image
    masked_image = cv2.bitwise_and(image, image, mask=mask)

    # Compute the histogram for each of the RGB channels
    hist_list = []
    for i in range(3):
        hist, _ = np.histogram(masked_image[:, :, i], bins=bins, range=(0, 256))
        hist_list.append(hist)

    # Concatenate and normalize the histograms
    hist = np.concatenate(hist_list)
    
    # Ensure there is no division by zero
    if hist.sum() > 0:
        hist = hist / hist.sum()  # Normalize the histogram
    else:
        hist = np.zeros_like(hist)  # Replace with zeros if sum is zero

    return hist

# Load images from the directory Extension: .png, .jpg, .jpeg
def load_images_from_folder(folder_path, image_size=(224, 224)):
    images = []
    for filename in glob(os.path.join(folder_path, "*.[pj][pn]g")):
        img = cv2.imread(filename)
        if img is not None:
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
            img_rgb_resized = cv2.resize(img_rgb, image_size)  # Resize the image
            images.append(img_rgb_resized)
    return images



# Load images and extract features
all_features = []
labels = []
skintone_labels = []

for i, (skintone, folder_path) in enumerate(skin_tone_folders.items()):
    images = load_images_from_folder(folder_path)
    print(f"Loaded {len(images)} images from {folder_path}") 
    features = np.array([extract_color_histogram(img) for img in images])
    all_features.append(features)
    labels.extend([i] * len(features))
    skintone_labels.extend([skintone] * len(features))

# Combine all features into one array
all_features = np.concatenate(all_features, axis=0)

# Check and handle NaNs before applying PCA
all_features = np.nan_to_num(all_features)  # Replace NaNs with 0

labels = np.array(labels)  # Convert to NumPy array

# Apply PCA
pca = PCA(n_components=2)

#PCA RGB features are extracted 
pca_result_rgb = pca.fit_transform(all_features)


######################################################################################

# Function to extract color histogram features from a Color Cube (RGB, HSV, YCbCr) image, excluding black pixels
def extract_colorcube_histogram(image, bins=(8, 8, 8)):
    """
    Extracts color histogram features from an image in multiple color spaces (RGB, HSV, YCbCr),
    excluding black pixels.
    Args:
        image: Input RGB image as a NumPy array of shape (H, W, 3).
        bins: Number of bins for the histogram in each channel.
    Returns:
        hist: Flattened histogram as a NumPy array.
    """
    # Resize image to 224x224
    image = cv2.resize(image, (224, 224))

    # Create a mask for non-black pixels
    mask = cv2.inRange(image, np.array([1, 1, 1]), np.array([255, 255, 255]))

    # Apply the mask to the image
    masked_image = cv2.bitwise_and(image, image, mask=mask)

    # Convert to HSV and YCbCr color spaces
    hsv_image = cv2.cvtColor(masked_image, cv2.COLOR_RGB2HSV)
    ycbcr_image = cv2.cvtColor(masked_image, cv2.COLOR_RGB2YCrCb)

    # Compute the histogram for each of the RGB, HSV, and YCbCr channels
    hist_list = []
    for img in [masked_image, hsv_image, ycbcr_image]:
        for i in range(3):
            hist, _ = np.histogram(img[:, :, i], bins=bins, range=(0, 256))
            hist_list.append(hist)

    # Concatenate and normalize the histograms
    hist = np.concatenate(hist_list)
    
    # Ensure there is no division by zero
    if hist.sum() > 0:
        hist = hist / hist.sum()  # Normalize the histogram
    else:
        hist = np.zeros_like(hist)  # Replace with zeros if sum is zero

    return hist

# Load images and extract Color Cube features
all_colorcube_features = []
for i, (skintone, folder_path) in enumerate(skin_tone_folders.items()):
    images = load_images_from_folder(folder_path)
    features = np.array([extract_colorcube_histogram(img) for img in images])
    all_colorcube_features.append(features)

# Combine all Color Cube features into one array
all_colorcube_features = np.concatenate(all_colorcube_features, axis=0)

# Check and handle NaNs before applying PCA
all_colorcube_features = np.nan_to_num(all_colorcube_features)  # Replace NaNs with 0

# Apply PCA to the Color Cube features
pca_result_cube = PCA(n_components=2).fit_transform(all_colorcube_features)



def plot_pca_violinplots_multiple_skintones(pca_result_rgb, pca_result_cube, labels, skintone_labels):
    # Assign colors to each skin tone category
    skintone_colors = {
        'Brown': '#8B4513',       # SaddleBrown
        'Light': '#FFDAB9',       # PeachPuff
        'Tan': '#D2B48C',         # Tan
        'Intermediate': '#F5DEB3', # Wheat
        'Dark': '#8B0000',        # DarkRed
        'Very_Light': '#FFE4E1'   # MistyRose
    }
    
    # Create a figure with subplots
    plt.figure(figsize=(14, 10))

    # Violin plot for Principal Component 1 (PC1) - RGB Features
    plt.subplot(2, 2, 1)
    sns.violinplot(x=skintone_labels, y=pca_result_rgb[:, 0], hue=skintone_labels, palette=skintone_colors, dodge=False)
    plt.title('PC1 (RGB Features)', color='black')
    plt.xlabel('Skin Tone', color='black')
    plt.ylabel('Principal Component 1', color='black')
    plt.legend([], [], frameon=False)  # Hide the legend

    # Violin plot for Principal Component 1 (PC1) - Color Cube Features
    plt.subplot(2, 2, 2)
    sns.violinplot(x=skintone_labels, y=pca_result_cube[:, 0], hue=skintone_labels, palette=skintone_colors, dodge=False)
    plt.title('PC1 (Color Cube Features)', color='black')
    plt.xlabel('Skin Tone', color='black')
    plt.ylabel('Principal Component 1', color='black')
    plt.legend([], [], frameon=False)  # Hide the legend

    # Violin plot for Principal Component 2 (PC2) - RGB Features
    plt.subplot(2, 2, 3)
    sns.violinplot(x=skintone_labels, y=pca_result_rgb[:, 1], hue=skintone_labels, palette=skintone_colors, dodge=False)
    plt.title('PC2 (RGB Features)', color='black')
    plt.xlabel('Skin Tone', color='black')
    plt.ylabel('Principal Component 2', color='black')
    plt.legend([], [], frameon=False)  # Hide the legend

    # Violin plot for Principal Component 2 (PC2) - Color Cube Features
    plt.subplot(2, 2, 4)
    sns.violinplot(x=skintone_labels, y=pca_result_cube[:, 1], hue=skintone_labels, palette=skintone_colors, dodge=False)
    plt.title('PC2 (Color Cube Features)', color='black')
    plt.xlabel('Skin Tone', color='black')
    plt.ylabel('Principal Component 2', color='black')
    plt.legend([], [], frameon=False)  # Hide the legend

    # Adjust layout
    plt.tight_layout()
    plt.show()

# Example usage:
# Assuming skintone_labels is an array of skin tone categories (e.g., ['Light', 'Medium', 'Dark', 'Very Dark'])
plot_pca_violinplots_multiple_skintones(pca_result_rgb, pca_result_cube, labels, skintone_labels)

