import numpy as np
import logging
import skimage as io
from tqdm import tqdm
import matplotlib as mpl
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import pylab
import urllib
from io import BytesIO
import requests as req
from PIL import Image
import os
import json
from sklearn.cluster import MiniBatchKMeans
import numpy as np
import pickle

logger = logging.getLogger(__name__)

class sg2sentence():
    def __init__(self, ob_json_path, rel_json_path, img_json_path):
        """
        args
            ob_json_path: path to the object.json
            rel_json_path: path to the relationship.json
            (vg file can be downloaded here https://cs.stanford.edu/~danfei/scene-graph/?fbclid=IwAR1tQPF0aGiWuP0kZCDrUaWnvefWAWJXefas6BWptCKVldj-Ra9psIZnJJQ)
        """
        self.ob_json = self._load_json(ob_json_path)
        self.rel_json = self._load_json(rel_json_path)
        self.img_json = self._load_json(img_json_path)
        self._json_id_checker(self.img_json, self.ob_json)
        self._json_id_checker(self.img_json, self.rel_json)
        self.cate_list = []
        self.sorted_cate_dict = {}
        self.threshold = 128
            
    def _load_json(self, json_path):
        logger.info("[Loading] Json from {}.".format(json_path))
        with open(json_path) as f:
            json_file = json.load(f)
        return json_file
    
    def _json_id_checker(self, x_json, y_json):
        for i in range(len(x_json)):
            assert x_json[i]['image_id'] == y_json[i]['image_id'],\
            "ERROR! Wrong Image ID matched!"

    def _creat_size_shape_cluster(self, list_result, n_size, n_shape):
        size_classifier = MiniBatchKMeans(n_clusters=n_size, random_state=0)
        shape_classifier = MiniBatchKMeans(n_clusters=n_shape, random_state=0)
        size_list = []
        shape_list = []
        id_list = []
        for i, dic in enumerate(list_result):
            for rel in dic['rel_pairs']:
                size_list.append(rel['object_box'][2]*rel['object_box'][3])
                size_list.append(rel['subject_box'][2]*rel['subject_box'][3])
                shape_list.append([rel['object_box'][2], rel['object_box'][3]])
                shape_list.append([rel['subject_box'][2], rel['subject_box'][3]])
                id_list.append(rel['object_id'])
                id_list.append(rel['subject_id'])
        size_ = size_classifier.fit(np.array(size_list).reshape(-1,1))
        shape_ = shape_classifier.fit(np.array(shape_list).reshape(-1,2))
        assert len(size_.labels_) == len(id_list)
        assert len(shape_.labels_) == len(id_list)
        result_dict = {}
        for i in range(len(id_list)):
            result_dict[id_list[i]] = [size_.labels_[i], shape_.labels_[i]]
        return result_dict
    
    def _parse_ob_json(self, ob_json, names_idx = 0):
        """
        Parse objects' name in the ob_json into list.
        """
        vg_cate_list = []
        for i in range(len(ob_json)):
            for idx_ob in range(len(ob_json[i]['objects'])):
                vg_cate_list.append(ob_json[i]['objects'][idx_ob]['names'][0])
        self.cate_list = vg_cate_list
    
    def _parse_rel_json(self, rel_json, names_idx = 0):
        """
        Parse objects' name in the ob_json into list.
        """
        vg_pred_list = []
        for i in range(len(rel_json)):
            for idx_rel in range(len(rel_json[i]['relationships'])):
                vg_pred_list.append(rel_json[i]['relationships'][idx_rel]['predicate'])
        self.pred_list = vg_pred_list
    
    def _simple_top_k(self, input_list, k):
        """
        In this part, we sample the top "k" in list.
        Retrun: 
            dict(). ex. {'trees': 102310, 'man': 99394, ....} (sorted from high to low).
        """
        logger.info("Start sorting and sampling top-{} from input_list.".format(k))
        # calculate the number of object for each cate.
        vg_dict = {}
        for idx, vg_cate in enumerate(input_list):
            if vg_cate not in vg_dict.keys(): vg_dict[vg_cate] = 1
            else: vg_dict[vg_cate] += 1
        sorted_vg_dict = sorted(vg_dict.items(), key=lambda kv: kv[1], reverse=True)
        sorted_vg_dict = dict(sorted_vg_dict[:k])
        return sorted_vg_dict
