import argparse
import json
import os
import re
import sys
import time

import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
import math
import cv2
import numpy as np
from mmdet.apis import inference_detector, init_detector
from tqdm import tqdm
from clip_for_color_detect import CLIPColorDetector
from copy import deepcopy


CONFI_THRESHOLD=0.3
IOU_THRESHOLD=0.9
POSITION_THRESHOLD=0.1
SIZE_THRESHOLD=5

def compute_iou(box_a, box_b):
    area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0)
    i_area = area_fn([
        max(box_a[0], box_b[0]), max(box_a[1], box_b[1]),
        min(box_a[2], box_b[2]), min(box_a[3], box_b[3])
    ])
    u_area = area_fn(box_a) + area_fn(box_b) - i_area
    return i_area / u_area if u_area else 0
    
def reverse_select(length, num, idx):
    p = 0
    ret = []
    while length > 0 and num > 0: 
        first_half = math.comb(length - 1, num - 1)
        if idx < first_half:
            ret.append(p)
            num -= 1
        else:
            idx -= first_half
        p += 1
        length -= 1
    return ret

def assign_shuffle(length, idx):
    pos_list = [_ for _ in range(length)]
    ret_list = []
    while length > 0:
        now_pos = idx // math.factorial(length - 1)
        next_pos = idx % math.factorial(length - 1)
        ret_list.append(pos_list[now_pos])
        pos_list.remove(pos_list[now_pos])
        idx = next_pos
        length -= 1
    return ret_list

def relative_position(obj_a_bbox, obj_b_bbox):
    """Give position of A relative to B, factoring in object dimensions"""
    boxes = np.array([obj_a_bbox, obj_b_bbox])[:, :4].reshape(2, 2, 2)
    center_a, center_b = boxes.mean(axis=-2)
    dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :]
    offset = center_a - center_b
    #
    revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset)
    if np.all(np.abs(revised_offset) < 1e-3):
        return set()
    #
    dx, dy = revised_offset / np.linalg.norm(offset)
    relations = set()
    if dx < -0.5: relations.add("left")
    if dx > 0.5: relations.add("right")
    if dy < -0.5: relations.add("above")
    if dy > 0.5: relations.add("below")
    return relations

def agg_metrics(obj_dict):
    number_bias = 0
    cor_att = 0
    total_att = 0
    for k in obj_dict:
        number_bias += obj_dict[k]['number_bias']
        for item in obj_dict[k]['objects']:
            temp_tot = 0
            temp_cor = 0
            if 'color' in item:
                temp_tot += 1
            if item.get('color_is_correct', False):
                temp_cor += 1
            
            if 'relation' in item:
                for rel in item['relation']:
                    for k in item['relation'][rel]:
                        temp_tot += 1
                        if item['object_found'] and k[2] == 'relation correct!':
                            temp_cor += 1
            item['local_acc'] = temp_cor / temp_tot
            total_att += temp_tot
            cor_att += temp_cor
    obj_dict['all_acc'] = cor_att / total_att
    obj_dict['all_bias'] = number_bias
    return obj_dict

