from __future__ import print_function
import sys
sys.path.append("..")
sys.path.append(".")
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import copy
from tqdm import tqdm
import json
from helpers.psutil import FreeMemLinux
from helpers.util import normalize_box_params
from omegaconf import OmegaConf
import random
import pickle
import trimesh
import h5py
import open3d as o3d


def load_ckpt(ckpt):
    map_fn = lambda storage, loc: storage
    if type(ckpt) == str:
        state_dict = torch.load(ckpt, map_location=map_fn)
    else:
        state_dict = ckpt
    return state_dict

class ThreedFrontDatasetSceneGraph(data.Dataset):
    def __init__(self, root, root_3dfront='', split='train', shuffle_objs=False, pass_scan_id=False, 
                 use_scene_rels=False, data_len=None,   eval=False,
                 eval_type='addition', with_feats=False, 
                 seed=True, large=False, recompute_feats=False, 
                 room_type='bedroom'):

        self.room_type = room_type
        self.seed = seed
        self.with_feats = with_feats
        self.cond_model = None
        self.large = large
        self.recompute_feats = recompute_feats

        if eval and seed:
            np.random.seed(47)
            torch.manual_seed(47)
            random.seed(47)


        self.root = root
        # list of class categories
        self.catfile = os.path.join(self.root,'classes_{}.txt'.format(self.room_type))
        self.cat = {}
        self.scans = []
        self.obj_paths = []
        self.data_len = data_len
        self.use_scene_rels = use_scene_rels

        self.fm = FreeMemLinux('GB')
        self.vocab = {}
        with open(os.path.join(self.root, 'classes_{}.txt'.format(self.room_type)), "r") as f:
            self.vocab['object_idx_to_name'] = f.readlines()
        # with open(os.path.join(self.root, 'classes_all.txt'), "r") as f:
        #     self.vocab['object_idx_to_name'] = f.readlines()
        with open(os.path.join(self.root, 'relationships.txt'), "r") as f:
            self.vocab['pred_idx_to_name'] = ['in\n']
            self.vocab['pred_idx_to_name']+=f.readlines()

        # list of relationship categories
        self.relationships = self.read_relationships(os.path.join(self.root, 'relationships.txt'))
        self.relationships_dict = dict(zip(self.relationships,range(1,len(self.relationships)+1)))
        self.relationships_dict_r = dict(zip(self.relationships_dict.values(), self.relationships_dict.keys()))

        if split == 'train_scans': # training set
            print('loaded training data')
            self.rel_json_file = os.path.join(self.root, 'relationships_{}_trainval.json'.format(self.room_type))
            self.box_json_file = os.path.join(self.root, 'obj_boxes_{}_trainval.json'.format(self.room_type))
            self.box_normalized_stats = os.path.join(self.root, 'boxes_centered_stats_{}_trainval.txt'.format(self.room_type))
        else: # test set
            print('loaded testing data')
            self.rel_json_file = os.path.join(self.root, 'relationships_{}_test.json'.format(self.room_type))
            self.box_json_file = os.path.join(self.root, 'obj_boxes_{}_test.json'.format(self.room_type))
            self.box_normalized_stats = os.path.join(self.root, 'boxes_centered_stats_{}_test.txt'.format(self.room_type))


        self.relationship_json, self.objs_json, self.tight_boxes_json = \
                self.read_relationship_json(self.rel_json_file, self.box_json_file)


        for scene, infos in self.tight_boxes_json.items():
            for id, info in infos.items():
                if 'model_path' in info:
                    if info['model_path']:
                        info['model_path'] = root + info['model_path'][36:]


        self.padding = 0.2
        self.eval = eval
        self.pass_scan_id = pass_scan_id
        self.shuffle_objs = shuffle_objs
        self.root_3dfront = root_3dfront
        if self.root_3dfront == '':
            self.root_3dfront = os.path.join(self.root, 'visualization')
            if not os.path.exists(self.root_3dfront):
                os.makedirs(self.root_3dfront)

        self.mapping_full2simple = json.load(open(os.path.join(self.root, "mapping.json"), "r"))

        with open(self.catfile, 'r') as f:
            for line in f:
                category = line.rstrip()
                self.cat[category] = category

        self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
        self.classes_r = dict(zip(self.classes.values(), self.classes.keys()))

        points_classes = list(self.classes.keys())
        points_classes.remove('_scene_')

        # points_classes = ['armchair', 'bookshelf', 'cabinet', 'ceiling_lamp', 'chair', 'chaise_longue_sofa', 'children_cabinet', 'chinese_chair',
        #                       'coffee_table', 'console_table', 'corner_side_table', 'desk', 'dining_chair', 'dining_table', 'double_bed', 'dressing_chair',
        #                       'dressing_table', 'kids_bed', 'l_shaped_sofa', 'lazy_sofa', 'lounge_chair', 'loveseat_sofa', 'multi_seat_sofa', 'nightstand', 'pendant_lamp',
        #                       'round_end_table', 'shelf', 'single_bed', 'sofa', 'stool', 'table', 'tv_stand', 'wardrobe', 'wine_cabinet']

        # Why we do this here? It is because that we want to make each category evenly sampled during the diffusion
        # training. You can see the objects are classified/mapped into coarse categories, which will cause the number
        # of objects in each category are very different from each other. For example, Chairs are the most and lamps
        # are the fewest among all the objects. So when we sample a batch in more fine-grained classes, the problem
        # can be alleviated.
        self.vocab['object_idx_to_name_grained'] = self.vocab['object_idx_to_name']

        if not self.large:
            self.fine_grained_classes = dict(zip(sorted([voc.strip('\n') for voc in self.vocab['object_idx_to_name']]),range(len(self.vocab['object_idx_to_name']))))
            #print('without mapping_full2simple', len(self.vocab['object_idx_to_name'] ))
            self.vocab['object_idx_to_name'] = [self.mapping_full2simple[voc.strip('\n')]+'\n' for voc in self.vocab['object_idx_to_name']]
            #print('using mapping_full2simple')
            #print('mapping_full2simple', len(self.vocab['object_idx_to_name'] ))
            self.classes = dict(zip(sorted(list(set([voc.strip('\n') for voc in self.vocab['object_idx_to_name']]))),
                                        range(len(list(set(self.vocab['object_idx_to_name']))))))
            self.classes_r = dict(zip(self.classes.values(), self.classes.keys()))
            points_classes = list(set([self.mapping_full2simple[class_] for class_ in points_classes]))

        points_classes_idx = [self.classes[pc] for pc in points_classes]

        self.point_classes_idx = points_classes_idx + [0]
        self.sorted_cat_list = sorted(self.cat)
        self.files = {}
        self.eval_type = eval_type
        # check if all shape features exist. If not they get generated here (once)
        if with_feats:
            print('Assume you downloaded the DeepSDF codes. If not, please download in README.md')
            # for index in tqdm(range(len(self))):
            #     self.__getitem__(index)
            self.recompute_feats = False


    def read_relationship_json(self, json_file, box_json_file):
        """ Reads from json files the relationship labels, objects and bounding boxes

        :param json_file: file that stores the objects and relationships
        :param box_json_file: file that stores the oriented 3D bounding box parameters
        :return: three dicts, relationships, objects and boxes
        """
        rel = {}
        objs = {}
        tight_boxes = {}

        with open(box_json_file, "r") as read_file:
            box_data = json.load(read_file)

        with open(json_file, "r") as read_file:
            data = json.load(read_file)
            for scan in data['scans']:

                relationships = []
                for relationship in scan["relationships"]:
                    relationship[2] -= 1
                    relationships.append(relationship)

                # for every scan in rel json, we append the scan id
                rel[scan["scan"]] = relationships
                self.scans.append(scan["scan"])

                objects = {}
                boxes = {}
                for k, v in scan["objects"].items():
                    # if not self.large:
                    #     objects[int(k)] = self.mapping_full2simple[v]
                    # else:
                    objects[int(k)] = v

                    try:
                        boxes[int(k)] = {}
                        boxes[int(k)]['param7'] = box_data[scan["scan"]][k]["param7"]
                        boxes[int(k)]['param7'][6] = boxes[int(k)]['param7'][6]
                        boxes[int(k)]['scale'] = box_data[scan["scan"]][k]["scale"]
                    except Exception as e:
                        # probably box was not saved because there were 0 points in the instance!
                        print(e)
                    try:
                        boxes[int(k)]['model_path']  = box_data[scan["scan"]][k]["model_path"]
                    except Exception as e:
                        print(e)
                        continue
                boxes["scene_center"] = box_data[scan["scan"]]["scene_center"]
                objs[scan["scan"]] = objects
                tight_boxes[scan["scan"]] = boxes
        return rel, objs, tight_boxes

    def read_relationships(self, read_file):
        """load list of relationship labels

        :param read_file: path of relationship list txt file
        """
        relationships = []
        with open(read_file, 'r') as f:
            for line in f:
                relationship = line.rstrip().lower()
                relationships.append(relationship)
        return relationships

    def norm_points(self, p):
        centroid = np.mean(p, axis=0)
        m = np.max(np.sqrt(np.sum(p ** 2, axis=1)))
        p = (p - centroid) / float(m)
        return p

    def get_key(self, dict, value):
        for k, v in dict.items():
            if v == value:
                return k
        return None

    def __getitem__(self, index):
        scan_id = self.scans[index]

        # instance2label, the whole instance ids in this scene e.g. {1: 'floor', 2: 'wall', 3: 'picture', 4: 'picture'}
        instance2label = self.objs_json[scan_id]
        keys = list(instance2label.keys())
        #print('keys: ', keys)

        if self.shuffle_objs:
            random.shuffle(keys)

        feats_in = None

        feats_path = self.root + "/DEEPSDF_reconstruction/Codes/" # for Graph-to-3D

        # Load points for debug
        if self.with_feats and (not os.path.exists(feats_path) or self.recompute_feats):
            if scan_id in self.files: # Caching
                (points_list, points_norm_list, instances_list) = self.files[scan_id]
            else:
                points_list=np.array([]).reshape(-1,3)
                points_norm_list = np.array([]).reshape(-1, 3)
                instances_list=np.array([]).reshape(-1,1)
                for key_, value_ in self.tight_boxes_json[scan_id].items():
                    if isinstance(key_,int):
                        path = self.tight_boxes_json[scan_id][key_]["model_path"]
                        # object points
                        if path is not None:
                            raw_mesh = trimesh.load(path)
                            position = self.tight_boxes_json[scan_id][key_]["param7"][3:6]
                            theta = self.tight_boxes_json[scan_id][key_]["param7"][-1]
                            R = np.zeros((3, 3))
                            R[0, 0] = np.cos(theta)
                            R[0, 2] = -np.sin(theta)
                            R[2, 0] = np.sin(theta)
                            R[2, 2] = np.cos(theta)
                            R[1, 1] = 1.
                            points = raw_mesh.copy().vertices
                            point_norm = self.norm_points(points) #normliazed in each individual boxes
                            points = points.dot(R) + position # not centered yet
                        # floor points
                        else:
                            position = self.tight_boxes_json[scan_id][key_]["param7"][3:6]
                            l,w = self.tight_boxes_json[scan_id][key_]["param7"][0], self.tight_boxes_json[scan_id][key_]["param7"][2]
                            x = l * np.random.random(1000)+ position[0] - l/2
                            z = w * np.random.random(1000)+ position[2] - w/2
                            y = np.repeat(0,1000)
                            points = np.vstack((x,y,z)).transpose()
                            point_norm = self.norm_points(points)
                        points_list = np.concatenate((points_list, points), axis=0)
                        points_norm_list = np.concatenate((points_norm_list, point_norm), axis=0)
                        instances = np.repeat(key_, points.shape[0]).reshape(-1, 1)
                        instances_list = np.concatenate((instances_list, instances), axis=0)

                if self.fm.user_free > 5:
                    self.files[scan_id] = (points_list, points_norm_list, instances_list)

            print("shifting points")
            points_list = points_list - np.array(self.tight_boxes_json[scan_id]['scene_center']) # centered in the scene

        instance2mask = {}
        instance2mask[0] = 0

        cat_ids = []
        cat_ids_grained = []
        tight_boxes = []

        counter = 0

        instances_order = []
        selected_shapes = []

        # key: 1 of 1: 'floor' instance_id              keys: whole instance ids
        for key in keys:
            # get objects from the selected list of classes of 3dssg
            scene_instance_id = key
            scene_instance_class = instance2label[key]
            
            if not self.large:
                scene_class_id_grained = self.fine_grained_classes[scene_instance_class]
                scene_instance_class = self.mapping_full2simple[scene_instance_class]
                #print('scene_instance_class: ', scene_instance_class)
                scene_class_id = self.classes[scene_instance_class]
                #print('scene_class_id: ', scene_class_id)

            else:
                scene_class_id = self.classes[scene_instance_class] # class id in the entire dataset ids

            instance2mask[scene_instance_id] = counter + 1
            counter += 1

            # mask to cat:
            if (scene_class_id >= 0) and (scene_instance_id > 0):
                selected_shapes.append(True)
                cat_ids.append(scene_class_id)
                if not self.large:
                    cat_ids_grained.append(scene_class_id_grained)
                else:
                    cat_ids_grained.append(scene_class_id)
                bbox = np.array(self.tight_boxes_json[scan_id][key]['param7'].copy())
                bbox[3:6] -= np.array(self.tight_boxes_json[scan_id]['scene_center'])

                instances_order.append(key)
                #bins = np.linspace(np.deg2rad(-180), np.deg2rad(180), 24)
                #angle = np.digitize(bbox[6], bins)
                angle = bbox[6]
                bbox = normalize_box_params(bbox,file=self.box_normalized_stats)
                bbox[6] = angle

                tight_boxes.append(bbox)



        if self.with_feats:
            # If precomputed features exist, we simply load them
            latents = []
            #for key_, value_ in self.tight_boxes_json[scan_id].items():
            for key_ in instances_order: # get the objects in order
                if isinstance(key_, int):
                    path = self.tight_boxes_json[scan_id][key_]["model_path"]
                    if path is None:
                        latent_code = np.zeros([1, 256]) #for the floor, latent_code.shape[1]=256
                        #print("why is it none?")
                    else:
                        model_id = path.split('/')[-2]
                        latent_code_path = feats_path + model_id + "/sdf.pth"
                        latent_code = torch.load(latent_code_path, map_location="cpu")[0]
                        latent_code = latent_code.detach().numpy()
                    latents.append(latent_code)
            latents.append(np.zeros([1, 256])) # for the room shape
            feats_in = list(np.concatenate(latents, axis=0))

        triples = []
        words = []
        rel_json = self.relationship_json[scan_id]
        for r in rel_json: # create relationship triplets from data
            if r[0] in instance2mask.keys() and r[1] in instance2mask.keys():
                subject = instance2mask[r[0]] - 1
                object = instance2mask[r[1]] - 1
                predicate = r[2] + 1
                if subject >= 0 and object >= 0:
                    triples.append([subject, predicate, object])
                    if not self.large:
                        words.append(self.mapping_full2simple[instance2label[r[0]]] + ' ' + r[3] + ' ' + self.mapping_full2simple[instance2label[r[1]]])
                    else:
                        words.append(instance2label[r[0]]+' '+r[3]+' '+instance2label[r[1]]) # TODO check
            else:
                continue

        if self.use_scene_rels:
            # add _scene_ object and _in_scene_ connections
            scene_idx = len(cat_ids)
            for i, ob in enumerate(cat_ids):
                triples.append([i, 0, scene_idx])
                words.append(self.get_key(self.classes, ob) + ' ' + 'in' + ' ' + 'room')
            cat_ids.append(0) # TODO check
            cat_ids_grained.append(0)
            # dummy scene box
            tight_boxes.append([-1, -1, -1, -1, -1, -1, -1])

        output = {}
        #print('words: ')

        # prepare outputs
        output['encoder'] = {}
        output['encoder']['objs'] = cat_ids
        output['encoder']['objs_grained'] = cat_ids_grained # not needed for encoder
        output['encoder']['triples'] = triples
        output['encoder']['boxes'] = tight_boxes
        output['encoder']['words'] = words


        if self.with_feats:
            output['encoder']['feats'] = feats_in

        output['gt'] = copy.deepcopy(output['encoder'])

        # torchify
        output['encoder']['objs'] = torch.from_numpy(np.array(output['encoder']['objs'], dtype=np.int64)) # this is changed
        output['encoder']['objs_grained'] = torch.from_numpy(np.array(output['encoder']['objs_grained'], dtype=np.int64)) # this doesn't matter
        output['encoder']['triples'] = torch.from_numpy(np.array(output['encoder']['triples'], dtype=np.int64))
        output['encoder']['boxes'] = torch.from_numpy(np.array(output['encoder']['boxes'], dtype=np.float32))
        if self.with_feats:
            output['encoder']['feats'] = torch.from_numpy(np.array(output['encoder']['feats'], dtype=np.float32))

        # these two should have the same amount.
        output['gt']['objs'] = torch.from_numpy(np.array(output['gt']['objs'], dtype=np.int64))
        output['gt']['objs_grained'] = torch.from_numpy(np.array(output['gt']['objs_grained'], dtype=np.int64))

        output['gt']['triples'] = torch.from_numpy(np.array(output['gt']['triples'], dtype=np.int64)) # this is changed
        output['gt']['boxes'] = torch.from_numpy(np.array(output['gt']['boxes'], dtype=np.float32))

        if self.with_feats:
            output['gt']['feats'] = torch.from_numpy(np.array(output['gt']['feats'], dtype=np.float32))


        output['scan_id'] = scan_id
        output['instance_id'] = instances_order
        output['classes'] = self.classes

        return output





    def __len__(self):
        if self.data_len is not None:
            return self.data_len
        else:
            return len(self.scans)


    def collate_fn_vaegan(self, batch, use_points=False):
        """
        Collate function to be used when wrapping a RIODatasetSceneGraph in a
        DataLoader. Returns a dictionary
        """

        out = {}

        out['scene_points'] = []
        out['scan_id'] = []
        out['instance_id'] = []
        #out['classes']=self.classes
        global_node_id = 0
        global_dec_id = 0
        for i in range(len(batch)):
            if batch[i] == -1:
                return -1
            # notice only works with single batches
            out['scan_id'].append(batch[i]['scan_id'])
            out['instance_id'].append(batch[i]['instance_id'])

            if 'scene' in batch[i]:
                out['scene_points'].append(batch[i]['scene'])

            global_node_id += len(batch[i]['encoder']['objs'])
            global_dec_id += len(batch[i]['gt']['objs'])

        for key in ['encoder', 'gt']:
            all_objs, all_boxes, all_triples = [], [], []
            all_objs_grained = []
            all_obj_to_scene, all_triple_to_scene = [], []
            all_points = []
            all_feats = []
            all_text_feats = []
            all_rel_feats = []

            obj_offset = 0

            for i in range(len(batch)):
                if batch[i] == -1:
                    print('this should not happen')
                    continue
                (objs, triples, boxes) = batch[i][key]['objs'], batch[i][key]['triples'], batch[i][key]['boxes']

                if 'points' in batch[i][key]:
                    all_points.append(batch[i][key]['points'])

                if 'feats' in batch[i][key]:
                    all_feats.append(batch[i][key]['feats'])

                num_objs, num_triples = objs.size(0), triples.size(0)

                all_objs.append(batch[i][key]['objs'])
                all_objs_grained.append(batch[i][key]['objs_grained'])
                all_boxes.append(boxes)

                if triples.dim() > 1:
                    triples = triples.clone()
                    triples[:, 0] += obj_offset
                    triples[:, 2] += obj_offset

                    all_triples.append(triples)
                    all_triple_to_scene.append(torch.LongTensor(num_triples).fill_(i))

                all_obj_to_scene.append(torch.LongTensor(num_objs).fill_(i))

                obj_offset += num_objs

            all_objs = torch.cat(all_objs)
            all_objs_grained = torch.cat(all_objs_grained)
            all_boxes = torch.cat(all_boxes)

            all_obj_to_scene = torch.cat(all_obj_to_scene)

            if len(all_triples) > 0:
                all_triples = torch.cat(all_triples)
                all_triple_to_scene = torch.cat(all_triple_to_scene)
            else:
                return -1

            outputs = {'objs': all_objs,
                       'objs_grained': all_objs_grained,
                       'tripltes': all_triples,
                       'boxes': all_boxes,
                       'obj_to_scene': all_obj_to_scene,
                       'triple_to_scene': all_triple_to_scene}


            if len(all_points) > 0:
                all_points = torch.cat(all_points)
                outputs['points'] = all_points

            if len(all_feats) > 0:
                all_feats = torch.cat(all_feats)
                outputs['feats'] = all_feats

            out[key] = outputs

        return out


    def collate_fn_vaegan_points(self,batch):
        """ Wrapper of the function collate_fn_vaegan to make it also return points
        """
        return self.collate_fn_vaegan(batch, use_points=True)


