import cv2
import json
import numpy as np
from random import shuffle
from PIL import Image
import ast
import json
import pickle
import os

def calculate_iou(box1, box2):
    # Extract coordinates (x, y, w, h) from the input lists
    x1, y1, w1, h1 = box1
    x1, y1, w1, h1 = float(x1), float(y1), float(w1), float(h1)
    x2, y2, w2, h2 = box2
    x2, y2, w2, h2 = float(x2), float(y2), float(w2), float(h2)

    # Calculate coordinates of the intersection rectangle
    x_intersection = max(x1, x2)
    y_intersection = max(y1, y2)
    w_intersection = min(x1 + w1, x2 + w2) - x_intersection
    h_intersection = min(y1 + h1, y2 + h2) - y_intersection

    # Calculate areas of the two bounding boxes and the intersection
    area_box1 = w1 * h1
    area_box2 = w2 * h2
    area_intersection = max(0, w_intersection) * max(0, h_intersection)

    # Calculate IoU
    iou = area_intersection / float(area_box1 + area_box2 - area_intersection)
    return iou


def match_predict_n_label(target, predicts):
    ious = []
    for predict in predicts:
        ious.append(calculate_iou(target, predict))
    if np.max(ious) == 0:
        return -1, 0
    return np.argmax(ious), np.max(ious)


def get_coco_image(path, predict_box, target_box, pad=0.15):
    image = cv2.imread(path)
    i_h, i_w = image.shape[0], image.shape[1]

    # drawing the gt box
    x, y, w, h = target_box
    x1_gt, y1_gt, x2_gt, y2_gt = \
        max(int(float(x) * i_w - 0.5 * float(w) * i_w), 0), \
        max(int(float(y) * i_h - 0.5 * float(h) * i_h), 0), \
        min(int(float(x) * i_w + 0.5 * float(w) * i_w), i_w - 1), \
        min(int(float(y) * i_h + 0.5 * float(h) * i_h), i_h - 1)
    cv2.rectangle(image, (x1_gt, y1_gt), (x2_gt, y2_gt), (255, 0, 0), 2)

    # drawing the predicted box
    x1_pred, y1_pred, x2_pred, y2_pred = x1_gt, y1_gt, x2_gt, y2_gt
    if len(predict_box):
        x, y, w, h = predict_box
        x1_pred, y1_pred, x2_pred, y2_pred = \
            max(int(float(x) * i_w - 0.5 * float(w) * i_w), 0), \
            max(int(float(y) * i_h - 0.5 * float(h) * i_h), 0), \
            min(int(float(x) * i_w + 0.5 * float(w) * i_w), i_w - 1), \
            min(int(float(y) * i_h + 0.5 * float(h) * i_h), i_h - 1)
        cv2.rectangle(image, (x1_pred, y1_pred),
                      (x2_pred, y2_pred), (0, 255, 0), 2)

    # determine the expanded range for cropping
    x1 = min(x1_pred, x1_gt)
    y1 = min(y1_pred, y1_gt)
    x2 = max(x2_pred, x2_gt)
    y2 = max(y2_pred, y2_gt)
    x1 = max(int(x1 - pad / 2 * (x2 - x1)), 0)
    y1 = max(int(y1 - pad / 2 * (y2 - y1)), 0)
    x2 = min(int(x2 + pad / 2 * (x2 - x1)), i_w)
    y2 = min(int(y2 + pad / 2 * (y2 - y1)), i_h)
    cropped_image = image[y1 : y2, x1 : x2]
    return cropped_image


def sorted_by_iou(results):
    ious = []
    for result in results:
        ious.append(result[1][-1])
    idxs = np.argsort(ious)
    sorted_results = []
    for i in idxs:
        sorted_results.append(results[i])
    return sorted_results


def update_form(attribute_form, description_form):
    for key in description_form.keys():
        if key in attribute_form.keys():
            attribute_form[key].append(description_form[key])
        else:
            attribute_form[key] = [description_form[key],]
    return attribute_form


def sample_from_form(attribute_form, sample_num=3):
    sampled_form = {}
    for key, value in attribute_form.items():
        idxs = np.arange(len(value))
        np.random.shuffle(idxs)
        sampled_form[key] = [value[idx] for idx in idxs[:sample_num]]
    return sampled_form


def response_to_json(response):
    try:
        start_idx = response.index('{')
        end_idx = response.rindex('}')
        python_dict = ast.literal_eval(response[start_idx : end_idx + 1])
        response = json.dumps(python_dict)
        response = json.loads(response)
        return response
    except Exception as e:
        print(e)
        return {}

