#%% import libraries and configuration

import os
import cv2
import yaml
import math
import random
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from scaling_tool import *

# Set the cwd to the directory containing the script
os.chdir(os.path.dirname(os.path.abspath(__file__)))

with open(os.path.normpath(os.path.join(os.getcwd(),"config.yaml")), "r") as ymlfile:
    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

input_dir = cfg["training_dir"]+cfg["dataset_name"]
output_dir = cfg["output_dir"]
test_dir = cfg["test_dir"]
classes = cfg["classes"]

epochs = cfg["epochs"]
img_size = cfg["img_size"]

autoscale = cfg["autoscale"]
train_size_in_nm = cfg["train_size_in_nm"]
train_size_in_pxl = img_size
test_size_in_nm = cfg["test_size_in_nm"]


# 16 colors for maximum 16 classes
all_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 255, 0), (255, 0, 255),
            (255, 100, 0), (255, 255, 100), (0, 255, 100), (50, 255, 255), (50, 0, 255), (255, 50, 255),
            (255, 50, 0), (255, 255, 50), (100, 255, 0), (100, 255, 255), (0, 100, 255), (255, 100, 255),
            (255, 0, 100), (255, 255, 150), (50, 255, 50), (150, 255, 255), (0, 50, 255), (255, 150, 255)
            ]   
all_colors = ['r', 'g', 'b', 'y', 'c', 'm', 'k', 'w']
max_classes = len(all_colors)

#%% define functions
def plot_predictions(image, pred, output_dir = None, name = None, confidence = False):
    # plot the image with the predicted boxes in yolo format

    fig, ax = plt.subplots(figsize=(15, 15))
    fig.suptitle('Predictions')

    ax.imshow(image)
    ax.axis('off')

    for i in range(pred.shape[0]):
        box = pred[i,:]
                
        # resize box to image size
        x_center = box[1] * image.shape[1]
        y_center = box[2] * image.shape[0]
        width = box[3] * image.shape[1]
        height = box[4] * image.shape[0]

        # plot the rectangle with matplotlib
        anchor = (x_center - width/2, y_center - height/2)
        rect = patches.Rectangle(anchor, width, height, linewidth=1, alpha = 0.8, edgecolor=all_colors[int(box[0])], facecolor='none')
        ax.add_patch(rect)

        # add confidence value to rectangle, if flag is set
        if confidence:
            confidence = str(box[5])
            ax.text(anchor[0], anchor[1], confidence, fontsize=10, color=all_colors[int(box[0])])

    # add legend for all classes with their respective color
    legend_elements = []
    for i in range(len(classes)):
        legend_elements.append(patches.Patch(facecolor=all_colors[i], edgecolor=all_colors[i], label=classes[i]))
    ax.legend(handles=legend_elements)
    plt.show()

    if output_dir is not None and name is not None:
        full_name = os.path.join(output_dir, name)+".eps"
        if confidence:
            full_name = os.path.join(output_dir, name)+"_c.eps"
        fig.savefig(full_name, format="eps")


#%% main

# load model
weights_dir = "D:/Measurement_Data/2025/Object_Recognition/Yolo_output/Phthalocyanine/testtesttest/train_001/weights/best.pt"
model = YOLO(weights_dir)

# select data
data_dir = os.path.normpath('D:/Measurement_Data/2024/Data_Michi/YoloInput_Phthalocyanine/Phthalocyanine_raw/')

#data_dir = "/data/michael/Object_Recognition/Yolo_input/Nanocar_Feedback/turning_test/valid"

# get the box coordinates of every image in the data_dir
all_boxes = []
all_files = os.listdir(data_dir)
all_files.sort()
for file in all_files:
    if file.endswith(".jpeg"):
        img = cv2.imread(os.path.join(data_dir, file))
        results = model(img)
        pred_bbox = results[0].boxes.xywhn[:].cpu().numpy()
        all_boxes.append(pred_bbox)

# calculate the size of each box
all_box_sizes = []
for boxes in all_boxes:
    for box in boxes:
        width = box[2]
        height = box[3]
        box_size = width * height
        all_box_sizes.append(box_size)

# %% plot box sizes

random_image = random.randint(0, len(all_boxes)-1)
random_image = 73
image = cv2.imread(os.path.join(data_dir, os.listdir(data_dir)[random_image]))

results = model(image)
pred_class = results[0].boxes.cls.cpu().numpy()
pred_bbox = results[0].boxes.xywhn[:].cpu().numpy()
pred_confidence = np.round(results[0].boxes.conf.cpu().numpy(),2)
pred = np.concatenate((pred_class.reshape(pred_class.size,1), pred_bbox, pred_confidence.reshape(pred_confidence.size,1)), axis=1)

plot_predictions(image, pred)

# smooth the box sizes and then plot again
smoothed_box_sizes = []
for i in range(0, len(all_box_sizes)-1):
    smoothed_box_sizes.append(np.mean(all_box_sizes[i:i+10]))

plt.figure()
plt.plot(all_box_sizes)
plt.plot(smoothed_box_sizes)
plt.legend(["original", "smoothed"])
plt.xlabel("rotation angle")
plt.ylabel("box size")


# %%


