from __future__ import print_function
import sys
sys.path.append("..")
sys.path.append(".")
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import copy
from tqdm import tqdm
import json
from helpers.psutil import FreeMemLinux
from helpers.util import normalize_box_params
from omegaconf import OmegaConf
import random
import pickle
import trimesh
import h5py
import open3d as o3d

from sklearn.cluster import DBSCAN
from collections import defaultdict

#from torch.utils.data import BatchSampler


def load_ckpt(ckpt):
    map_fn = lambda storage, loc: storage
    if type(ckpt) == str:
        state_dict = torch.load(ckpt, map_location=map_fn)
    else:
        state_dict = ckpt
    return state_dict

class ThreedFrontDatasetSceneGraphIncremental(data.Dataset):
    def __init__(self, root, root_3dfront='', split='train', shuffle_objs=False, pass_scan_id=False, 
                 data_len=None,   eval=False,
                 eval_type='addition', with_feats=True, 
                 seed=True, large=False, recompute_feats=False, 
                 room_type='bedroom'):

        self.room_type = room_type
        self.seed = seed
        self.with_feats = with_feats
        self.cond_model = None
        self.large = large
        self.recompute_feats = recompute_feats

        

        if eval and seed:
            np.random.seed(47)
            torch.manual_seed(47)
            random.seed(47)


        self.root = root
        # list of class categories
        self.catfile = os.path.join(self.root,'classes_{}.txt'.format(self.room_type))
        self.cat = {}
        self.scans = []
        self.obj_paths = []
        self.data_len = data_len

        

        self.fm = FreeMemLinux('GB')
        self.vocab = {}
        with open(os.path.join(self.root, 'classes_{}.txt'.format(self.room_type)), "r") as f:
            self.vocab['object_idx_to_name'] = f.readlines()
        # with open(os.path.join(self.root, 'classes_all.txt'), "r") as f:
        #     self.vocab['object_idx_to_name'] = f.readlines()
        with open(os.path.join(self.root, 'relationships.txt'), "r") as f:
            self.vocab['pred_idx_to_name'] = ['in\n']
            self.vocab['pred_idx_to_name']+=f.readlines()

        # list of relationship categories
        self.relationships = self.read_relationships(os.path.join(self.root, 'relationships.txt'))
        self.relationships_dict = dict(zip(self.relationships,range(1,len(self.relationships)+1)))
        self.relationships_dict_r = dict(zip(self.relationships_dict.values(), self.relationships_dict.keys()))

        if split == 'train_scans': # training set
            print('load training data for', self.room_type)
            self.rel_json_file = os.path.join(self.root, 'relationships_{}_trainval.json'.format(self.room_type))
            self.box_json_file = os.path.join(self.root, 'obj_boxes_{}_trainval.json'.format(self.room_type))
            self.box_normalized_stats = os.path.join(self.root, 'boxes_centered_stats_{}_trainval.txt'.format(self.room_type))
        else: # test set
            print('load testing data', self.room_type)
            self.rel_json_file = os.path.join(self.root, 'relationships_{}_test.json'.format(self.room_type))
            self.box_json_file = os.path.join(self.root, 'obj_boxes_{}_test.json'.format(self.room_type))
            self.box_normalized_stats = os.path.join(self.root, 'boxes_centered_stats_{}_test.txt'.format(self.room_type))


        self.relationship_json, self.objs_json, self.tight_boxes_json = \
                self.read_relationship_json(self.rel_json_file, self.box_json_file)
        
        self.scan_name2idx = {name.strip(): i for i, name in enumerate(self.scans)}

        for scene, infos in self.tight_boxes_json.items():
            for id, info in infos.items():
                if 'model_path' in info:
                    if info['model_path']:
                        info['model_path'] = root + info['model_path'][36:]


        self.padding = 0.2
        self.eval = eval
        self.pass_scan_id = pass_scan_id
        self.shuffle_objs = shuffle_objs
        self.root_3dfront = root_3dfront
        if self.root_3dfront == '':
            self.root_3dfront = os.path.join(self.root, 'visualization')
            if not os.path.exists(self.root_3dfront):
                os.makedirs(self.root_3dfront)

        self.mapping_full2simple = json.load(open(os.path.join(self.root, "mapping.json"), "r")) # e.g mapping "lazy_sofa" to "sofa"

        with open(self.catfile, 'r') as f:
            for line in f:
                category = line.rstrip()
                self.cat[category] = category

        self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
        self.classes_r = dict(zip(self.classes.values(), self.classes.keys()))

        points_classes = list(self.classes.keys())
        points_classes.remove('_scene_')

        # points_classes = ['armchair', 'bookshelf', 'cabinet', 'ceiling_lamp', 'chair', 'chaise_longue_sofa', 'children_cabinet', 'chinese_chair',
        #                       'coffee_table', 'console_table', 'corner_side_table', 'desk', 'dining_chair', 'dining_table', 'double_bed', 'dressing_chair',
        #                       'dressing_table', 'kids_bed', 'l_shaped_sofa', 'lazy_sofa', 'lounge_chair', 'loveseat_sofa', 'multi_seat_sofa', 'nightstand', 'pendant_lamp',
        #                       'round_end_table', 'shelf', 'single_bed', 'sofa', 'stool', 'table', 'tv_stand', 'wardrobe', 'wine_cabinet']

        # Why we do this here? It is because that we want to make each category evenly sampled during the diffusion
        # training. You can see the objects are classified/mapped into coarse categories, which will cause the number
        # of objects in each category are very different from each other. For example, Chairs are the most and lamps
        # are the fewest among all the objects. So when we sample a batch in more fine-grained classes, the problem
        # can be alleviated.
        self.vocab['object_idx_to_name_grained'] = self.vocab['object_idx_to_name']

        if not self.large:
            self.fine_grained_classes = dict(zip(sorted([voc.strip('\n') for voc in self.vocab['object_idx_to_name']]),range(len(self.vocab['object_idx_to_name']))))
            #print('without mapping_full2simple', len(self.vocab['object_idx_to_name'] ))
            self.vocab['object_idx_to_name'] = [self.mapping_full2simple[voc.strip('\n')]+'\n' for voc in self.vocab['object_idx_to_name']]
            #e.g self.classes  {'_scene_': 0, 'bookshelf': 1, 'cabinet': 2, 'chair': 3, 'desk': 4, 'floor': 5, 'lamp': 6, 'shelf': 7, 'sofa': 8, 'table': 9, 'tv_stand': 10, 'wardrobe': 11}

            self.classes = dict(zip(sorted(list(set([voc.strip('\n') for voc in self.vocab['object_idx_to_name']]))),
                                        range(len(list(set(self.vocab['object_idx_to_name']))))))
            self.classes_r = dict(zip(self.classes.values(), self.classes.keys()))
            points_classes = list(set([self.mapping_full2simple[class_] for class_ in points_classes]))

        points_classes_idx = [self.classes[pc] for pc in points_classes]

        self.point_classes_idx = points_classes_idx + [0]
        self.sorted_cat_list = sorted(self.cat)
        self.files = {}
        self.eval_type = eval_type
        # check if all shape features exist. If not they get generated here (once)
        if with_feats:
            print('Assume you downloaded the DeepSDF codes. If not, please download in README.md')
            # for index in tqdm(range(len(self))):
            #     self.__getitem__(index)
            self.recompute_feats = False


    def read_relationship_json(self, json_file, box_json_file):
        """ Reads from json files the relationship labels, objects and bounding boxes

        :param json_file: file that stores the objects and relationships
        :param box_json_file: file that stores the oriented 3D bounding box parameters
        :return: three dicts, relationships, objects and boxes
        """
        rel = {}
        objs = {}
        tight_boxes = {}

        with open(box_json_file, "r") as read_file:
            box_data = json.load(read_file)

        with open(json_file, "r") as read_file:
            data = json.load(read_file)
            for scan in data['scans']:

                relationships = []
                for relationship in scan["relationships"]:
                    relationship[2] -= 1
                    relationships.append(relationship)

                # for every scan in rel json, we append the scan id
                rel[scan["scan"]] = relationships
                self.scans.append(scan["scan"])

                objects = {}
                boxes = {}
                for k, v in scan["objects"].items():
                    # if not self.large:
                    #     objects[int(k)] = self.mapping_full2simple[v]
                    # else:
                    objects[int(k)] = v

                    try:
                        boxes[int(k)] = {}
                        boxes[int(k)]['param7'] = box_data[scan["scan"]][k]["param7"]
                        boxes[int(k)]['param7'][6] = boxes[int(k)]['param7'][6]
                        boxes[int(k)]['scale'] = box_data[scan["scan"]][k]["scale"]
                    except Exception as e:
                        # probably box was not saved because there were 0 points in the instance!
                        print(e)
                    try:
                        boxes[int(k)]['model_path']  = box_data[scan["scan"]][k]["model_path"]
                    except Exception as e:
                        print(e)
                        continue
                boxes["scene_center"] = box_data[scan["scan"]]["scene_center"]
                objs[scan["scan"]] = objects
                tight_boxes[scan["scan"]] = boxes
        return rel, objs, tight_boxes

    def read_relationships(self, read_file):
        """load list of relationship labels

        :param read_file: path of relationship list txt file
        """
        relationships = []
        with open(read_file, 'r') as f:
            for line in f:
                relationship = line.rstrip().lower()
                relationships.append(relationship)
        return relationships

    def norm_points(self, p):
        centroid = np.mean(p, axis=0)
        m = np.max(np.sqrt(np.sum(p ** 2, axis=1)))
        p = (p - centroid) / float(m)
        return p

    def get_key(self, dict, value):
        for k, v in dict.items():
            if v == value:
                return k
        return None

    def __getitem__(self, index):        # read data of every scene
        scan_id = self.scans[index]
        #print("scan_id", scan_id)

        # instance2label {1: 'armchair', 2: 'multi_seat_sofa', 3: 'stool', 4: 'coffee_table', 5: 'corner_side_table', 6: 'corner_side_table', 7: 'dining_table', 8: 'dining_chair', 9: 'dining_chair', 10: 'dining_chair', 11: 'dining_chair', 12: 'tv_stand', 13: 'cabinet', 14: 'pendant_lamp', 15: 'pendant_lamp', 16: 'floor'}

        instance2label = self.objs_json[scan_id] # instance id based on the json file
        #print("instance2label", instance2label)
        keys = list(instance2label.keys())  #e.g original keys [1, 2, 3, 4, 5]
        #print(" original keys", keys)

        if self.shuffle_objs:
            random.shuffle(keys)

        feats_in = None

        feats_path = self.root + "/DEEPSDF_reconstruction/Codes/"

        # Load points for debug
        if self.with_feats and (not os.path.exists(feats_path) or self.recompute_feats):
            if scan_id in self.files: # Caching
                (points_list, points_norm_list, instances_list) = self.files[scan_id]
            else:
                points_list=np.array([]).reshape(-1,3)
                points_norm_list = np.array([]).reshape(-1, 3)
                instances_list=np.array([]).reshape(-1,1)
                for key_, value_ in self.tight_boxes_json[scan_id].items():
                    if isinstance(key_,int):
                        path = self.tight_boxes_json[scan_id][key_]["model_path"]
                        # object points
                        if path is not None:
                            raw_mesh = trimesh.load(path)
                            position = self.tight_boxes_json[scan_id][key_]["param7"][3:6]
                            theta = self.tight_boxes_json[scan_id][key_]["param7"][-1]
                            R = np.zeros((3, 3))
                            R[0, 0] = np.cos(theta)
                            R[0, 2] = -np.sin(theta)
                            R[2, 0] = np.sin(theta)
                            R[2, 2] = np.cos(theta)
                            R[1, 1] = 1.
                            points = raw_mesh.copy().vertices
                            point_norm = self.norm_points(points) #normliazed in each individual boxes
                            points = points.dot(R) + position # not centered yet
                        # floor points
                        else:
                            position = self.tight_boxes_json[scan_id][key_]["param7"][3:6]
                            l,w = self.tight_boxes_json[scan_id][key_]["param7"][0], self.tight_boxes_json[scan_id][key_]["param7"][2]
                            x = l * np.random.random(1000)+ position[0] - l/2
                            z = w * np.random.random(1000)+ position[2] - w/2
                            y = np.repeat(0,1000)
                            points = np.vstack((x,y,z)).transpose()
                            point_norm = self.norm_points(points)
                        points_list = np.concatenate((points_list, points), axis=0)
                        points_norm_list = np.concatenate((points_norm_list, point_norm), axis=0)
                        instances = np.repeat(key_, points.shape[0]).reshape(-1, 1)
                        instances_list = np.concatenate((instances_list, instances), axis=0)

                if self.fm.user_free > 5:
                    self.files[scan_id] = (points_list, points_norm_list, instances_list)

            print("shifting points")
            points_list = points_list - np.array(self.tight_boxes_json[scan_id]['scene_center']) # centered in the scene

        instance2mask = {}
        instance2mask[0] = 0

        cat_ids = []
        cat_ids_grained = []
        tight_boxes = []

        counter = 0

        instances_order = []
        selected_shapes = []
        #print("keys", keys)
        #e.g keys [2, 1, 5, 4, 3]         keys: whole instance ids
        for key in keys:
            # get objects from the selected list of classes of 3dssg
            scene_instance_id = key
            scene_instance_class = instance2label[key]
            if not self.large:
                scene_class_id_grained = self.fine_grained_classes[scene_instance_class]
                scene_instance_class = self.mapping_full2simple[scene_instance_class]
                scene_class_id = self.classes[scene_instance_class]    # transter into class id in the entire dataset

            else:
                scene_class_id = self.classes[scene_instance_class] # class id in the entire dataset (one objects has same calss ID across all scenes)


            # print( "scene_class_id", scene_class_id)
            # print("self.classes ", self.classes)
            instance2mask[scene_instance_id] = counter + 1  #recover to the original instance id
            counter += 1

            # mask to cat:
            if (scene_class_id >= 0) and (scene_instance_id > 0):
                selected_shapes.append(True)
                cat_ids.append(scene_class_id)
                if not self.large:
                    cat_ids_grained.append(scene_class_id_grained)
                else:
                    cat_ids_grained.append(scene_class_id)
                bbox = np.array(self.tight_boxes_json[scan_id][key]['param7'].copy())
                bbox[3:6] -= np.array(self.tight_boxes_json[scan_id]['scene_center'])

                instances_order.append(key)
                #bins = np.linspace(np.deg2rad(-180), np.deg2rad(180), 24)
                #angle = np.digitize(bbox[6], bins)
                angle = bbox[6]
                bbox = normalize_box_params(bbox,file=self.box_normalized_stats)
                bbox[6] = angle

                tight_boxes.append(bbox)



        if self.with_feats:
            # If precomputed features exist, we simply load them
            latents = []
            #for key_, value_ in self.tight_boxes_json[scan_id].items():
            for key_ in instances_order: # get the objects in order
                if isinstance(key_, int):
                    path = self.tight_boxes_json[scan_id][key_]["model_path"]
                    if path is None:
                        latent_code = np.zeros([1, 256]) #for the floor, latent_code.shape[1]=256
                        #print("why is it none?")
                    else:
                        model_id = path.split('/')[-2]
                        latent_code_path = feats_path + model_id + "/sdf.pth"
                        latent_code = torch.load(latent_code_path, map_location="cpu")[0]
                        latent_code = latent_code.detach().numpy()
                    latents.append(latent_code)
            latents.append(np.zeros([1, 256])) # for the room shape
            feats_in = list(np.concatenate(latents, axis=0))

        triples = []
        words = []
        #print('instance2mask', instance2mask)
        rel_json = self.relationship_json[scan_id]
        for r in rel_json: # create relationship triplets from data
            if r[0] in instance2mask.keys() and r[1] in instance2mask.keys():
                subject = instance2mask[r[0]] - 1   #instnce id to local ID(align with the cat_ids)
                #print('instance2mask[r[0]] ', instance2mask[r[0]])
                #print('r[0] ', r[0])
                object = instance2mask[r[1]] - 1
                predicate = r[2] + 1
                if subject >= 0 and object >= 0:
                    triples.append([subject, predicate, object])
                    if not self.large:
                        words.append(self.mapping_full2simple[instance2label[r[0]]] + ' ' + r[3] + ' ' + self.mapping_full2simple[instance2label[r[1]]])
                    else:
                        words.append(instance2label[r[0]]+' '+r[3]+' '+instance2label[r[1]]) # TODO check
            else:
                continue

        output = {}

        # prepare outputs
        output['encoder'] = {}
        output['encoder']['objs'] = cat_ids
        output['encoder']['objs_grained'] = cat_ids_grained # not needed for encoder
        output['encoder']['triples'] = triples
        output['encoder']['boxes'] = tight_boxes
        output['encoder']['words'] = words


        if self.with_feats:
            output['encoder']['feats'] = feats_in

        output['gt'] = copy.deepcopy(output['encoder'])

        # torchify
        output['encoder']['objs'] = torch.from_numpy(np.array(output['encoder']['objs'], dtype=np.int64)) # this is changed
        output['encoder']['objs_grained'] = torch.from_numpy(np.array(output['encoder']['objs_grained'], dtype=np.int64)) # this doesn't matter
        output['encoder']['triples'] = torch.from_numpy(np.array(output['encoder']['triples'], dtype=np.int64))
        output['encoder']['boxes'] = torch.from_numpy(np.array(output['encoder']['boxes'], dtype=np.float32))
        if self.with_feats:
            output['encoder']['feats'] = torch.from_numpy(np.array(output['encoder']['feats'], dtype=np.float32))

        # these two should have the same amount.
        output['gt']['objs'] = torch.from_numpy(np.array(output['gt']['objs'], dtype=np.int64))
        output['gt']['objs_grained'] = torch.from_numpy(np.array(output['gt']['objs_grained'], dtype=np.int64))

        output['gt']['triples'] = torch.from_numpy(np.array(output['gt']['triples'], dtype=np.int64)) # this is changed
        output['gt']['boxes'] = torch.from_numpy(np.array(output['gt']['boxes'], dtype=np.float32))

        if self.with_feats:
            output['gt']['feats'] = torch.from_numpy(np.array(output['gt']['feats'], dtype=np.float32))


        output['scan_id'] = scan_id
        output['instance_id'] = instances_order

        valid_mask = torch.tensor([
        (self.classes_r[int(cid)] not in ['floor', '_scene_'])
        for cid in output['encoder']['objs']], dtype=torch.bool)

        output['encoder']['valid_mask'] = valid_mask          # later used

        output['steps'] = self.split_to_incremental(output,2, 12)   

        return output

    def split_to_incremental(self, output,
                            r_clu: float = 2.0,
                            r_touch: float = 8.0):


        # ========= 0. tensor =========
        centers = np.asarray([b[3:6] for b in output['encoder']['boxes'].numpy()])
        valid   = output['encoder']['valid_mask']          # BoolTensor (N,)
        idx_all = np.arange(len(centers))[valid]           

        # ========= 1. DBSCAN  =========
        lab = DBSCAN(eps=r_clu, min_samples=1).fit(centers[idx_all]).labels_
        clusters = defaultdict(list)
        for idx, cid in zip(idx_all, lab):
            clusters[cid].append(idx)

        # ========= 2. graph =========
        G = defaultdict(list)
        for a, na in clusters.items():
            for b, nb in clusters.items():
                if a >= b: continue
                if np.min([np.linalg.norm(centers[i]-centers[j])
                        for i in na for j in nb]) < r_touch:
                    G[a].append(b); G[b].append(a)
        for cid in clusters.keys():
            G.setdefault(cid, [])

        # ========= 3. order =========
        clu_center = {cid: np.mean(centers[idxs], 0) for cid, idxs in clusters.items()}
        g_center   = np.mean(list(clu_center.values()), 0)
        start_cid  = max(clu_center, key=lambda c: np.linalg.norm(clu_center[c]-g_center))

        cid_list   = list(clu_center)
        idx_of     = {cid:i for i, cid in enumerate(cid_list)}
        C          = np.stack([clu_center[c] for c in cid_list])
        dist_mat   = np.linalg.norm(C[:, None] - C[None, :], axis=-1)

        N          = len(cid_list)
        path_idx   = [idx_of[start_cid]]
        unvisited  = set(range(N)) - {path_idx[0]}
        while unvisited:
            cur = path_idx[-1]
            nxt = min(unvisited, key=lambda j: dist_mat[cur, j])
            path_idx.append(nxt); unvisited.remove(nxt)

        def two_opt(p):
            improved = True; iters = 0
            while improved and iters < 2 * N:
                improved = False; iters += 1
                for i in range(1, N - 2):
                    for j in range(i + 1, N - 1):
                        a, b = p[i-1], p[i]
                        c, d = p[j], p[j+1]
                        if dist_mat[a, b] + dist_mat[c, d] > dist_mat[a, c] + dist_mat[b, d]:
                            p[i:j+1] = reversed(p[i:j+1])
                            improved = True
                if not improved: break
            return p

        path_idx = two_opt(path_idx)
        clu_order = [cid_list[i] for i in path_idx]

        # ========= 4. steps =========
        steps, visited = [], set()
        objs   = output['encoder']['objs']
        objs_g = output['encoder']['objs_grained']
        boxes  = output['encoder']['boxes']
        triples= output['encoder']['triples']
        feats  = output['encoder'].get('feats')

        if not hasattr(torch, 'isin'):
            torch_isin = lambda a, b: (a.unsqueeze(1) == b.unsqueeze(0)).any(dim=1)
        else:
            torch_isin = torch.isin

        prev_nodes = []  

        for cid in clu_order:
            new_nodes = [i for i in clusters[cid] if i not in visited]
            if not new_nodes:
                continue

            ctx_nodes = prev_nodes[:]                    

            order = new_nodes + ctx_nodes

            def _sel(tensor):
                return tensor[order] if tensor is not None else None

            if triples.numel() > 0:
                s = triples[:, 0]; p = triples[:, 1]; o = triples[:, 2]
                device = triples.device

                new_t = torch.tensor(new_nodes, dtype=torch.long, device=device)
                mask_nn = torch_isin(s, new_t) & torch_isin(o, new_t)    

                if len(ctx_nodes) > 0:
                    ctx_t = torch.tensor(ctx_nodes, dtype=torch.long, device=device)
                    mask_nc = (torch_isin(s, new_t) & torch_isin(o, ctx_t)) | (torch_isin(s, ctx_t) & torch_isin(o, new_t))   # new↔prev
                    mask = mask_nn | mask_nc
                else:
                    mask = mask_nn

                if mask.any():
                    triple_local = triples[mask].clone()
                    remap = {int(gid): lid for lid, gid in enumerate(order)}
                    for col in (0, 2):
                        triple_local[:, col] = torch.tensor(
                            [remap[int(x)] for x in triple_local[:, col]],
                            dtype=torch.long, device=triple_local.device)
                else:
                    triple_local = torch.empty(0, 3, dtype=torch.long, device=triples.device)
            else:
                triple_local = torch.empty(0, 3, dtype=torch.long, device=triples.device)

            step = {
                'new_idx': torch.arange(len(new_nodes), dtype=torch.long),
                'old_idx': torch.arange(len(new_nodes), len(order), dtype=torch.long),
                'objs'   : _sel(objs),
                'objs_g' : _sel(objs_g),
                'boxes'  : _sel(boxes),
                'triples': triple_local,      # new-new and new↔prev 
                'feats'  : _sel(feats)
            }
            steps.append(step)

            visited.update(new_nodes)
            prev_nodes = new_nodes[:]       
        return steps

    def __len__(self):
        if self.data_len is not None:
            return self.data_len
        else:
            return len(self.scans)


    def collate_fn_vaegan(self, batch, use_points=False):
        """
        Collate function to be used when wrapping a RIODatasetSceneGraph in a
        DataLoader. Returns a dictionary
        """

        out = {}

        out['scene_points'] = []
        out['scan_id'] = []
        out['instance_id'] = []

        global_node_id = 0
        global_dec_id = 0
        for i in range(len(batch)):
            if batch[i] == -1:
                return -1
            # notice only works with single batches
            out['scan_id'].append(batch[i]['scan_id'])
            out['instance_id'].append(batch[i]['instance_id'])

            if 'scene' in batch[i]:
                out['scene_points'].append(batch[i]['scene'])

            global_node_id += len(batch[i]['encoder']['objs'])
            global_dec_id += len(batch[i]['gt']['objs'])

        for key in ['encoder', 'gt']:
            all_objs, all_boxes, all_triples = [], [], []
            all_objs_grained = []
            all_obj_to_scene, all_triple_to_scene = [], []
            all_points = []
            all_feats = []
            all_text_feats = []
            all_rel_feats = []

            obj_offset = 0

            for i in range(len(batch)):
                if batch[i] == -1:
                    print('this should not happen')
                    continue
                (objs, triples, boxes) = batch[i][key]['objs'], batch[i][key]['triples'], batch[i][key]['boxes']

                if 'points' in batch[i][key]:
                    all_points.append(batch[i][key]['points'])

                if 'feats' in batch[i][key]:
                    all_feats.append(batch[i][key]['feats'])

                num_objs, num_triples = objs.size(0), triples.size(0)

                all_objs.append(batch[i][key]['objs'])
                all_objs_grained.append(batch[i][key]['objs_grained'])
                all_boxes.append(boxes)

                if triples.dim() > 1:
                    triples = triples.clone()
                    triples[:, 0] += obj_offset
                    triples[:, 2] += obj_offset

                    all_triples.append(triples)
                    all_triple_to_scene.append(torch.LongTensor(num_triples).fill_(i))

                all_obj_to_scene.append(torch.LongTensor(num_objs).fill_(i))

                obj_offset += num_objs

            all_objs = torch.cat(all_objs)
            all_objs_grained = torch.cat(all_objs_grained)
            all_boxes = torch.cat(all_boxes)

            all_obj_to_scene = torch.cat(all_obj_to_scene)

            if len(all_triples) > 0:
                all_triples = torch.cat(all_triples)
                all_triple_to_scene = torch.cat(all_triple_to_scene)
            else:
                return -1

            outputs = {'objs': all_objs,
                       'objs_grained': all_objs_grained,
                       'tripltes': all_triples,
                       'boxes': all_boxes,
                       'obj_to_scene': all_obj_to_scene,
                       'triple_to_scene': all_triple_to_scene}


            if len(all_points) > 0:
                all_points = torch.cat(all_points)
                outputs['points'] = all_points

            if len(all_feats) > 0:
                all_feats = torch.cat(all_feats)
                outputs['feats'] = all_feats

            out[key] = outputs

        return out

    def collate_fn_inc(self, batch):
        flat_steps = []
        step_idx_vec = []          
        step_batch_idx = []        
        for bidx, scene in enumerate(batch):
            for k, step in enumerate(scene['steps']):
                flat_steps.append(step)
                step_idx_vec.append(k)
                step_batch_idx.append(bidx)

        objs_all, objs_g_all, boxes_all = [], [], []
        feats_all = []
        new_mask, old_mask = [], []
        triples_all = []

        obj_to_step          = []  
        obj_to_scene         = []  

        triple_to_step       = []  
        triple_to_scene      = []  

        obj_offset = 0
        #print('one batch -----------------------------')
        for flat_sid, step in enumerate(flat_steps):
            k_in_scene = step_idx_vec[flat_sid]
            bidx       = step_batch_idx[flat_sid]

            n_obj = step['objs'].size(0)
            objs_all.append(step['objs'])
            objs_g_all.append(step['objs_g'])
            boxes_all.append(step['boxes'])
            #print('flat_sid', flat_sid)
            #print('k_in_scene', k_in_scene)
            obj_to_step.append(torch.full((n_obj,), k_in_scene, dtype=torch.long))
            obj_to_scene.append(torch.full((n_obj,), bidx,      dtype=torch.long))

            nm = torch.zeros(n_obj, dtype=torch.bool); nm[step['new_idx']] = True
            om = torch.zeros(n_obj, dtype=torch.bool); om[step['old_idx']] = True
            new_mask.append(nm); old_mask.append(om)

            t = step['triples']
            if t.size(0) > 0:
                t = t.clone()
                t[:, 0] += obj_offset
                t[:, 2] += obj_offset
                triples_all.append(t)
                triple_to_step.append(torch.full((t.size(0),), k_in_scene, dtype=torch.long))
                triple_to_scene.append(torch.full((t.size(0),), bidx,      dtype=torch.long))

            if step.get('feats', None) is not None:
                feats_all.append(step['feats'])

            obj_offset += n_obj

        out_enc = {
            'objs'              : torch.cat(objs_all, 0),
            'objs_g'            : torch.cat(objs_g_all, 0),
            'boxes'             : torch.cat(boxes_all, 0),
            'obj_to_step'       : torch.cat(obj_to_step, 0),        
            'obj_to_scene'      : torch.cat(obj_to_scene, 0),
            'new_mask'          : torch.cat(new_mask, 0),
            'old_mask'          : torch.cat(old_mask, 0),
        }

        if len(triples_all) > 0:
            out_enc['triples']         = torch.cat(triples_all, 0)
            out_enc['triple_to_step']  = torch.cat(triple_to_step, 0)   
            out_enc['triple_to_scene'] = torch.cat(triple_to_scene, 0)  
        else:
            out_enc['triples']         = torch.empty(0, 3, dtype=torch.long)
            out_enc['triple_to_step']  = torch.empty(0,     dtype=torch.long)
            out_enc['triple_to_scene'] = torch.empty(0,     dtype=torch.long)

        if feats_all:
            out_enc['feats'] = torch.cat(feats_all, 0)

        step_meta = {
            'step_to_batch': torch.tensor(step_batch_idx, dtype=torch.long),
            'scan_id_str'  : [batch[i]['scan_id'].strip() for i in step_batch_idx],
        }
        return {'encoder': out_enc, 'step_meta': step_meta}
    # def collate_fn_inc(self,batch):

    #     #print('batch', batch)
    #     flat_steps   = []          
    #     idx_in_batch_for_every_step_vec = []          #  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3] which batch scene dose each step belong to
    #     step_idx_vec = []          # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
    #     for batch_id, scene in enumerate(batch):  # This batch is a list of scenes, each scene is a dict with 'steps' key 
    #         #print('scene ',scene )
    #         for k, step in enumerate(scene['steps']):
    #             flat_steps.append(step)
    #             idx_in_batch_for_every_step_vec.append(batch_id)
    #             #print('idx_in_batch_vec', idx_in_batch_vec)
    #             step_idx_vec.append(k)
    #             #print('step_idx_vec', step_idx_vec)
    #     #print('flat_steps', flat_steps)
    #     objs_all, objs_g_all, boxes_all = [], [], []
    #     triples_all = []
    #     feats_all   = []
    #     new_mask, old_mask = [], []

    #     obj_to_step   = []     
    #     triple_to_step= []     
    #     obj_offset = 0

    #     for step_id, step in enumerate(flat_steps):
    #         n_obj = step['objs'].size(0)
    #         objs_all.append(step['objs'])   # the number of objs_all is growing with objects repearting addingly
    #         #print('objs_all', len(objs_all))
    #         objs_g_all.append(step['objs_g'])
    #         boxes_all.append(step['boxes'])

    #         obj_to_step.append(torch.full((n_obj,), step_id, dtype=torch.long))
    #         #print('obj_to_step', obj_to_step)

    #         # new/old mask
    #         nm = torch.zeros(n_obj, dtype=torch.bool)
    #         om = torch.zeros(n_obj, dtype=torch.bool)
    #         nm[step['new_idx']] = True
    #         om[step['old_idx']] = True
    #         new_mask.append(nm); old_mask.append(om)

    #         #remap triples becauae batch merges objects from different steps
    #         t = step['triples']
    #         if t.size(0) > 0:
    #             t = t.clone()
    #             t[:, 0] += obj_offset
    #             t[:, 2] += obj_offset
    #             triples_all.append(t)
    #             triple_to_step.append(torch.full((t.size(0),), step_id, dtype=torch.long))

    #         if step.get('feats', None) is not None:
    #             feats_all.append(step['feats'])

    #         obj_offset += n_obj

    #     out_enc = {
    #         'objs'     : torch.cat(objs_all, 0),
    #         'objs_g'   : torch.cat(objs_g_all, 0),
    #         'boxes'    : torch.cat(boxes_all, 0),
    #         'obj_to_step': torch.cat(obj_to_step, 0),
    #         'new_mask' : torch.cat(new_mask, 0),
    #         'old_mask' : torch.cat(old_mask, 0)
    #     }

    #     if len(triples_all) > 0:
    #         out_enc['triples']        = torch.cat(triples_all, 0)
    #         out_enc['triple_to_step'] = torch.cat(triple_to_step, 0)
    #     else:
    #         out_enc['triples']        = torch.empty(0, 3, dtype=torch.long)
    #         out_enc['triple_to_step'] = torch.empty(0,     dtype=torch.long)


    #     if feats_all:
    #         out_enc['feats'] = torch.cat(feats_all, 0)

    #     step_meta = {
    #     'step_to_batch'     : torch.tensor(idx_in_batch_for_every_step_vec, dtype=torch.long),
    #     #'step_in_scene': torch.tensor(step_idx_vec, dtype=torch.long),
    #     'scan_id_str': [batch[i]['scan_id'].strip() for i in idx_in_batch_for_every_step_vec] 
    #     }
    #     return {'encoder': out_enc, 'step_meta': step_meta}
    # def collate_fn_vaegan_points(self,batch):
    #     """ Wrapper of the function collate_fn_vaegan to make it also return points
    #     """
    #     return self.collate_fn_vaegan(batch, use_points=True)
    
