#!/usr/bin/env python
import numpy as np
from glob import glob
from itertools import product
import SimpleITK as sitk
sitk.ProcessObject_SetGlobalWarningDisplay(False)
from tqdm import tqdm

image_ids = range(1, 11)

# save all dice scores
all_dices = []

def dice_score(p, q):
    num = (2.0*p*q).sum()
    den = (p.sum() + q.sum())
    return num/den

for i, (fixed_id, moving_id) in enumerate(product(image_ids, image_ids)):
    if fixed_id == moving_id:
        continue
    fixed_path = f"../Atlases/g{fixed_id}.img"
    moved_path = f"outputs/deformed_{fixed_id}_{moving_id}_seg.img"
    # load labels
    fixed_image = sitk.ReadImage(fixed_path)
    moved_image = sitk.ReadImage(moved_path)
    # load segmentations
    fixed_array = sitk.GetArrayFromImage(fixed_image)
    moved_array = sitk.GetArrayFromImage(moved_image)
    # compute metrics
    pair_dices = []
    labels_all = np.unique(np.concatenate([fixed_array, moved_array]))
    for lab in tqdm(labels_all):
        if lab == 0:
            continue
        p = (fixed_array == lab)
        q = (moved_array == lab)
        pair_dices.append(dice_score(p, q))
    print(np.mean(pair_dices))
    all_dices.append(pair_dices)

# full results
all_dices = np.array(all_dices)
print(np.mean(all_dices, 0), np.std(all_dices, 1))
all_dices_0 = np.mean(all_dices, 1)
print(np.mean(all_dices_0), np.std(all_dices_0))

np.save("all_dices.npy", all_dices)

