import detecto
# !pip install detecto
from detecto import core, utils, visualize
from detecto.visualize import show_labeled_image, plot_prediction_grid
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
import torch
import cv2
import os
import csv
import math
from skimage import color

def ita_2_tone(ita):
    if ita > 55:
        return 'Very_Light'
    if 41 < ita <=55:
        return 'Light'
    if 28 < ita <=41:
        return 'Intermediate'
    if 10 < ita <= 28:
        return 'Tan'
    if -30 < ita <= 10:
        return 'Brown'
    if ita <= -30:
        return 'Dark'

model = core.Model.load(r"./model_rcnn.pth", ['finger'])
# Paths
current_directory = os.getcwd()
dataset_path = r"./img/"
image_name = "2.jpg"
csv_file_path = os.path.join(current_directory, "skintone.csv")
temp = os.path.join(current_directory, "temp1.jpg")
final_path = os.path.join(current_directory, "")
image_path = dataset_path + image_name

# Read the image using the utils.read_image function
image = utils.read_image(image_path)
predictions = model.predict(image)
print(f"Processing image: {image_path}")

# Extract predictions
labels, boxes, scores = predictions
thresh = 0.6  #optimal Threshold set to 0.6 according to literature

# Filter based on score threshold
filtered_indices = np.where(scores > thresh)
filtered_scores = scores[filtered_indices]
filtered_boxes = boxes[filtered_indices].to(torch.long)

# Extract corresponding labels
num_list = filtered_indices[0].tolist()
filtered_labels = [labels[i] for i in num_list]

# Convert to RGB for OpenCV
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# If no boxes are detected, handle accordingly
if filtered_boxes.numel() == 0:
    print("No boxes detected, saving the whole image.")
    cv2.imwrite("./temp1.jpg", image)  # Save the entire image as output_image.jpg
else:
    # If boxes are detected, crop the first box region
    X, Y, W, H = filtered_boxes[0]
    roi = image[Y:H, X:W]
    
    # Save the cropped region
    cv2.imwrite("./temp1.jpg", roi)
    print(f"Box detected, cropped and saved the region for image: {image_path}")


# Read the saved image back
img_array = cv2.imread(temp)
img_crp = cv2.imread(temp)  # This seems to be used later

# Convert the image to HSV color space
img_HSV = cv2.cvtColor(img_array, cv2.COLOR_BGR2HSV)

# Define the skin color range for HSV color space
HSV_mask = cv2.inRange(img_HSV, (0, 15, 0), (17, 170, 255))
HSV_mask = cv2.morphologyEx(HSV_mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))

# Convert the image to YCbCr color space
img_YCrCb = cv2.cvtColor(img_array, cv2.COLOR_BGR2YCrCb)

# Define the skin color range for YCbCr color space
YCrCb_mask = cv2.inRange(img_YCrCb, (0, 135, 85), (255, 180, 135))
YCrCb_mask = cv2.morphologyEx(YCrCb_mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8))

# Merge the skin detection results from both color spaces
global_mask = cv2.bitwise_and(YCrCb_mask, HSV_mask)
global_mask = cv2.medianBlur(YCrCb_mask, 3)
global_mask = cv2.morphologyEx(global_mask, cv2.MORPH_OPEN, np.ones((4, 4), np.uint8))

# Apply the mask to get the skin region
skin = cv2.bitwise_and(img_array, img_array, mask=global_mask)

img_BGR_small = cv2.resize(skin, (0, 0), fx = 0.1, fy = 0.1)
img_grayscale = cv2.cvtColor(img_BGR_small, cv2.COLOR_BGR2GRAY)

threshold_value, threshold_image = cv2.threshold(img_grayscale, 0, 255,cv2.THRESH_BINARY_INV+cv2.THRESH_OTSU)
threshold_image_binary = 1- threshold_image/255
threshold_image_binary = np.repeat(threshold_image_binary[:, :, np.newaxis], 3, axis=2)
img_face_only = np.multiply(threshold_image_binary, img_BGR_small)
img_face_only = img_face_only.astype(np.uint8)
img_HSV = cv2.cvtColor(img_face_only, cv2.COLOR_BGR2HSV)
img_YCrCb = cv2.cvtColor(img_face_only, cv2.COLOR_BGR2YCrCb)

# Aggregate skin pixels
blue = []
green = []
red = []

height, width, channels = img_face_only.shape

for i in range (height):
    for j in range (width):
        if((img_HSV.item(i, j, 0) <= 170) and (140 <= img_YCrCb.item(i, j, 1) <= 170) and (90 <= img_YCrCb.item(i, j, 2) <= 120)):
            blue.append(img_face_only[i, j].item(0))
            green.append(img_face_only[i, j].item(1))
            red.append(img_face_only[i, j].item(2))
        else:
            img_face_only[i, j] = [0, 0, 0]

# Determine mean skin tone estimate in BGR
skin_tone_estimate_BGR = [np.mean(blue), np.mean(green), np.mean(red)]


# Determine mean skin tone estimate in RGB
skin_tone_estimate_RGB = [np.mean(red), np.mean(green), np.mean(blue)]


# Display the skin tone
tone_rgb = [[[x/255 for x in skin_tone_estimate_RGB]]]

tone_lab = color.rgb2lab(tone_rgb)
# Get individual colors in LAB color space
l,a,b = tone_lab.flatten()
ita = np.arctan((l-50)/b)*(180/math.pi) #--------------ITA Formula

skin_pil = Image.fromarray(cv2.cvtColor(img_crp, cv2.COLOR_BGR2RGB))

final_directory = os.path.join(final_path, ita_2_tone(ita))

if not os.path.exists(final_directory):
    os.makedirs(final_directory)

skin_pil.save(final_directory+"/{}".format(image_name))

with open(csv_file_path, 'a', newline='') as file:
    writer = csv.writer(file)
    writer.writerow([image_name, ita, ita_2_tone(ita)])

# #remove the Temp file
if os.path.exists(temp):
    os.remove(temp)
    
print(f"ITA value = {ita}")
print(f"Skin tone = {ita_2_tone(ita)}")