import os
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import medpy.metric.binary as mm
import argparse
from multiprocessing import Pool, cpu_count
from functools import partial

def get_bounding_box(mask, padding=5):
    """Get bounding box coordinates for a binary mask with padding."""
    if not mask.any():
        return (0, 0, 0), mask.shape
        
    coords = np.array(np.nonzero(mask))
    min_coords = np.min(coords, axis=1)
    max_coords = np.max(coords, axis=1)
    
    # Add padding
    min_coords = np.maximum(min_coords - padding, 0)
    max_coords = np.minimum(max_coords + padding, np.array(mask.shape) - 1)
    
    return tuple(min_coords), tuple(max_coords + 1)

def crop_to_bounding_box(volume, min_coords, max_coords):
    """Crop volume to specified bounding box."""
    return volume[min_coords[0]:max_coords[0],
                 min_coords[1]:max_coords[1],
                 min_coords[2]:max_coords[2]]

def evaluate_metrics(pred_path, gt_path):
    """Evaluate metrics on predicted and ground truth segmentations with optimized cropping."""
    pred = nib.load(pred_path).get_fdata()
    gt = nib.load(gt_path).get_fdata()
    voxelspacing = nib.load(pred_path).header.get_zooms()
    
    labels = ['both_lungs', 'liver']
    results = {}
    
    for i, label in enumerate(labels):
        pred_label = (pred == i+1).astype(int)
        gt_label = (gt == i+1).astype(int)
          
        if pred_label.sum() > 0:
            # Get combined bounding box for both pred and gt
            pred_min, pred_max = get_bounding_box(pred_label)
            gt_min, gt_max = get_bounding_box(gt_label)
            
            # Combine bounding boxes
            min_coords = np.minimum(pred_min, gt_min)
            max_coords = np.maximum(pred_max, gt_max)
            
            # Crop both volumes to the combined bounding box
            pred_crop = crop_to_bounding_box(pred_label, min_coords, max_coords)
            gt_crop = crop_to_bounding_box(gt_label, min_coords, max_coords)
        
            dice = mm.dc(pred_crop, gt_crop)
            jc = mm.jc(pred_crop, gt_crop)
            hd95 = mm.hd95(pred_crop, gt_crop, voxelspacing=voxelspacing)
            asd = mm.asd(pred_crop, gt_crop, voxelspacing=voxelspacing)
        else:
            dice, jc, hd95, asd = 0, np.nan, np.nan, np.nan
            
        results[label] = {
            'dice': dice,
            'jc': jc,
            'asd': asd,
            'hd95': hd95
        }
    
    return results

    
def process_file(fname, MERGED_GT, PREDS):
    gt_path = os.path.join(MERGED_GT, fname)
    pred_path = os.path.join(PREDS, fname)
    
    if not os.path.exists(pred_path):
        return {'fname': fname, 'error': f"Missing prediction for {fname}"}
    
    results = evaluate_metrics(pred_path, gt_path)
    results['fname'] = fname
    return results


def process_file_wrapper(fname, merged_gt, preds):
    """Wrapper function to handle processing and printing for each file"""
    result = process_file(fname, merged_gt, preds)
    fname = result['fname']
    print(f"Processed {fname}")
    for label in ['liver',  'both_lungs']:
        dice = result[label]['dice']
        jc = result[label]['jc']
        asd = result[label]['asd']
        hd95 = result[label]['hd95']
        print(f"  {label} - dice: {dice:.2f}, jc: {jc:.2f}, asd: {asd:.2f}, hd95: {hd95:.2f}")
    return result

def main(MERGED_GT, PREDS, OUTPUT_CSV, num_processes=8):
    """
    Parallelized version of the main processing function
    
    Args:
        MERGED_GT (str): Path to merged ground truth labels
        PREDS (str): Path to predicted labels
        OUTPUT_CSV (str): Path to save output CSV
        num_processes (int, optional): Number of processes to use. Defaults to number of CPU cores minus 1
    """
    # Determine number of processes to use
    if num_processes is None:
        num_processes = max(1, cpu_count() - 1)  # Leave one CPU free for system processes
    
    # Get list of files to process
    files_to_process = [x for x in os.listdir(PREDS) if x.endswith('.nii.gz')]
    
    # Create partial function with fixed arguments
    process_func = partial(process_file_wrapper, merged_gt=MERGED_GT, preds=PREDS)
    
    # Process files in parallel
    print(f"Starting parallel processing with {num_processes} processes...")
    with Pool(processes=num_processes) as pool:
        results = pool.map(process_func, files_to_process)
    
    # Convert results to DataFrame
    df = pd.DataFrame(results)
    df_melted = df.melt(id_vars=['fname'], var_name='metric', value_name='value')
    metrics_df = pd.json_normalize(df_melted['value'])
    df_melted = df_melted.join(metrics_df)
    df_melted = df_melted.drop(columns=['value'])
    
    # Save results
    os.makedirs(os.path.dirname(OUTPUT_CSV), exist_ok=True)
    df_melted.to_csv(OUTPUT_CSV, index=False)
    print(f"Results saved to {OUTPUT_CSV}")


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--merged_gt', type=str, default='/path/to/data/labels/merged_labels/', help='Path to merged ground truth labels')
    argparser.add_argument('--preds', type=str, default='/path/to/data/predictions/labels/', help='Path to predicted labels')
    argparser.add_argument('--output_csv', type=str, default='results/metrics.csv',help='Path to save output CSV')
    args = argparser.parse_args()
    os.makedirs(os.path.dirname(args.output_csv), exist_ok=True)
    MERGED_GT = args.merged_gt
    PREDS = args.preds
    OUTPUT_CSV = args.output_csv
    main(MERGED_GT, PREDS, OUTPUT_CSV)
