from torchvision.models import resnet50, ResNet50_Weights
import torchvision
import os
import torch

import numpy as np

from ultralytics import YOLO

import torchvision.transforms as transforms
from torchvision.io import decode_image
from transformers import pipeline


def get_classifier_results(images, labels, device):
    
    classifier = resnet50(weights=ResNet50_Weights.DEFAULT)
    classifier.to(device)
    classifier.eval()
    
    with torch.no_grad():
        results = classifier(images.to(device)).argmax().cpu().item()
    
    return results

def cxcywh_to_xyxy(box):
    cx, cy, w, h = box
    x_min = cx - w/2
    y_min = cy - h/2
    x_max = cx + w/2
    y_max = cy + h/2
    return [x_min, y_min, x_max, y_max]

def iou(box1, box2):
    box1_xyxy = cxcywh_to_xyxy(box1)
    box2_xyxy = cxcywh_to_xyxy(box2)
    
    xA = max(box1_xyxy[0], box2_xyxy[0])
    yA = max(box1_xyxy[1], box2_xyxy[1])
    xB = min(box1_xyxy[2], box2_xyxy[2])
    yB = min(box1_xyxy[3], box2_xyxy[3])

    inter_area = max(0, xB - xA) * max(0, yB - yA)
    box1_area = (box1_xyxy[2] - box1_xyxy[0]) * (box1_xyxy[3] - box1_xyxy[1])
    box2_area = (box2_xyxy[2] - box2_xyxy[0]) * (box2_xyxy[3] - box2_xyxy[1])
    
    iou = inter_area / float(box1_area + box2_area - inter_area + 1e-6)
    return iou

def calculate_metrics(preds, gts, iou_threshold=0.5):
    gt_by_class = {}
    for gt in gts:
        cls = int(gt[0])
        gt_by_class.setdefault(cls, []).append([float(x) for x in gt[1:]])

    pred_by_class = {}
    for pred in preds:
        cls = int(pred[0])
        pred_by_class.setdefault(cls, []).append([pred[0], float(pred[1]), float(pred[2]), float(pred[3]), float(pred[4]), float(pred[5])])

    precisions = []
    recalls = []
    class_mean_ious = []
    all_match_ious = []

    classes = set([int(p[0]) for p in preds] + [int(g[0]) for g in gts])

    for cls in classes:
        gt_boxes = gt_by_class.get(cls, [])
        pred_boxes = pred_by_class.get(cls, [])
        detected = [False] * len(gt_boxes)

        pred_boxes = sorted(pred_boxes, key=lambda x: x[1], reverse=True)

        TP = 0
        FP = 0
        match_ious = []

        for pred in pred_boxes:
            pred_box = pred[2:]
            if not gt_boxes:
                FP += 1
                continue

            ious = [iou(pred_box, gt_box) for gt_box in gt_boxes]
            max_iou_idx = int(np.argmax(ious))
            max_iou = ious[max_iou_idx]

            if max_iou >= iou_threshold and not detected[max_iou_idx]:
                TP += 1
                detected[max_iou_idx] = True
                match_ious.append(max_iou)
                all_match_ious.append(max_iou)
            else:
                FP += 1

        FN = len(gt_boxes) - TP

        precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0.0

        class_mean_iou = float(np.mean(match_ious)) if match_ious else 0.0

        precisions.append(precision)
        recalls.append(recall)
        class_mean_ious.append(class_mean_iou)

        print(f"Class {cls}: Precision={precision:.3f}, Recall={recall:.3f}, TP={TP}, FP={FP}, FN={FN}, Mean IoU={class_mean_iou:.3f}")
        if match_ious:
            print(f"  Matched IoUs: {[round(x,3) for x in match_ious]}")

    overall_precision = float(np.mean(precisions)) if precisions else 0.0
    overall_recall = float(np.mean(recalls)) if recalls else 0.0
    f1 = (2 * overall_precision * overall_recall / (overall_precision + overall_recall)) if (overall_precision + overall_recall) > 0 else 0.0
    overall_mean_iou = float(np.mean(all_match_ious)) if all_match_ious else 0.0

    print(f"Overall Precision: {overall_precision:.3f}")
    print(f"Overall Recall: {overall_recall:.3f}")
    print(f"Overall F1-score: {f1:.3f}")
    print(f"Overall Mean IoU (over matched detections): {overall_mean_iou:.3f}")

    return overall_precision, overall_recall, f1, overall_mean_iou

def get_detection_results(images, labels):
    yolo = YOLO("yolo11n.pt")
    yolo.eval()
    
    results = yolo(images)
        
    for result in results:
        preds = torch.cat([result.boxes.cls.view(-1, 1), result.boxes.conf.view(-1, 1), result.boxes.xywhn], -1)
        overall_precision, overall_recall, f1, overall_mean_iou = calculate_metrics(preds, labels)
    return overall_precision, overall_recall, f1, overall_mean_iou

def mae(images, labels):
    images = transforms.Resize(labels.size()[1:])(images)
    
    return torch.mean(torch.abs(images.cpu() - labels.cpu())).item()

def get_depth_map(pipe, img, device):
    
    depth = pipe(transforms.ToPILImage()(img[0].cpu()))["predicted_depth"]
    
    depth = (depth - depth.min()) / (depth.max() - depth.min())
    depth = depth.unsqueeze(0)
    return depth.to(device)

def get_depth_results(images, labels, device):
    
    pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf", device=device)
    
    with torch.no_grad():
        results = get_depth_map(pipe, images, device)
    
    return mae(results, labels)
