#!/usr/bin/env python
import numpy as np
from glob import glob
from itertools import product
import SimpleITK as sitk
sitk.ProcessObject_SetGlobalWarningDisplay(False)
from tqdm import tqdm

image_ids = range(1, 13)
labels = np.array([  0,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,  14,
        15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
        28,  29,  30,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,
        43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
        56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  67,  68,  69,
        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,
        83,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,
        97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
       110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122,
       123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133])

# save all dice scores
all_dices = []

def dice_score(p, q):
    num = (2.0*p*q).sum()
    den = (p.sum() + q.sum())
    return num/den

for i, (fixed_id, moving_id) in enumerate(product(image_ids, image_ids)):
    if fixed_id == moving_id:
        continue
    fixed_path = f"../Atlases/m{fixed_id}.img"
    moved_path = f"outputs/deformed_{fixed_id}_{moving_id}_seg.img"
    # load labels
    fixed_image = sitk.ReadImage(fixed_path)
    moved_image = sitk.ReadImage(moved_path)
    # load segmentations
    fixed_array = sitk.GetArrayFromImage(fixed_image)
    moved_array = sitk.GetArrayFromImage(moved_image)
    # compute metrics
    pair_dices = []
    # labels_all = np.unique(np.concatenate([fixed_array, moved_array]))
    for lab in tqdm(labels):
        if lab == 0:
            continue
        p = (fixed_array == lab)
        q = (moved_array == lab)
        pair_dices.append(dice_score(p, q))
    print(np.mean(pair_dices))
    all_dices.append(pair_dices)

# full results
all_dices = np.array(all_dices)
print(np.mean(all_dices, 0), np.std(all_dices, 0))
all_dices_0 = np.mean(all_dices, 1)
print(np.mean(all_dices_0), np.std(all_dices_0))

np.save("all_dices.npy", all_dices)


