import sys
import os
from PIL import Image
sys.path.append('./')

import logging

import numpy as np
from tqdm import tqdm

from corr.utils import pose_error


def inference_part(
        cfg,
        cate,
        model,
        dataloader,
        cached_pred=None
):
    mious = []

    intersection_list = None
    union_list = None
    for i, sample in enumerate(tqdm(dataloader, desc=f"{cfg.task}_{cate}")):
        if cached_pred is None or True:
            results = model.evaluate(sample)
            if results is None:
                continue
            print('results: ', results)

            if intersection_list is None:
                intersection_list = results['intersections']
                union_list = results['unions']
            else:
                for part_id in range(len(intersection_list)):
                    intersection_list[part_id] += results['intersections'][part_id]
                    union_list[part_id] += results['unions'][part_id]

            mious.append(results['mIoU'])

    if model.chamfer:
        for part_id in range(len(model.chamfer_count)):
            if model.chamfer_count[part_id] == 0:
                print(part_id, 'no chamfer')
            else:
                chamfer_distance = model.chamfer_distances[part_id] / model.chamfer_count[part_id]
                iou = model.ious[part_id] / model.chamfer_count[part_id]
                print(' chamfer: ', chamfer_distance, ' iou: ', iou)

    print('exp name: ', cfg.args.save_dir.split('/')[-1])
    iou_list = []
    for part_id in range(len(intersection_list)):
        iou = intersection_list[part_id] / union_list[part_id]
        iou_list.append(iou)
    print('ious: ', iou_list)
    print('miou: ', np.mean(np.array(iou_list)))

    return None


helper_func_by_task = {"image_part": inference_part}