class LengthGroupedCurriculumSampler(data.BatchSampler):

    def __init__(self, dataset, batch_size: int, drop_last: bool = False):
        self.dataset     = dataset
        self.batch_size  = batch_size
        self.drop_last   = drop_last

        self.len_buckets = defaultdict(list)          # {n_steps: [scene_idx,…]}
        for idx in range(len(dataset)):
            n_steps = len(dataset[idx]['steps'])
            self.len_buckets[n_steps].append(idx)

        self.sorted_buckets = sorted(self.len_buckets.items(), key=lambda x: x[0])

    # -------------------------------------------------
    def __iter__(self):
        for _, idx_list in self.sorted_buckets:       
            random.shuffle(idx_list)                 
            for i in range(0, len(idx_list), self.batch_size):
                chunk = idx_list[i:i+self.batch_size]
                if len(chunk) == self.batch_size or not self.drop_last:
                    yield chunk                     

    # -------------------------------------------------
    def __len__(self):
        total = 0
        for _, idx_list in self.sorted_buckets:
            n_full = len(idx_list) // self.batch_size
            total += n_full
            if len(idx_list) % self.batch_size != 0 and not self.drop_last:
                total += 1
        return total
class RandomStepGroupedSampler(data.BatchSampler):
    def __init__(self, dataset, batch_size: int, drop_last: bool = False):
        self.dataset    = dataset
        self.batch_size = batch_size
        self.drop_last  = drop_last

        self.buckets = defaultdict(list)             # {n_steps: [idx, …]}
        for idx in range(len(dataset)):
            n_steps = len(dataset[idx]['steps'])
            self.buckets[n_steps].append(idx)

        for lst in self.buckets.values():
            random.shuffle(lst)

    # -------------------------------------------------
    def __iter__(self):
        remain = {k: v.copy() for k, v in self.buckets.items()}
        active_keys = [k for k, v in remain.items() if len(v)]

        while active_keys:
            k = random.choice(active_keys)
            bucket = remain[k]

            if len(bucket) >= self.batch_size:
                chunk = [bucket.pop() for _ in range(self.batch_size)]
                yield chunk
            else:
                if not self.drop_last:
                    yield [bucket.pop() for _ in range(len(bucket))]
                bucket.clear()  

            active_keys = [key for key, v in remain.items() if len(v)]

    # -------------------------------------------------
    def __len__(self):
        total = 0
        for lst in self.buckets.values():
            n_full = len(lst) // self.batch_size
            total += n_full
            if len(lst) % self.batch_size != 0 and not self.drop_last:
                total += 1
        return total