#         self.sorted_cate_dict = sorted_vg_cate_dict
        
    def parse_Sg2List(self, names_idx = 0, top_num_cate = 150, top_num_pred = 50):
        """
        Transform sgs into lists.
        Return 
            (list) [dict(), dict(), dict(), ...]
            <Note>
            Each dictionary with keys() = ['rel_pairs', 'image_id'] represents the 
            annotation of each image in visual genime dataset.
            the 'rel_pairs' in the dictionary is dict().
            Ex. dict['rel_pairs'] = [{'object':'car', 'predicate':'ON', 'subject':'cat'}, 
            {'object':'dog', 'predicate':'Next to', 'subject':'tree'}, ...]
            box = [x_small, y_small, w, h] norm to w h (x,y)
            
        """
        self._parse_ob_json(self.ob_json, names_idx = names_idx)
        self._parse_rel_json(self.rel_json, names_idx = names_idx)
        self.sorted_cate_dict = self._simple_top_k(self.cate_list, k = top_num_cate)
        self.sorted_pred_dict = self._simple_top_k(self.pred_list, k = top_num_pred)
        #### user defined remove predicte
        del self.sorted_pred_dict['OF']
        del self.sorted_pred_dict['of a']
        del self.sorted_pred_dict['for']
        del self.sorted_pred_dict['of']
        del self.sorted_pred_dict['playing']
        cate_list = list(self.sorted_cate_dict.keys())
        pred_list = list(self.sorted_pred_dict.keys())
        
        print("Sample categories: ",cate_list)
        print("Sample predicates: ",pred_list)
        self.cate_dict = {}
        for i in range(len(cate_list)):
            self.cate_dict[cate_list[i]] = i
        result = []
        for idx, rel_dict in enumerate(tqdm(self.rel_json)):
            image_w, image_h = self.img_json[idx]['width'], self.img_json[idx]['height']
            assert rel_dict['image_id'] == self.img_json[idx]['image_id']
            dict_single_image = {}
            exist_list = []
            rel_list_single_image = []
            for rel_idx, rel in enumerate(rel_dict['relationships']):
                rel_single_image = {}
                object_id_pair = (rel['object']['object_id'],rel['subject']['object_id'])
                if object_id_pair not in exist_list \
                and rel['object']['name'] in cate_list \
                and rel['subject']['name'] in cate_list and rel['predicate'] in pred_list:
                    exist_list.append(object_id_pair)
                    rel_single_image['object'] = rel['object']['name']
                    rel_single_image['predicate'] = rel['predicate']
                    rel_single_image['subject'] = rel['subject']['name']
                    # x,y,w,h
                    rel_single_image['object_id'] = rel['object']['object_id']
                    rel_single_image['subject_id'] = rel['subject']['object_id']
                    rel_single_image['object_box'] = \
                        [rel['object']['x']/image_w, rel['object']['y']/image_h, \
                         rel['object']['w']/image_w, rel['object']['h']/image_h]
                    rel_single_image['subject_box'] = \
                        [rel['subject']['x']/image_w, rel['subject']['y']/image_h, \
                         rel['subject']['w']/image_w, rel['subject']['h']/image_h]
                    rel_list_single_image.append(rel_single_image)
            if len(rel_list_single_image) != 0:
                ## Do not import image without any annotation.
                dict_single_image['rel_pairs'] = rel_list_single_image
                dict_single_image['image_id'] = rel_dict['image_id']
                dict_single_image['image_w'] = image_w
                dict_single_image['image_h'] = image_h
                result.append(dict_single_image)
        return result
            
    def parse_Sg2Sent(self, names_idx = 0, top_num_cate = 150, top_num_pred = 50, n_size = 8, n_shape = 8):
        """
        Transform sgs into Sentenct.
        Return 
            (list) [dict(), dict(), dict(), ...]
            dict() contain keys = ['sentence', 'image_id']
            dict['sentence'] = (String) "[CLS] sub rel obj [SEP] sub rel obj [SEP]"
        """
        START_TOKEN = '[CLS]'
        SEPERATE_TOKEN = '[SEP]'
        result = self.parse_Sg2List(names_idx = names_idx, top_num_cate = top_num_cate, top_num_pred = top_num_pred)
        shape_dict = self._creat_size_shape_cluster(result, n_size=n_size, n_shape=n_shape)
        for idx, rel_dict in enumerate(tqdm(result)):
            sentence_single_image = ''
            sentence_single_image += START_TOKEN
            classes_per_image = []
            boxes_center_per_image = []
            boxes_shape_per_image = []
            boxes_xy_per_image = []
            object_id_by_sent_per_image = []
            object_id_per_image = []
            exist_list = []
            for rel_idx, rel in enumerate(rel_dict['rel_pairs']):
                sentence_per_rel = ' ' + rel['subject'] + ' ' + rel['predicate'] + ' ' + \
                    rel['object'] + ' ' + SEPERATE_TOKEN
                sentence_single_image += sentence_per_rel
                object_id_by_sent_per_image.append(rel['subject_id'])
                object_id_by_sent_per_image.append(rel['object_id'])
                if rel['object_id'] not in exist_list:
                    exist_list.append(rel['object_id'])
                    classes_per_image.append(self.map_classes(rel['object']))
                    boxes_center_per_image.append(self.map_box_center(rel['object_box']))
                    boxes_shape_per_image.append(self.map_box_shape(rel['object_id'], shape_dict, n_size, n_shape))
                    boxes_xy_per_image.append(rel['object_box'])
                    object_id_per_image.append(rel['object_id'])
                if rel['subject_id'] not in exist_list:
                    exist_list.append(rel['subject_id'])
                    classes_per_image.append(self.map_classes(rel['subject']))
                    boxes_center_per_image.append(self.map_box_center(rel['subject_box']))
                    boxes_shape_per_image.append(self.map_box_shape(rel['subject_id'], shape_dict, n_size, n_shape))
                    boxes_xy_per_image.append(rel['subject_box'])
                    object_id_per_image.append(rel['subject_id'])

            rel_dict['gt_classes'] = classes_per_image
            rel_dict['gt_boxes_center'] = boxes_center_per_image
            rel_dict['gt_boxes_shape'] = boxes_shape_per_image
            rel_dict['gt_boxes_xy'] = boxes_xy_per_image
            rel_dict['sentence'] = sentence_single_image
            rel_dict['gt_obj_id_by_sent'] = object_id_by_sent_per_image
            rel_dict['gt_obj_id'] = object_id_per_image
            del rel_dict['rel_pairs']
        return result

    def parse_Sg2Idx(self, names_idx = 0, top_num_cate = 150, top_num_pred = 50):
        """
        Transform sgs into Idx.
        Return 
            (list) [list(), list(), list(), ...]
            list() = (String) "1 4 7 8 2 4 6 5 2 0 0 0"
        """
        # Create Sentence list
        START_TOKEN = '[CLS]'
        SEPERATE_TOKEN = '[SEP]'
        PAD_TOKEN = '[PAD]'
        all_anns = {}
        word_anns = []
        pos_anns = []
        shape_anns = []
        cls_anns = []
        boxes_per_word_anns = []
        boxes_anns = []
        id_anns = []
        image_id_anns = []
        image_wh_anns = []
        word_set = set()
        result = self.parse_Sg2List(names_idx = names_idx, top_num_cate = top_num_cate, top_num_pred = top_num_pred)
        sent_result = self.parse_Sg2Sent(names_idx = names_idx, top_num_cate = top_num_cate, top_num_pred = top_num_pred)
        for idx, rel_dict in enumerate(tqdm(result)):
            assert rel_dict['image_id'] == sent_result[idx]['image_id'], 'Rel and Img idx different!' 
            single_ann = []
            single_box_ann = []
            single_ann.append(START_TOKEN)
            single_box_ann.append([2.,2.,2.,2.])
            for rel_idx, rel in enumerate(rel_dict['rel_pairs']):
                single_ann.append(rel['subject'])
                single_ann.append(rel['predicate'])
                single_ann.append(rel['object'])
                # x,y,w,h
                single_box_ann.append(rel['subject_box'])
                min_x = min(rel['subject_box'][0], rel['object_box'][0])
                min_y = min(rel['subject_box'][1], rel['object_box'][1])
                max_x = max(rel['subject_box'][0]+rel['subject_box'][2], rel['object_box'][0]+rel['object_box'][2])
                max_y = max(rel['subject_box'][1]+rel['subject_box'][3], rel['object_box'][1]+rel['object_box'][3])
                rel_w = max_x - min_x
                rel_h = max_y - min_y
                rel_box = [min_x, min_y, rel_w, rel_h]
                single_box_ann.append(rel_box)
                single_box_ann.append(rel['object_box'])
                single_ann.append(SEPERATE_TOKEN)
                single_box_ann.append([2.,2.,2.,2.])
                word_set.add(rel['subject'])
                word_set.add(rel['predicate'])
                word_set.add(rel['object'])
            
            center_pos, cats_id, shape_centroid, boxes_xywh, obj_id_by_sent = \
                        self.parse_img_anns(sent_result[idx])
            boxes_xywh = self.box_pad_and_cut(boxes_xywh)
            single_id_ann = self.id_pad_and_cut(obj_id_by_sent)
            assert len(single_id_ann) == len(single_ann)
            
            if len(single_ann) <= self.threshold and len(single_ann) > 1:
                word_anns.append(single_ann)
                boxes_per_word_anns.append(single_box_ann)
                image_id_anns.append(rel_dict['image_id'])
                image_wh_anns.append([rel_dict['image_w'], rel_dict['image_h']])
                cls_anns.append(cats_id)
                pos_anns.append(center_pos)
                shape_anns.append(shape_centroid)
                boxes_anns.append(boxes_xywh)
                id_anns.append(single_id_ann)
        # Create Dict
        w2i = self.save_word_dict(word_set, './data/rel_dict_45.pkl')
        _ = self.save_word_dict(self.get_cate_list(), './data/cls_dict_45.pkl')

        # Tokenize & Padding
        for single_ann in word_anns:
            for i in range(len(single_ann)):
                single_ann[i] = w2i[single_ann[i]]

        word_anns = np.array(word_anns)
        id_anns = np.array(id_anns)
        for i in range(len(word_anns)):
            word_anns[i] = np.pad(word_anns[i], ((0, self.threshold - len(word_anns[i]))), 'constant', constant_values = 0)
            id_anns[i] = np.pad(id_anns[i], ((0, self.threshold - len(id_anns[i]))), 'constant', constant_values = 0)
            for j in range(self.threshold - len(boxes_per_word_anns[i])):
                boxes_per_word_anns[i].append([2.,2.,2.,2.])

        all_anns['rel'] = word_anns
        all_anns['rel_box'] = boxes_per_word_anns
        all_anns['id'] = id_anns
        all_anns['cls'] = cls_anns
        all_anns['pos'] = pos_anns
        all_anns['shape'] = shape_anns
        all_anns['box_xy'] = boxes_anns
        all_anns['image_id'] = image_id_anns
        all_anns['image_wh'] = image_wh_anns
        return all_anns

    def save_word_dict(self, word_set, fn):
        w2i, i2w = {}, {}
        w2i['[PAD]'] = 0
        i2w[0] = '[PAD]'
        w2i['[CLS]'] = 1
        i2w[1] = '[CLS]'
        w2i['[SEP]'] = 2
        i2w[2] = '[SEP]'
        w2i['[MASK]'] = 3
        i2w[3] = '[MASK]'

        idx = 4
        for w in word_set:
            if w not in w2i:
                w2i[w] = idx
                i2w[idx] = w
                idx += 1

        with open(fn, 'wb+') as file:
            pickle.dump(i2w, file)

        return w2i


    def parse_img_anns(self, sent_result):
        """
        Sort obj by pos, add BOS, EOS, PAD
        """
        center_pos = np.array(sent_result['gt_boxes_center'])
        cats_id = np.array(sent_result['gt_classes'])
        shape_centroid = np.array(sent_result['gt_boxes_shape'])
        boxes_xywh = np.array(sent_result['gt_boxes_xy'])
        obj_id_by_sent = np.array(sent_result['gt_obj_id_by_sent'])
        obj_id = np.array(sent_result['gt_obj_id'])
        
        sort_idx = np.argsort(center_pos)
        
        center_pos = self.pad_and_cut(center_pos, sort_idx)
        cats_id = self.pad_and_cut(cats_id, sort_idx)
        shape_centroid = self.pad_and_cut(shape_centroid, sort_idx)       
        boxes_xywh = boxes_xywh[sort_idx]
        
        ## parse obj_id and re-index
        obj_id = list(obj_id[sort_idx])
