import os, sys
import torch, json
import numpy as np
import matplotlib.pyplot as plt
from main import build_model_main
from util.slconfig import SLConfig
from datasets import build_dataset
from util.visualizer import COCOVisualizer
from util import box_ops
import albumentations as A
from torch.utils.data import Dataset, DataLoader
import util.misc as utils


model_config_path = "config/DINO/DINO_5scale.py" # change the path of the model config file
model_checkpoint_path = "SWIM_MS5_Villi_Crypt_SME_Chamfer_matcher_Manhatloss_L_PL_Augment_mixup_start_end_points/checkpoint_best_regular.pth" # change the path of the model checkpoint

# See our Model Zoo section in README.md for more details about our pretrained models.
s
args = SLConfig.fromfile(model_config_path) 
args.device = 'cuda' 
model, criterion, postprocessors = build_model_main(args)
checkpoint = torch.load(model_checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
_ = model.eval()
model.to(args.device)

args = SLConfig.fromfile(model_config_path) 

args.dataset_file = 'measurement_augment_mask_mixup_mid_point'
args.coco_path = "MeasureNet" # the path of coco
args.fix_size = False
dataset_train = build_dataset(image_set='train', args=args)
dataset_val = build_dataset(image_set='val', args=args)
dataset_test = build_dataset(image_set='test', args=args)
sampler_train = torch.utils.data.SequentialSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
data_loader_train = DataLoader(dataset_train, 1, sampler=sampler_train,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=6)
data_loader_val = DataLoader(dataset_val, 1, sampler=sampler_val,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=6)
data_loader_test = DataLoader(dataset_test, 1, sampler=sampler_test,
                                 drop_last=False, collate_fn=utils.collate_fn, num_workers=6)

image_size_scaled = 512

import torch
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import os
import seaborn as sns

import torch
import numpy as np
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment

device = 'cuda'

def compute_average_length(lengths_dict, class_id):
    if class_id in lengths_dict and len(lengths_dict[class_id]) > 0:
        return np.mean(lengths_dict[class_id])
    return 0.0

def compute_detection(output, threshold = 0.3):
    # Set threshold and filter predictions
    scores = output['scores']
    labels = output['labels']
    boxes = output['boxes'] * 512
    select_mask = scores > threshold

    # Select predictions based on the threshold
    selected_boxes = boxes[select_mask]
    selected_labels = labels[select_mask]
    selected_scores = scores[select_mask]

    selected_boxes_np = selected_boxes.cpu().numpy()
    selected_labels_np = selected_labels.cpu().numpy()

    return selected_boxes_np, selected_labels_np

def chamfer_distance(set1, set2):
    """
    Compute the Chamfer distance between two sets of polylines.
    Chamfer distance measures the closest-point distance between two sets.
    """
    distances = np.zeros((len(set1), len(set2)))

    for i, poly1 in enumerate(set1):
        for j, poly2 in enumerate(set2):
            dists = cdist(poly1.reshape(-1, 2), poly2.reshape(-1, 2), metric='euclidean')
            distances[i, j] = np.mean(np.min(dists, axis=1)) + np.mean(np.min(dists, axis=0))

    return distances


def compute_length(points):
    """
    Compute the length of a polyline given a set of 3 points.
    """
    points = points.reshape(-1, 2)  # Ensure shape is (3, 2)
    return np.linalg.norm(points[1] - points[0]) + np.linalg.norm(points[2] - points[1])

# Evaluation Loop
villi_GP_errors = []
villi_PG_errors = []
villi_GP_rel_errors = []
villi_PG_rel_errors = []

crypt_GP_errors = []
crypt_PG_errors = []
crypt_GP_rel_errors = []
crypt_PG_rel_errors = []

for index, (image, AP_Mask_strong, AP_Mask_weak, targets) in enumerate(data_loader_test):
    image = image.to(device)
    AP_Mask_strong = AP_Mask_strong.to(device)
    AP_Mask_weak = AP_Mask_weak.to(device)

    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    targets = targets[0]


    threshold = 0.3
    output = model(image, AP_Mask_strong, AP_Mask_weak)
    output = postprocessors['bbox'](output, torch.Tensor([[1.0, 1.0]]).cuda(), not_to_xyxy=True)[0]

    # Initial detection attempt
    selected_boxes_np, selected_labels_np = compute_detection(output, threshold=threshold)

    gt_points = targets['boxes'] * 512
    gt_labels = targets['labels']

    gt_points_np = gt_points.cpu().numpy()
    gt_labels_np = gt_labels.cpu().numpy()

    for class_label in [1, 2]:
        mask = gt_labels_np == class_label
        class_gt_points = gt_points_np[mask]
        class_pred_points = selected_boxes_np[selected_labels_np == class_label]

        # Adjust threshold dynamically but prevent infinite loop
        while len(class_pred_points) == 0 and threshold > 0.1:
            threshold -= 0.05
            selected_boxes_np, selected_labels_np = compute_detection(output, threshold=threshold)
            class_pred_points = selected_boxes_np[selected_labels_np == class_label]

        # Compute Chamfer distance matrix
        distance_matrix = chamfer_distance(class_gt_points, class_pred_points)

        # Compute GP_avg (GT to closest Pred)
        for i, gt_poly in enumerate(class_gt_points):
            if len(class_pred_points) > 0:
                closest_pred_idx = np.argmin(distance_matrix[i])
                closest_pred_poly = class_pred_points[closest_pred_idx]
                gt_length = compute_length(gt_poly)
                pred_length = compute_length(closest_pred_poly)
                length_error = abs(gt_length - pred_length)
                if class_label == 1:
                    villi_GP_errors.append(length_error)
                    if gt_length > 0:
                        villi_GP_rel_errors.append(length_error / gt_length)
                else:
                    crypt_GP_errors.append(length_error)
                    if gt_length > 0:
                        crypt_GP_rel_errors.append(length_error / gt_length)

        # Compute PG_avg (Pred to closest GT)
        for j, pred_poly in enumerate(class_pred_points):
            if len(class_gt_points) > 0:
                closest_gt_idx = np.argmin(distance_matrix[:, j])
                closest_gt_poly = class_gt_points[closest_gt_idx]
                pred_length = compute_length(pred_poly)
                gt_length = compute_length(closest_gt_poly)
                length_error = abs(pred_length - gt_length)
                if class_label == 1:
                    villi_PG_errors.append(length_error)
                    if pred_length > 0:
                        villi_PG_rel_errors.append(length_error / pred_length)
                else:
                    crypt_PG_errors.append(length_error)
                    if pred_length > 0:
                        crypt_PG_rel_errors.append(length_error / pred_length)

# Compute mean errors for villi and crypt
villi_GP_avg = np.mean(villi_GP_errors) if villi_GP_errors else 0.0
villi_PG_avg = np.mean(villi_PG_errors) if villi_PG_errors else 0.0
villi_max_error = max(villi_GP_avg, villi_PG_avg)

villi_GP_rel_avg = np.mean(villi_GP_rel_errors) if villi_GP_rel_errors else 0.0
villi_PG_rel_avg = np.mean(villi_PG_rel_errors) if villi_PG_rel_errors else 0.0
villi_max_rel_error = max(villi_GP_rel_avg, villi_PG_rel_avg)

crypt_GP_avg = np.mean(crypt_GP_errors) if crypt_GP_errors else 0.0
crypt_PG_avg = np.mean(crypt_PG_errors) if crypt_PG_errors else 0.0
crypt_max_error = max(crypt_GP_avg, crypt_PG_avg)

crypt_GP_rel_avg = np.mean(crypt_GP_rel_errors) if crypt_GP_rel_errors else 0.0
crypt_PG_rel_avg = np.mean(crypt_PG_rel_errors) if crypt_PG_rel_errors else 0.0
crypt_max_rel_error = max(crypt_GP_rel_avg, crypt_PG_rel_avg)

print(f"Villi - Max Length Error (Chamfer-based MAE): {villi_max_error}")
print(f"Villi - Max Length Error (Chamfer-based MRE): {villi_max_rel_error}")
print(f"Crypt - Max Length Error (Chamfer-based MAE): {crypt_max_error}")
print(f"Crypt - Max Length Error (Chamfer-based MRE): {crypt_max_rel_error}")