class ImageEvaluator(object):
    def __init__(self, args):
        self.device = args.device
        CONFIG_PATH = "configs/mmdet_config/configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py"
        
        
        CKPT_PATH = args.det_path

        self.object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=self.device)
        self.color_detector = CLIPColorDetector(args.clip_path, self.device)

        with open('m3t2ibench_code/object_names.json', 'r') as f:
            self.class_to_color = json.load(f)
            
        self.classnames = list(self.class_to_color.keys())
        for k in self.class_to_color:
            if len(self.class_to_color[k]) == 0:
                self.class_to_color[k] = ['green', 'red', 'yellow', 'brown', 'black', 'white', 'blue', 'grey']
        
        with open("m3t2ibench_code/object_names.txt") as cls_file:
            self.classnames_ori = [line.strip() for line in cls_file]

    def set_image_path(self, root_path, image_path, metadata):
        self.image = cv2.imread(image_path)
        self.image_path = image_path
        self.root_path = root_path
        self.color_detector.set_image_path(image_path)
        with open(metadata, 'r') as f:
            self.metadata = json.load(f)
            if 'fix' in root_path:
                self.metadata = self.metadata['meta_dict']

    def eval_image(self):
        result = inference_detector(self.object_detector, self.image_path)
        bbox = result[0]
        seg = result[1]

        detected = {}
        self.tot = 0 
        temp_image = deepcopy(self.image)
        for index, classname in enumerate(self.classnames_ori):
            if classname not in self.classnames:
                continue
            possible_bbox = bbox[index]
            possible_seg = np.array(seg[index])

            bbox_idx_sorted = np.argsort(possible_bbox[:, 4])[::-1]
            possible_bbox = possible_bbox[bbox_idx_sorted]
            possible_seg = possible_seg[bbox_idx_sorted]

            detected[classname] = []
            for p_bbox, p_seg in zip(possible_bbox, possible_seg):
                if p_bbox[-1] > CONFI_THRESHOLD and p_bbox[2] - p_bbox[0] > SIZE_THRESHOLD and p_bbox[3] - p_bbox[1] > SIZE_THRESHOLD:
                    fl = 0
                    for pp_bbox, _ in detected[classname]:
                        iou = compute_iou(p_bbox, pp_bbox)
                        if iou > IOU_THRESHOLD:
                            fl = 1
                            break
                    if fl == 0:
                        detected[classname].append((p_bbox, self.color_detector.check_color_list(p_bbox[:4], classname, self.class_to_color[classname], p_seg)))
                else:
                    break
            
            for p_bbox, _ in detected[classname]:
                cv2.rectangle(temp_image, (int(p_bbox[0]), int(p_bbox[1])), (int(p_bbox[2]), int(p_bbox[3])), color = (255, 0, 0))
                cv2.putText(temp_image, classname, (int(p_bbox[0]), int(p_bbox[1]) - 10), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0))
        cv2.imwrite(os.path.join(self.root_path, 'test_det.png'), temp_image)
        self.reverse_assign_and_match(detected, self.metadata['obj_dict'], 0)

    def check_metrics(self, obj_dict):
        ret_dict = deepcopy(obj_dict)
        img = deepcopy(self.image)
        for k in ret_dict:
            for item in ret_dict[k]['objects']:
                if 'bbox' not in item:
                    item['object_found'] = False
                    continue
                item['bbox'] = item['bbox'].tolist()
                item['object_found'] = True
                bbox = item['bbox']
                cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color = (255, 0, 0))
                cv2.putText(img, f'{k}_{item["id"]}', (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0))
                if 'color' in item:
                    if not item['color'] in item['detected_colors']:
                        item['color_is_correct'] = False
                    else:
                        item['color_is_correct'] = True
                if 'relation' in item:
                    for rel in item['relation']:
                        for idx, kk in enumerate(item['relation'][rel]):
                            target_item = ret_dict[kk[0]]['objects'][kk[1]]
                            if 'bbox' not in target_item:
                                item['relation'][rel][idx].append('not found')
                            else:
                                ret_rel = relative_position(bbox, target_item['bbox'])
                                if rel in ret_rel:
                                    item['relation'][rel][idx].append('relation correct!')
                                else:
                                    item['relation'][rel][idx].append('relation failed!')
        ret_dict = agg_metrics(ret_dict)
        with open(os.path.join(self.root_path, f'test_{self.tot}.json'), 'w') as f:
            json.dump(ret_dict, f, indent = 4)
        self.tot += 1
    
    def reverse_assign_and_match(self, detected, obj_dict, depth):
        if depth >= len(list(obj_dict.keys())):
            self.check_metrics(obj_dict)
            return
        obj_item = list(obj_dict.keys())[depth]
        detected_list = []
        for obj_item_name in detected:
            if obj_item in obj_item_name:
                detected_list.extend(detected[obj_item_name])
        #detected_list = detected[obj_item] if obj_item in detected else []
        target_list = obj_dict[obj_item]['objects']
        num1 = len(detected_list)
        num2 = len(target_list)
        obj_dict[obj_item]['number_bias'] = abs(num1 - num2)
        if num1 == num2:
            for i in range(math.factorial(num1)):
                ret_list = assign_shuffle(num1, i)
                for j in range(len(target_list)):
                    target_list[j]['bbox'] = detected_list[ret_list[j]][0]
                    target_list[j]['detected_colors'] = detected_list[ret_list[j]][1]
                obj_dict[obj_item]['objects'] = target_list
                self.reverse_assign_and_match(detected, obj_dict, depth + 1)
                for j in range(len(target_list)):
                    target_list[j].pop('bbox', None)
                    target_list[j].pop('detected_colors', None)

        elif num1 < num2:
            #select items
            for i in range(math.comb(num2, num1)):
                selected_indices = reverse_select(num2, num1, i)
                for j in range(math.factorial(num1)):
                    ret_list = assign_shuffle(num1, j)
                    for k in range(num1):
                        target_list[selected_indices[k]]['bbox'] = detected_list[ret_list[k]][0]
                        target_list[selected_indices[k]]['detected_colors'] = detected_list[ret_list[k]][1]
                    obj_dict[obj_item]['objects'] = target_list
                    self.reverse_assign_and_match(detected, obj_dict, depth + 1)
                    for k in range(num1):
                        target_list[selected_indices[k]].pop('bbox', None)
                        target_list[selected_indices[k]].pop('detected_colors', None)
                    obj_dict[obj_item]['objects'] = target_list
        else:
            #select detected
            for i in range(math.comb(num1, num2)):
                selected_indices = reverse_select(num1, num2, i)
                for j in range(math.factorial(num2)):
                    ret_list = assign_shuffle(num2, j)
                    for k in range(num2):
                        target_list[k]['bbox'] = detected_list[selected_indices[ret_list[k]]][0]
                        target_list[k]['detected_colors'] = detected_list[selected_indices[ret_list[k]]][1]
                    obj_dict[obj_item]['objects'] = target_list
                    self.reverse_assign_and_match(detected, obj_dict, depth + 1)
                    for k in range(num2):
                        target_list[k].pop('bbox', None)
                        target_list[k].pop('detected_colors', None)
                    obj_dict[obj_item]['objects'] = target_list

def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate benchmark')
    parser.add_argument('--root_path', type=str, default='', help='root path')
    parser.add_argument('--device', type=str, default='cpu', help='device')
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--end_index', type=int, default=1000)
    parser.add_argument('--det_path', type=str, default='')
    parser.add_argument('--clip_path', type=str, default='')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_args()
    evaluator = ImageEvaluator(args)

    root_path = args.root_path
    for i in tqdm(range(args.start_index, args.end_index)):
        now_path = os.path.join(root_path, f'{i}')
        image_path = os.path.join(now_path, 'gen.png')
        metadata_path = os.path.join(now_path, 'metadata.json')
        evaluator.set_image_path(now_path, image_path, metadata_path)
        evaluator.eval_image()