def get_image(img_path, patch_path=None, box=[], padding_ratio=0.2, resize_to=256, show_box=False):
    """
    read and process an image. if bbox given, crop and pad the image before resizing
    
    arguments:
    - img_path: path to the image
    - patch_path: path to the patch if applicable
    - box: normalized xywh format of the box [x_center, y_center, width, height]
    - resize_to: target size of the image (or patch)
    - padding_ratio: padding size around the bbox, represented by the ratio to the height and width of the bbox
    
    returns:
    - processed image (or patch)
    """
    img = Image.open(img_path)
    img = np.array(img)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
    h, w, _ = img.shape

    if box:
        # denormalization
        x_center, y_center, box_w, box_h = box
        x_center, y_center = int(x_center * w), int(y_center * h)
        box_w, box_h = int(box_w * w), int(box_h * h)
        
        # calculate padding
        pad_w = int(box_w * padding_ratio)
        pad_h = int(box_h * padding_ratio)
        
        # calculate the coordinates of the padded bbox
        x_min = max(0, x_center - box_w // 2 - pad_w)
        y_min = max(0, y_center - box_h // 2 - pad_h)
        x_max = min(w, x_center + box_w // 2 + pad_w)
        y_max = min(h, y_center + box_h // 2 + pad_h)
        
        if show_box:
            pad_w = int(box_w * 0.05)
            pad_h = int(box_h * 0.05)
            box_x_min = max(0, x_center - box_w // 2 - pad_w)
            box_y_min = max(0, y_center - box_h // 2 - pad_h)
            box_x_max = min(w, x_center + box_w // 2 + pad_w)
            box_y_max = min(h, y_center + box_h // 2 + pad_h)
            
            # draw bbox rectangles on the image
            cv2.rectangle(img, (box_x_min, box_y_min), (box_x_max, box_y_max), (0, 0, 255), max(1, (x_max-x_min)//50)) 

        # crop the image
        img = img[y_min:y_max, x_min:x_max]
    else:
        img = img
    
    img = cv2.resize(img, (resize_to, resize_to))

    if patch_path is None:
        canvas = img
    else:
        patch = cv2.imread(patch_path)
        h, w, c = img.shape
        patch = cv2.resize(patch, (w // 6, w // 6))
        canvas = np.ones((h, w + w // 6, c)) * 255
        canvas[: w // 6, : w // 6] = patch
        canvas[:, w // 6 :] = img
    return canvas.astype(np.uint8)

def sample_from_dataset(data, data_root, labels=[], samples_per_class=-1):
    data_ = {}
    with open(f'{data_root}/cls_names.json', 'r') as f:
        cls_names = json.load(f)
    names = list(cls_names.values())
    tokens = list(cls_names.keys())
    for l in labels:
        img_paths = data[tokens[l]]
        shuffle(img_paths)
        for img_path in img_paths[: samples_per_class]:
            data_[img_path] = [names[l], l]
    return data_

def tuple_to_str(d):
    new_dict = {}
    for key, value in d.items():
        if isinstance(key, tuple):
            key = str(key)
        if isinstance(value, dict):
            value = tuple_to_str(value)
        new_dict[key] = value
    return new_dict

def str_to_tuple(d):
    new_dict = {}
    for key, value in d.items():
        if isinstance(key, str):
            try:
                key = eval(key)
                if isinstance(key, tuple):
                    pass
                else:
                    key = str(key)
            except:
                key = str(key)
        if isinstance(value, dict):
            value = str_to_tuple(value)
        new_dict[key] = value
    return new_dict

def postprocess(labels, tags_refined):
    # in case gpt cannot use [] to provide multi tags:
    for cls in labels:
        for data in labels[cls]:
            for attr_type in labels[cls][data]:
                for attr in labels[cls][data][attr_type]:
                    tag = labels[cls][data][attr_type][attr]
                    new_tag = tag
                    if ',' in tag and type(tag) == str:
                        new_tag = []
                        for tag_in_str in tag.split(', '):
                            new_tag.append(tag_in_str)
                    if type(tag) == list:
                        new_tag = []
                        for tag_in_list in tag:
                            if ',' in tag_in_list:
                                for tag_in_str in tag_in_list.split(', '):
                                    new_tag.append(tag_in_str)
                            else:
                                new_tag.append(tag_in_list)
                    labels[cls][data][attr_type][attr] = new_tag
    for cls in tags_refined:
        for attr_type in tags_refined[cls]:
            for attr in tags_refined[cls][attr_type]:
                tags = tags_refined[cls][attr_type][attr]
                new_tags = tags
                for tag in tags:
                    if ', ' in tag:
                        for tag_in_str in tag.split(', '):
                            if tag_in_str not in tags:
                                new_tags.append(tag_in_str)
                tags = new_tags
                new_tags = []
                for tag in tags:
                    if ', ' in tag:
                        continue
                    new_tags.append(tag)
                tags_refined[cls][attr_type][attr] = list(set(new_tags))
    return labels, tags_refined

def merge_label_files(label_file, total_worker):
    all_worker_done = True
    labels = {}
    for i in range(total_worker):
        label_file_worker = os.path.splitext(label_file)[0] + '_%d.pkl'%i
        if not os.path.exists(label_file_worker):
            all_worker_done = False
            break
        else:
            labels_worker = pickle.load(open(label_file_worker,'rb'))
            if not labels:
                labels = labels_worker
                continue
            for cls in labels:
                labels[cls].update(labels_worker[cls])
    if all_worker_done:
        return labels
    else:
        exit()
    
def merge_tag_files(tags_refined_file, total_worker):
    all_worker_done = True
    all_new_tags = {}
    for i in range(total_worker):
        new_tag_worker = os.path.splitext(tags_refined_file)[0] + '_%d.json'%i
        if not os.path.exists(new_tag_worker):
            all_worker_done = False
            break
        else:
            new_tag_worker = json.load(open(new_tag_worker))
            if not all_new_tags:
                all_new_tags = new_tag_worker
                continue
            for cls in all_new_tags:
                for category in all_new_tags[cls].keys():
                    for attribute in new_tag_worker[cls][category]:
                        all_new_tags[cls][category][attribute] = list(set(new_tag_worker[cls][category][attribute] + all_new_tags[cls][category].get(attribute, [])))
    
    if all_worker_done:
        return all_new_tags
    else:
        exit()