# knee_yolo_v8_train.py

import os
import numpy as np
import h5py
import cv2
from torch.utils.data import Dataset
from ultralytics import YOLO

# Uses dataset from https://pmc.ncbi.nlm.nih.gov/articles/PMC9531250/
custom_project_dir = "/home/acc/Treatment_Modeling/Temporal-Treatment-Modeling/yolo"

class KneeDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.h5_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.h5')]

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        h5_file = self.h5_files[idx]

        with h5py.File(h5_file, 'r') as data:
            img = data['images'][()].astype(np.uint8)

            boxes_group = data['gt_boxes']
            classes_group = data['gt_classes']

            box_keys = sorted(boxes_group.keys(), key=lambda x: int(x[1:]))
            boxes = [
                np.array([boxes_group[key].attrs[f'i{i}'] for i in range(4)])
                for key in box_keys
            ]

            class_count = len(classes_group.attrs) - 3
            classes = [int(classes_group.attrs[f'i{i}']) for i in range(class_count)]

        if len(img.shape) == 2 or (len(img.shape) == 3 and img.shape[2] == 1):
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        height, width = img.shape[:2]

        yolo_labels = []
        for box, cls in zip(boxes, classes):
            box_width = (box[2] - box[0]) / width
            box_height = (box[3] - box[1]) / height
            if box_width <= 0 or box_height <= 0:
                continue

            x_center = ((box[0] + box[2]) / 2) / width
            y_center = ((box[1] + box[3]) / 2) / height
            yolo_labels.append([int(cls), x_center, y_center, box_width, box_height])

        return img, yolo_labels


def save_yolo_dataset(dataset, image_folder, label_folder):
    os.makedirs(image_folder, exist_ok=True)
    os.makedirs(label_folder, exist_ok=True)

    for idx in range(len(dataset)):
        img, labels = dataset[idx]
        img_path = os.path.join(image_folder, f'image_{idx:05d}.png')
        label_path = os.path.join(label_folder, f'image_{idx:05d}.txt')

        cv2.imwrite(img_path, img)

        with open(label_path, 'w') as f:
            for label in labels:
                cls_id, x_center, y_center, box_width, box_height = label
                if all(0 <= val <= 1 for val in [x_center, y_center, box_width, box_height]):
                    f.write(f"{int(cls_id)} {x_center:.6f} {y_center:.6f} {box_width:.6f} {box_height:.6f}\n")


if __name__ == '__main__':
    root = '/local2/acc/OAI/Knee_Joint/KneeXrayData'

    datasets = {'train': 'trainH5', 'val': 'valH5', 'test': 'testH5'}

    for split, ds in datasets.items():
        dataset = KneeDataset(os.path.join(root, 'DetKneeData', 'H5', ds))
        save_yolo_dataset(
            dataset,
            os.path.join(root, 'YOLO', 'images', split),
            os.path.join(root, 'YOLO', 'labels', split)
        )

    model = YOLO('yolo11x.pt')

    results = model.train(
        data='knee_dataset.yaml',
        epochs=100,
        imgsz=640,
        batch=32,
        workers=4,
        device='cuda:2'
    )

    model.save('knee_detector_yolo11xb32.pt')

    print("Training complete! Model saved as knee_detector_yolov8.pt")

    val_metrics = model.val(data='knee_dataset.yaml', split='val', project=custom_project_dir, name='knee_val_yolo11')
    print("Validation metrics:", val_metrics)

    test_metrics = model.val(data='knee_dataset.yaml', split='test', project=custom_project_dir, name='knee_test_yolo11')
    print("Test metrics:", test_metrics)
