import sys
import os
from PIL import Image
sys.path.append('./')

import logging

import numpy as np
from tqdm import tqdm

from corr.utils import pose_error


def inference_part(
        cfg,
        cate,
        model,
        dataloader,
        cached_pred=None
):
    mious = []

    intersection_list = None
    union_list = None
    for i, sample in enumerate(tqdm(dataloader, desc=f"{cfg.task}_{cate}")):
        if cached_pred is None or True:
            results = model.evaluate(sample)
            if results is None:
                continue
            print('results: ', results)

            if intersection_list is None:
                intersection_list = results['intersections']
                union_list = results['unions']
            else:
                for part_id in range(len(intersection_list)):
                    intersection_list[part_id] += results['intersections'][part_id]
                    union_list[part_id] += results['unions'][part_id]

            mious.append(results['mIoU'])

    exp_name = cfg.args.save_dir.split('/')[-1]
    
    save_iou_path = os.path.join('../no_scale_iou_results', exp_name, cate)
    os.makedirs(save_iou_path, exist_ok=True)
    with open(os.path.join(save_iou_path, '3d.txt'), 'w') as f:
        pass

    if model.chamfer:
        for part_id in range(len(model.chamfer_count)):
            if model.chamfer_count[part_id] == 0:
                print(part_id, 'no chamfer')
            else:
                chamfer_distance = model.chamfer_distances[part_id] / model.chamfer_count[part_id]
                iou = model.ious[part_id] / model.chamfer_count[part_id]
                print(' chamfer: ', chamfer_distance, ' iou: ', iou)
            
                with open(os.path.join(save_iou_path, '3d.txt'), 'a') as f:
                    f.write(f'{part_id} {chamfer_distance} {iou}\n')

    with open(os.path.join(save_iou_path, '2d.txt'), 'w') as f:
        pass

    print('exp name: ', exp_name)
    iou_list = []
    for part_id in range(len(intersection_list)):
        iou = intersection_list[part_id] / union_list[part_id]
        iou_list.append(iou)

        with open(os.path.join(save_iou_path, '2d.txt'), 'a') as f:
            f.write(f'{part_id} {iou}\n')

    print('ious: ', iou_list)
    print('miou: ', np.mean(np.array(iou_list)))

    return None


helper_func_by_task = {"image_part": inference_part}