if __name__ == "__main__":
    dataset = ThreedFrontDatasetSceneGraph(
        root="/media/xxx/FRONT",
        split='val_scans',
        shuffle_objs=True,
        use_scene_rels=True,
        with_feats=False,
        large=False,
        seed=False,
        room_type='livingroom')
    a = dataset[0]

    for x in ['encoder', 'gt']:
        en_obj = a[x]['objs'].cpu().numpy().astype(np.int32)
        en_triples = a[x]['triples'].cpu().numpy().astype(np.int32)
        #instance
        sub = en_triples[:,0]
        obj = en_triples[:,2]
        #cat
        instance_ids = np.array(sorted(list(set(sub.tolist() + obj.tolist())))) #0-n
        cat_ids = en_obj[instance_ids]
        texts = [dataset.classes_r[cat_id] for cat_id in cat_ids]
        objs = dict(zip(instance_ids.tolist(),texts))
        objs = {str(key): value for key, value in objs.items()}
        for rel in en_triples[:,1]:
            if rel == 0:
                txt = 'in'
                txt_list.append(txt)
                continue
            txt = dataset.relationships_dict_r[rel]
            txt_list.append(txt)
        txt_list = np.array(txt_list)
        rel_list = np.vstack((sub,obj,en_triples[:,1],txt_list)).transpose()
        print(a['scan_id'])