# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE

import os
import re
import json
import torch
import PIL.Image as Image
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
from UniDet_eval.dataset.randaugment import RandAugment



# COCO_FEATURES = torch.load('dataset/coco_features.pt')['features']
# ADE_FEATURES = torch.load('dataset/ade_features.pt')['features']
# DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features']
# BACKGROUND_FEATURES = torch.load('dataset/background_features.pt')


class Transform:
    def __init__(self, resize_resolution=384, scale_size=[0.5, 1.0], train=False):
        self.resize_size = [resize_resolution, resize_resolution]
        self.scale_size = scale_size
        self.train = train
        self.randaugment = RandAugment(2, 5)

    def __call__(self, image, labels):
        if self.train:
            # random resize crop
            i, j, h, w = transforms.RandomResizedCrop.get_params(img=image, scale=self.scale_size, ratio=[3. / 4, 4. / 3])
            image = transforms_f.crop(image, i, j, h, w)
            if labels is not None:
                for exp in labels:
                    labels[exp] = transforms_f.crop(labels[exp], i, j, h, w)

        # resize to the defined shape
        image = transforms_f.resize(image, self.resize_size, transforms_f.InterpolationMode.BICUBIC)
        if labels is not None:
            for exp in labels:
                labels[exp] = transforms_f.resize(labels[exp], [224, 224], transforms_f.InterpolationMode.NEAREST)

        if self.train:
            # random flipping
            if torch.rand(1) > 0.5:
                image = transforms_f.hflip(image)
                if labels is not None:
                    for exp in labels:
                        labels[exp] = transforms_f.hflip(labels[exp])

            # random augmentation
            image, labels = self.randaugment(image, labels)

        # transform to tensor
        image = transforms_f.to_tensor(image)
        if labels is not None:
            for exp in labels:
                if exp in ['depth', 'normal', 'edge']:
                    labels[exp] = transforms_f.to_tensor(labels[exp])
                else:
                    labels[exp] = (transforms_f.to_tensor(labels[exp]) * 255).long()

        # apply normalisation:
        image = transforms_f.normalize(image, mean=[0.48145466, 0.4578275, 0.40821073],
                                       std=[0.26862954, 0.26130258, 0.27577711])
        if labels is not None:
            return {'rgb': image, **labels}
        else:
            return{'rgb': image}


def get_expert_labels(data_path, label_path, image_path, dataset, experts):
    image_full_path = os.path.join(data_path, dataset, image_path)
    image = Image.open(image_full_path).convert('RGB')
    if experts != 'none':
        labels = {}
        labels_info = {}
        ps = image_path.split('.')[-1]
        for exp in experts:
            if exp in ['seg_coco', 'seg_ade', 'edge', 'depth']:
                label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
                if os.stat(label_full_path).st_size > 0:
                    labels[exp] = Image.open(label_full_path).convert('L')
                else:
                    labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0]])).convert('L')
            elif exp == 'normal':
                label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
                if os.stat(label_full_path).st_size > 0:
                    labels[exp] = Image.open(label_full_path).convert('RGB')
                else:
                    labels[exp] = Image.fromarray(np.zeros([image.size[1], image.size[0], 3])).convert('RGB')
            elif exp == 'obj_detection':
                label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
                if os.stat(label_full_path).st_size > 0:
                    labels[exp] = Image.open(label_full_path).convert('L')
                else:
                    labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
                label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.json'))
                labels_info[exp] = json.load(open(label_info_path, 'r'))
            elif exp == 'ocr_detection':
                label_full_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.png'))
                label_info_path = os.path.join(label_path, exp, dataset, image_path.replace(f'.{ps}', '.pt'))
                if os.path.exists(label_info_path):
                    labels[exp] = Image.open(label_full_path).convert('L')
                    labels_info[exp] = torch.load(label_info_path)
                else:
                    labels[exp] = Image.fromarray(255 * np.ones([image.size[1], image.size[0]])).convert('L')
                    labels_info[exp] = None

    else:
        labels, labels_info = None, None
    return image, labels, labels_info


def post_label_process(inputs, labels_info):
    eps = 1e-6
    for exp in inputs:
        if exp in ['depth', 'normal', 'edge']:  # remap to -1 to 1 range
            inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
        
        elif exp == 'seg_coco':  # in-paint with CLIP features
            text_emb = torch.empty([64, *inputs[exp].shape[1:]])
            for l in inputs[exp].unique():
                if l == 255:
                    text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
                else:
                    text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
            inputs[exp] = text_emb

        elif exp == 'seg_ade':  # in-paint with CLIP features
            text_emb = torch.empty([64, *inputs[exp].shape[1:]])
            for l in inputs[exp].unique():
                if l == 255:
                    text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
                else:
                    text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
            inputs[exp] = text_emb

        elif exp == 'obj_detection':  # in-paint with CLIP features
            text_emb = torch.empty([64, *inputs[exp].shape[1:]])
            label_map = labels_info[exp]
            for l in inputs[exp].unique():
                if l == 255:
                    text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
                else:
                    text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
            inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}

        elif exp == 'ocr_detection':  # in-paint with CLIP features
            text_emb = torch.empty([64, *inputs[exp].shape[1:]])
            label_map = labels_info[exp]
            for l in inputs[exp].unique():
                if l == 255:
                    text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
                else:
                    text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
            inputs[exp] = text_emb
    return inputs


def pre_caption(caption, max_words=50):
    caption = re.sub(r"([.!\"()*#:;~])", ' ', caption.capitalize())  # remove special characters
    caption = re.sub(r"\s{2,}", ' ', caption)  # remove two white spaces

    caption = caption.rstrip('\n')  # remove \num_ans_per_q symbol
    caption = caption.strip(' ')    # remove leading and trailing white spaces

    # truncate caption to the max words
    caption_words = caption.split(' ')
    if len(caption_words) > max_words:
        caption = ' '.join(caption_words[:max_words])
    return caption


def pre_question(question, max_words=50):
    question = re.sub(r"([.!\"()*#:;~])", ' ', question.capitalize())  # remove special characters
    question = question.strip()

    # truncate question
    question_words = question.split(' ')
    if len(question_words) > max_words:
        question = ' '.join(question_words[:max_words])
    if question[-1] != '?':
        question += '?'
    return question