if __name__ == "__main__":
    dataset = ThreedFrontDatasetSceneGraphIncremental(
        root="/media/xxx/xxx_ssd/FRONT",
        split='val_scans',
        shuffle_objs=True,
        with_feats=True,
        large=False,
        seed=False,
        room_type='livingroom')
    
    # batch_size = 4
    # #sampler = LengthGroupedCurriculumSampler(dataset, batch_size)
    # sampler =RandomStepGroupedSampler(dataset, batch_size)
    # loader  = data.DataLoader(dataset,
    #                     batch_sampler=sampler,
    #                     collate_fn=dataset.collate_fn_inc,
    #                     num_workers=4)
    
    # for bi, batch in enumerate(loader):
    #     names = [s.strip() for s in batch['step_meta']['scan_id_str']]
    #     uniq  = list(dict.fromkeys(names))                 
    #     #print('uniq[0] = ', uniq[0])

    #     #print('dataset.scan_name2idx ',dataset.scan_name2idx)
    #     #print('scan_name2idx ', dataset.scan_name2idx)
    #     first_idx = dataset.scan_name2idx[uniq[0]]         # str → int
    #     n_steps   = len(dataset[first_idx]['steps'])

    #     print(f'\n=== Mini-batch {bi} | each scene has {n_steps} steps ===')
    #     for nm in uniq:
    #         print(f'  • {nm:<25}  total_steps: {n_steps}')
    # for bi, idx_batch in enumerate(sampler):
    #     print(f"\n=== Mini-batch {bi} ===")
    #     first_len = len(dataset[idx_batch[0]]['steps'])
    #     for si in idx_batch:
    #         sid = dataset[si]['scan_id']
    #         print(f"  • {sid:<25}  total_steps: {first_len}")


    scene0_steps = dataset[139]['steps']   
    print(f'scene_id : {dataset[188]["scan_id"]}')

    print(f'there are {len(scene0_steps)} steps\n')

    for sid, step in enumerate(scene0_steps):
        new_ids  = step['new_idx'].tolist()      
        old_ids  = step['old_idx'].tolist()      

        obj_ids  = step['objs'].cpu().numpy()
        new_cats = [dataset.classes_r[int(obj_ids[i])] for i in new_ids]
        old_cats = [dataset.classes_r[int(obj_ids[i])] for i in old_ids]

        box_np   = step['boxes'].cpu().numpy()
        new_box  = box_np[new_ids]   # shape: (n_new, 7)
        old_box  = box_np[old_ids]   # shape: (n_old, 7)

        triples  = step['triples'].cpu().numpy()
        rel_txt  = ['in' if int(t[1]) == 0
                    else dataset.relationships_dict_r[int(t[1])]
                    for t in triples]
        
        print(f'─ Step {sid:02d} ──────────────────────────')
        for cat, bp in zip(new_cats, new_box):
            print(f'  NEW  [{cat:>10}]  box: {np.round(bp, 3)}')
        for cat, bp in zip(old_cats, old_box):
            print(f'  OLD  [{cat:>10}]  box: {np.round(bp, 3)}')
        print('  relationships:', rel_txt, '\n')

    #------------- END -------------