#         print(obj_id)
        for i in range(len(obj_id_by_sent)):
#             print(obj_id.index(obj_id_by_sent[i]) + 1)
            obj_id_by_sent[i] = obj_id.index(obj_id_by_sent[i]) + 1
        
        return center_pos, cats_id, shape_centroid, boxes_xywh, obj_id_by_sent

    def id_pad_and_cut(self, obj_id_by_sent):
        """
        Insert 0 to match the len. of sentence.
        """
        output = [0]
        for idx, obj_id in enumerate(obj_id_by_sent):
            output.append(obj_id)
            output.append(0)
        return output
        
        
    def pad_and_cut(self, arr, idx):
        """
        obj (list) -> bos obj eos pad (list)
        """
        arr = arr[idx]
        arr = np.insert(arr, 0, 1)
        arr = np.append(arr, [2])
        if len(arr) < self.threshold:
            arr = np.pad(arr, ((0, self.threshold - len(arr))), 'constant', constant_values = 0)
        return arr
    
    def box_pad_and_cut(self, arr):
        """
        obj (list) -> bos obj eos pad (list)
        """
        arr = np.insert(arr, 0, [1,1,1,1], 0)
        arr = np.append(arr, [[2,2,2,2]], 0)
        if len(arr) < self.threshold:
            for i in range(self.threshold - len(arr)):
                arr = np.append(arr, [[0,0,0,0]], 0)
        return arr


    def map_classes(self, class_name):
        """
        0 for PAD
        1 for BOS
        2 for EOS
        3 for MASK
        """
        return self.cate_dict[class_name] + 4
