import os
import numpy as np
import cv2
from glob import glob
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
from medpy.metric.binary import dc
CONFIG = {
    'image_size': 560,
    'dataset': '',  
    'category': '',
    'methods': ['point'] 
}

# Directory setup
BASE_DIR = os.path.dirname(__file__)
DIRS = {" "
}

# Create required directories
os.makedirs(DIRS['output'], exist_ok=True)

def process_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return None
    img = cv2.resize(img, (CONFIG['image_size'], CONFIG['image_size']))
    img = (img > 127).astype(np.uint8) if np.max(img) > 1 else img
    return img

def calculate_iou(y_true, y_pred):
    intersection = np.logical_and(y_true, y_pred)
    union = np.logical_or(y_true, y_pred)
    return np.sum(intersection) / (np.sum(union) + 1e-6)

def evaluate_single_method(pred_dir, ground_truth_dir):
    result_files = sorted(glob(os.path.join(pred_dir, "*.png")))
    if not result_files:
        raise FileNotFoundError(f"No mask images found in {pred_dir}")
    
    results = []
    for result_file in tqdm(result_files, desc="Processing images"):
        filename = os.path.basename(result_file)
        filename_without_ext = os.path.splitext(filename)[0]
        
        for ext in ['.jpg', '.png']:
            gt_file = os.path.join(ground_truth_dir, filename_without_ext + ext)
            if os.path.exists(gt_file):
                break
        else:  
            continue
        
        result_img = process_image(result_file)
        gt_img = process_image(gt_file)
        
        if result_img is None or gt_img is None:
            continue
        
        results.append({
            'filename': filename,
            'dice': dc(gt_img, result_img),
            'iou': calculate_iou(gt_img, result_img)
        })
    
    return pd.DataFrame(results)


def evaluate_segmentation():
    print("Evaluating Points method...")
    point_dir = os.path.join(DIRS['results'], 'masks_point')
    results = evaluate_single_method(point_dir, DIRS['ground_truth'])
    
    fig, summary_df = results
    fig.savefig(os.path.join(DIRS['output'], "point_results.png"), dpi=300, bbox_inches='tight')
    
    results.to_csv(os.path.join(DIRS['output'], "point_results.csv"), index=False)
    summary_df.to_csv(os.path.join(DIRS['output'], "summary.csv"), index=False)
    
    print("\nResults Summary:")
    print(summary_df.to_string(index=False))
    
    plt.show()

if __name__ == "__main__":
    evaluate_segmentation()