import cv2
import numpy as np
from PIL import Image
import os
import csv
import math
from facenet_pytorch import MTCNN
from skimage import color

image_size = 224
margin = 2


# Paths
current_directory = os.getcwd()
dataset_path = r"./img/"
image_name = "3.png"
csv_file_path = os.path.join(current_directory, "skintone.csv")
temp_image_path = os.path.join(current_directory, "temp.jpg")
final_path = os.path.join(current_directory, "")


# MTCNN for face detection and Inception ResNet for feature extraction
mtcnn = MTCNN(image_size=image_size, margin=margin)
path = dataset_path + image_name

# Load the image
img = cv2.imread(path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Convert to PIL Image and detect using MTCNN
img_tensor = mtcnn(img_rgb, save_path=temp)
# Load the temp image
img_array = cv2.imread(temp)
img_crp = cv2.imread(temp)

# Skin detection
img_HSV = cv2.cvtColor(img_array, cv2.COLOR_BGR2HSV)
HSV_mask = cv2.inRange(img_HSV, (0, 15, 0), (17, 170, 255))
HSV_mask = cv2.morphologyEx(HSV_mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))

img_YCrCb = cv2.cvtColor(img_array, cv2.COLOR_BGR2YCrCb)
YCrCb_mask = cv2.inRange(img_YCrCb, (0, 135, 85), (255, 180, 135))
YCrCb_mask = cv2.morphologyEx(YCrCb_mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))

global_mask = cv2.bitwise_and(YCrCb_mask, HSV_mask)
global_mask = cv2.medianBlur(global_mask, 3)
global_mask = cv2.morphologyEx(global_mask, cv2.MORPH_OPEN, np.ones((4, 4), np.uint8))

# Apply the mask to get the skin region
skin = cv2.bitwise_and(img_array, img_array, mask=global_mask)

skin_rgb = cv2.cvtColor(skin, cv2.COLOR_BGR2RGB)

# Plotting the masks for visualization
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

# Display the HSV mask
ax[0].imshow(HSV_mask, cmap='gray')
ax[0].set_title("HSV Mask")
ax[0].axis('off')

# Display the YCrCb mask
ax[1].imshow(YCrCb_mask, cmap='gray')
ax[1].set_title("YCrCb Mask")
ax[1].axis('off')

ax[2].imshow(skin_rgb, cmap='gray')
ax[2].set_title("Apply Mask")
ax[2].axis('off')

plt.show()

# Convert back to PIL
img_cropped = Image.fromarray(cv2.cvtColor(img_crp, cv2.COLOR_BGR2RGB))
img_resized = img_cropped.resize((224, 224))
# Detect landmarks and process image
boxes, probs, landmarks = mtcnn.detect(img_resized, landmarks=True)
img_cropped_np = np.array(img_resized)

if landmarks is not None:
    for landmark in landmarks:
        # Assuming the landmarks for the mouth are at indices 3 and 4 (corners of the mouth)
        left_mouth = landmark[3]
        right_mouth = landmark[4]

        # Calculate the bounding box for the mouth area based on the landmarks
        x1, y1 = int(left_mouth[0]), int(left_mouth[1])
        x2, y2 = int(right_mouth[0]), int(right_mouth[1])

        # Ensure the rectangle covers from the top of the left point to the bottom of the right point
        top_left_x = min(x1, x2)-20
        top_left_y = min(y1, y2) - 20  # Adjust slightly to cover the mouth vertically
        bottom_right_x = max(x1, x2)+20
        bottom_right_y = max(y1, y2) + 20  # Adjust slightly to cover the mouth vertically

        # Draw a black rectangle over the mouth area on the image
        cv2.rectangle(img_cropped_np, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (0, 0, 0), -1)


# Convert the image back to PIL and save as temp1.jpg
img_with_landmarks = Image.fromarray(img_cropped_np)
img_with_landmarks.save(temp)

# Continue with the rest of the skin tone calculation and saving
img_BGR = cv2.imread(temp, 3)
img_BGR_small = cv2.resize(img_BGR, (0, 0), fx=0.1, fy=0.1)
img_grayscale = cv2.cvtColor(img_BGR_small, cv2.COLOR_BGR2GRAY)

threshold_value, threshold_image = cv2.threshold(img_grayscale, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
threshold_image_binary = 1 - threshold_image / 255
threshold_image_binary = np.repeat(threshold_image_binary[:, :, np.newaxis], 3, axis=2)
img_face_only = np.multiply(threshold_image_binary, img_BGR_small).astype(np.uint8)
img_HSV = cv2.cvtColor(img_face_only, cv2.COLOR_BGR2HSV)
img_YCrCb = cv2.cvtColor(img_face_only, cv2.COLOR_BGR2YCrCb)


# # Plotting the masks for visualization
# fig, ax = plt.subplots(1, 3, figsize=(15, 5))

# # Display the HSV mask
# ax[0].imshow(img_HSV, cmap='gray')
# ax[0].set_title("HSV Mask")
# ax[0].axis('off')

# # Display the YCrCb mask
# ax[1].imshow(img_YCrCb, cmap='gray')
# ax[1].set_title("YCrCb Mask")
# ax[1].axis('off')

# ax[2].imshow(img_face_only, cmap='gray')
# ax[2].set_title("Apply Mask")
# ax[2].axis('off')

# plt.show()


blue, green, red = [], [], []
height, width, channels = img_face_only.shape

for i in range(height):
    for j in range(width):
        if (img_HSV.item(i, j, 0) <= 170) and (140 <= img_YCrCb.item(i, j, 1) <= 170) and (90 <= img_YCrCb.item(i, j, 2) <= 120):
            blue.append(img_face_only[i, j].item(0))
            green.append(img_face_only[i, j].item(1))
            red.append(img_face_only[i, j].item(2))
        else:
            img_face_only[i, j] = [0, 0, 0]

# Skin tone estimate
skin_tone_estimate_BGR = [np.mean(blue), np.mean(green), np.mean(red)]
skin_tone_estimate_RGB = [np.mean(red), np.mean(green), np.mean(blue)]

tone_rgb = [[[x / 255 for x in skin_tone_estimate_RGB]]]
tone_lab = color.rgb2lab(tone_rgb)
l, a, b = tone_lab.flatten()
ita = np.arctan((l - 50) / b) * (180 / math.pi)

rgb_crp = cv2.cvtColor(img_crp, cv2.COLOR_BGR2RGB)
rgb_crp = Image.fromarray(rgb_crp)

final_directory = final_path + ita_2_tone(ita) + "\\"

if not os.path.exists(final_directory):
    os.makedirs(final_directory)

rgb_crp.save(final_directory + "{}".format(jj))

with open(csv_file_path, 'a', newline='') as file:
    writer = csv.writer(file)
    writer.writerow([jj, ita, ita_2_tone(ita)])

print(ita)

print(f"Landmarks saved in temp1.jpg and data written to {csv_file_path}")