#         return list(self.sorted_cate_dict.keys()).index(class_name) + 1
    
    def map_box_center(self, box, map_w = 8, map_h = 8):
        """
        0 for PAD
        1 for BOS
        2 for EOS
        3 for MASK
        """
        x_center = box[0] + box[2]/2.
        y_center = box[1] + box[3]/2.
        grid_w, grid_h = 1./map_w, 1./map_h
        idx_x, idx_y = x_center // grid_w, y_center // grid_h
        if idx_x >= map_w: idx_x = map_w-1
        if idx_y >= map_h: idx_y = map_h-1
        return int(idx_x + idx_y * map_w + 4)
    
    def map_box_shape(self, obj_id, shape_dict, n_size, n_shape):
        """
        0 for PAD
        1 for BOS
        2 for EOS
        3 for MASK
        """
        idx_size = int(shape_dict[obj_id][0])
        idx_shape = int(shape_dict[obj_id][1])
        return int(idx_size + idx_shape * n_size + 4)
    
    def get_cate_list(self):
        """Return cate list"""
        if len(list(self.sorted_cate_dict.keys())) == 0:
            logger.warning("Please construct the cate dict first.")
        return list(self.sorted_cate_dict.keys())
    
    def get_pred_list(self):
        """Return pred list"""
        if len(list(self.sorted_pred_dict.keys())) == 0:
            logger.warning("Please construct the pred dict first.")
        return list(self.sorted_pred_dict.keys())
    
        
if __name__ == '__main__':
    ob_json_path = './data/vg/scene_graph/objects.json'
    rel_json_path = './data/vg/scene_graph/relationships.json'
    img_json_path = './data/vg/scene_graph/image_data.json'
    VG_parser = sg2sentence(ob_json_path, rel_json_path, img_json_path)
    if True:
        sent_ann = VG_parser.parse_Sg2Sent()
        with open('./data/vg_sents_45.pkl', 'wb+') as file:
            pickle.dump(sent_ann, file)
        all_anns = VG_parser.parse_Sg2Idx()
        with open('./data/vg_anns_45.pkl', 'wb+') as file:
            pickle.dump(all_anns, file)
    else:
        list_result = VG_parser.parse_Sg2List(names_idx = 0, top_num_cate = 150, top_num_pred = 50)
        sent_result = VG_parser.parse_Sg2Sent(names_idx = 0, top_num_cate = 150, top_num_pred = 50)
        print("="*20)
        print("Instance in Image_id = 0")
        print("="*20)
        print(list_result[0])
        print("="*20)
        print(sent_result[0])
        print("="*20)
        print(VG_parser.get_cate_list())
        print("="*20)
        print(VG_parser.get_pred_list())
    