import os
import numpy as np
import trimesh
import math
import random
# import matplotlib.pyplot as plt
from pyrr import Matrix44
from simple_3dviz import Scene
from collections import Counter, defaultdict
from num2words import num2words
from nltk.corpus import cmudict
from utils import floor_plan_renderable, render
from llm_handler import LlamaHandler
from furni_rel import get_furni_rel

LLM_INSTRUCTION = """Instruction: You are given texts describing the positions of furniture in a room. The texts are repetitive and unnatural. Your task is to transform them into short and natural English descriptions that preserve the essential spatial relationships.

Chain of Thought Reminder: Think through the text step by step, consider how a person would naturally describe the scene, then generate the most concise and human-sounding summary possible.

1. Read the text carefully.
2. Identify the main pieces of furniture.
3. Use chain-of-thought reasoning to summarize the scene in a short, natural way.
4. Focus on essential directions (left, right, behind, in front) without sounding repetitive.
5. Provide the final concise descriptions.

Finally, please remember you can only output one short and natural sentence to describe the room. 

The following is the provided text:
"""

# LLM_INSTRUCTION = """Instruction: You are given texts describing the positions of furniture in a room. The texts are repetitive and unnatural. Your task is to transform them into short and natural English descriptions that preserve the essential spatial relationships.
# Please remember you can only output one short and natural sentence to describe the room. 
# The following is the provided text: 
# """

def generate_messages(text):
    return [
        {
            "role":"system",
            "content":"You are helpful AI"
        },
        {
            "role":"user",
            "content":"Please rephrase the following text, and directly tell me the answer: \"{}\"".format(text)
        }
    ]

def get_furniture_json(furni_labels, transformed_bboxes):
    json_dict = {"furniture": []}
    for furni_label, transformed_bbox in zip(furni_labels, transformed_bboxes):
        # construct the json
        json_dict["furniture"].append({
            "label": furni_label,
            "bbox": transformed_bbox
        })
    return json_dict
        
def get_llm_prompt(instruction, furni_json) -> str:
    """
    格式化输出函数，按照指定格式返回字符串：
    {
    instruction

    furni_json
    }
    """
    prompt= f"{instruction}\n{furni_json}"
    # print("prompt: ", prompt)
    return prompt

# def get_processed_room_legacy(scene, class_labels, aug_prob, llm_handler, pointcloud_size=2048):
#     class_num = len(class_labels)
#     num_objs = len(scene.bboxes)
    
#     proc_scene_dict = {}
#     proc_scene_dict["one_hot_class_labels"] = np.zeros((num_objs, class_num), dtype=np.float32)
#     proc_scene_dict["room_layout"], _ = get_room_layout(scene)
#     # We transform the room layout 3d to point cloud
#     # print("your layout: ", proc_scene_dict["room_layout"].shape)
#     # use plt to plot the room layout
#     # plt.imsave("output2/room_layout.png", proc_scene_dict["room_layout"][:, :, 0])
#     # exit()
#     proc_scene_dict["centriod"] = scene.centroid
#     proc_scene_dict["all_furni"] = [] # We add all the furni to this list
#     for index, furni in enumerate(scene.bboxes):
#         proc_scene_dict["one_hot_class_labels"][index] = furni.one_hot_label(class_labels)
#         furni_dict = {}
#         furni_dict["label"] = furni.label # How to embed this label?
#         furni_dict["raw_model_path"] = os.path.join(*furni.raw_model_path.split("/")[-2:])
#         furni_dict["scale"] = furni.scale
#         # get point cloud
#         furni_dict["point_cloud"] = get_furni_point_cloud(furni, pointcloud_size=pointcloud_size)
#         # print("furni_dict['point_cloud']: ", furni_dict["point_cloud"])
#         # get position information
#         furni_dict["centroid"] = furni.centroid()
#         furni_dict["translation"] = get_furni_translations(furni, scene)
#         furni_dict["z_angle"] = furni.z_angle
#         furni_dict["size"] = furni.size
#         furni_dict["bbox"] = get_furni_bbox(furni_dict["translation"], furni_dict["size"])
#         proc_scene_dict["all_furni"].append(furni_dict)
    
#     proc_scene_dict["text_des_orig_list"] = get_text_des(proc_scene_dict, class_labels)
#     proc_scene_dict["text_des"] = "".join(proc_scene_dict["text_des_orig_list"])
#     # llm augmentation
#     if aug_prob > 0:
#         # We merge all the text list into one string
#         # print("use LLM aug, input_messages: ", proc_scene_dict["text_des"])
#         proc_scene_dict["text_des"] = llm_handler.gen_text(generate_messages(proc_scene_dict["text_des"]))
#     return proc_scene_dict

def get_processed_room_dict(scene, class_labels, aug_prob, llm_handler, pointcloud_size=2048):
    class_num = len(class_labels)
    num_objs = len(scene.bboxes)
    
    proc_scene_dict = {}
    proc_scene_dict["one_hot_class_labels"] = np.zeros((num_objs, class_num), dtype=np.float32)
    # proc_scene_dict["room_layout"], proc_scene_dict["room_pc"] = get_room_layout(scene) # contains serious bugs, lowing down the speed
    proc_scene_dict["centroid"] = scene.centroid
    proc_scene_dict["floor_plan_centroid"] = scene.floor_plan_centroid
    proc_scene_dict["scene_id"] = scene.scene_id
    # 修改 all_furni 的结构，现在是字典，其中每个key都指向一个空列表
    proc_scene_dict["all_furni"] = {
        "label": [],
        "raw_model_path": [],
        "texture_image_path": [],
        "scale": [],
        "furni_pc": [],
        "furni_norms": [],
        "furni_loc": [],
        "furni_single_pc_scale": [],
        "furni_single_pc_bbox": [],
        "centroid": [],
        "translation": [],
        "z_angle": [],
        "size": [],
        "bbox": []
    }
    for index, furni in enumerate(scene.bboxes):
        proc_scene_dict["one_hot_class_labels"][index] = furni.one_hot_label(class_labels)
        # 为每个属性追加相应的值
        proc_scene_dict["all_furni"]["label"].append(furni.label)
        proc_scene_dict["all_furni"]["raw_model_path"].append(os.path.join(*furni.raw_model_path.split("/")[-2:]))
        proc_scene_dict["all_furni"]["texture_image_path"].append(os.path.join(*furni.texture_image_path.split("/")[-2:]))
        proc_scene_dict["all_furni"]["scale"].append(furni.scale) # This is the scale read from the dataset
        # get point cloud
        furni_pc_dict = get_furni_point_cloud(furni, pointcloud_size=pointcloud_size, scale=furni.scale)
        proc_scene_dict["all_furni"]["furni_pc"].append(furni_pc_dict["points"])
        proc_scene_dict["all_furni"]["furni_norms"].append(furni_pc_dict["normals"])
        proc_scene_dict["all_furni"]["furni_loc"].append(furni_pc_dict["loc"])
        proc_scene_dict["all_furni"]["furni_single_pc_scale"].append(furni_pc_dict["single_pc_scale"]) # this scale is different from the previous scale, this is single pc scale
        proc_scene_dict["all_furni"]["furni_single_pc_bbox"].append(furni_pc_dict["single_pc_bbox"])
        proc_scene_dict["all_furni"]["centroid"].append(furni.centroid())
        proc_scene_dict["all_furni"]["translation"].append(get_furni_translations(furni, scene))
        proc_scene_dict["all_furni"]["z_angle"].append(furni.z_angle)
        proc_scene_dict["all_furni"]["size"].append(furni.size)
        # proc_scene_dict["all_furni"]["bbox"].append(get_furni_bbox(furni.centroid(), furni.size))
    
    proc_scene_dict["text_des_orig_list"] = get_text_des(proc_scene_dict["all_furni"]["translation"], proc_scene_dict["all_furni"]["size"], 
                                                         proc_scene_dict["one_hot_class_labels"], class_labels)
    proc_scene_dict["text_des"] = "".join(proc_scene_dict["text_des_orig_list"])
    # We will add a key called text_path out side the function, as we also save the text outside the npy file
    # llm augmentation
    if aug_prob > 0:
        # print("use LLM aug, input_messages: ", proc_scene_dict["text_des"])
        proc_scene_dict["text_des"] = llm_handler.gen_text(generate_messages(proc_scene_dict["text_des"]))
    return proc_scene_dict

def get_processed_room_dict_full_llm(scene, class_labels, llm_handler, pointcloud_size=2048):
    class_num = len(class_labels)
    num_objs = len(scene.bboxes)
    
    proc_scene_dict = {}
    proc_scene_dict["one_hot_class_labels"] = np.zeros((num_objs, class_num), dtype=np.float32)
    # proc_scene_dict["room_layout"], proc_scene_dict["room_pc"] = get_room_layout(scene) # contains serious bugs, lowing down the speed
    proc_scene_dict["centroid"] = scene.centroid
    proc_scene_dict["floor_plan_centroid"] = scene.floor_plan_centroid
    proc_scene_dict["scene_id"] = scene.scene_id
    # 修改 all_furni 的结构，现在是字典，其中每个key都指向一个空列表
    proc_scene_dict["all_furni"] = {
        "label": [],
        "raw_model_path": [],
        "texture_image_path": [],
        "scale": [],
        "furni_pc": [],
        "furni_norms": [],
        "furni_loc": [],
        "furni_single_pc_scale": [],
        "furni_single_pc_bbox": [],
        "furni_single_transformed_bbox": [],
        "centroid": [],
        "translation": [],
        "z_angle": [],
        "size": [],
        "bbox": []
    }
    for index, furni in enumerate(scene.bboxes):
        proc_scene_dict["one_hot_class_labels"][index] = furni.one_hot_label(class_labels)
        # 为每个属性追加相应的值
        proc_scene_dict["all_furni"]["label"].append(furni.label)
        proc_scene_dict["all_furni"]["raw_model_path"].append(os.path.join(*furni.raw_model_path.split("/")[-2:]))
        proc_scene_dict["all_furni"]["texture_image_path"].append(os.path.join(*furni.texture_image_path.split("/")[-2:]))
        proc_scene_dict["all_furni"]["scale"].append(furni.scale) # This is the scale read from the dataset
        # get point cloud
        furni_pc_dict = get_furni_point_cloud(furni, pointcloud_size=pointcloud_size, scale=furni.scale)
        proc_scene_dict["all_furni"]["furni_pc"].append(furni_pc_dict["points"])
        proc_scene_dict["all_furni"]["furni_norms"].append(furni_pc_dict["normals"])
        proc_scene_dict["all_furni"]["furni_loc"].append(furni_pc_dict["loc"])
        proc_scene_dict["all_furni"]["furni_single_pc_scale"].append(furni_pc_dict["single_pc_scale"]) # this scale is different from the previous scale, this is single pc scale
        proc_scene_dict["all_furni"]["furni_single_pc_bbox"].append(furni_pc_dict["single_pc_bbox"])
        proc_scene_dict["all_furni"]["centroid"].append(furni.centroid())
        translation = get_furni_translations(furni, scene)
        proc_scene_dict["all_furni"]["translation"].append(translation)
        proc_scene_dict["all_furni"]["z_angle"].append(furni.z_angle)
        transformed_bbox, _ = get_transformed_furni(furni, translation, furni.z_angle)
        proc_scene_dict["all_furni"]["furni_single_transformed_bbox"].append(transformed_bbox.tolist())
        proc_scene_dict["all_furni"]["size"].append(furni.size)
        # proc_scene_dict["all_furni"]["bbox"].append(get_furni_bbox(furni.centroid(), furni.size))
    
    # We here only use LLM to generate the text description
    print("Your furni label: ", proc_scene_dict["all_furni"]["label"])
    print("Your transformed bbox: ", proc_scene_dict["all_furni"]["furni_single_transformed_bbox"])
    furni_json = get_furniture_json(proc_scene_dict["all_furni"]["label"], proc_scene_dict["all_furni"]["furni_single_transformed_bbox"])
    # print("furni_json: ", furni_json)
    furni_rel = get_furni_rel(furni_json, False)
    if furni_rel is not None:
        proc_scene_dict["furni_rel"] = furni_rel
        
        llm_prompt = get_llm_prompt(LLM_INSTRUCTION, furni_rel)
        # print("llm_prompt: ", llm_prompt)
        # input_temp = llm_handler.get_input_template(llm_prompt)
        proc_scene_dict["text_des"] = llm_handler.gen_text(llm_prompt)
        # proc_scene_dict["text_des"] = "None" # Only for debugging
    else:
        proc_scene_dict["furni_rel"] = "No relations"
        proc_scene_dict["text_des"] = "Please arrange this room in a reasonable way."
    return proc_scene_dict

# create a scene dictionary
def create_scene_dict(window_size=(64, 64), background=(0,0,0,1), up_vector=(0,0,-1), camera_target=(0, 0, 0), 
                      camera_position=(0, 4, 0), room_side=3.1):
    scene_camera_dict = {}
    scene_camera_dict["window_size"] = window_size
    scene_camera_dict["background"] = background
    scene_camera_dict["up_vector"] = up_vector
    scene_camera_dict["camera_target"] = camera_target
    scene_camera_dict["camera_position"] = camera_position
    scene_camera_dict["room_side"] = room_side
    return scene_camera_dict
    
# load scene setting from scene_dict
def scene_from_dict(scene_camera_dict):
    scene_camera = Scene(size=scene_camera_dict["window_size"], background=scene_camera_dict["background"])
    scene_camera.up_vector = scene_camera_dict["up_vector"]
    scene_camera.camera_target = scene_camera_dict["camera_target"]
    scene_camera.camera_position = scene_camera_dict["camera_position"]
    scene_camera.light = scene_camera_dict["camera_position"]
    scene_camera.camera_matrix = Matrix44.orthogonal_projection(
        left=-scene_camera_dict["room_side"], right=scene_camera_dict["room_side"],
        bottom=scene_camera_dict["room_side"], top=-scene_camera_dict["room_side"],
        near=0.1, far=6
    )
    return scene_camera

def get_room_layout(scene, pointcloud_size=2048):
    """
    input: scene_camera: a scene object defined by simple_3dviz for camera settings
           scene: room
    """
    scene_camera_dict = create_scene_dict()
    scene_camera = scene_from_dict(scene_camera_dict)
    floor_mesh, floor_vertices, floor_faces = floor_plan_renderable(scene)
    room_layout = render(
        scene_camera, 
        [floor_mesh],
        (1.0, 1.0, 1.0), 
        "flat", 
        None
    )[:, :, 0:1]
    
    # convert floor_mesh to point cloud
    
    floor_mesh = trimesh.Trimesh(vertices=floor_vertices, faces=floor_faces)
    floor_pc = floor_mesh.sample(pointcloud_size, return_index=False)
    
    return room_layout, floor_pc
    
def get_transformed_furni(furni, translation, theta):
    """
    input: furni: a furniture object
           translation: centralized translation
           theta: z angle
    """
    # load furniture mesh
    tr_mesh = trimesh.load(
        furni.raw_model_path, 
        process=False, 
        force="mesh", 
        skip_materials=True, 
        skip_texture=True
    )

    # rotate the mesh
    R = np.zeros((3, 3))
    R[0, 0] = np.cos(theta)
    R[0, 2] = -np.sin(theta)
    R[2, 0] = np.sin(theta)
    R[2, 2] = np.cos(theta)
    R[1, 1] = 1.
    tr_mesh.vertices[...] = tr_mesh.vertices.dot(R)
    
    # translate the mesh
    tr_mesh.vertices[...] = tr_mesh.vertices + translation #- scene_centroid
    transformed_bbox = tr_mesh.bounding_box.bounds
    return transformed_bbox, tr_mesh

def get_furni_point_cloud(furni, bbox_padding=0.0, pointcloud_size=2048, scale=None, apply_loc=True): # At default, we have no scale and we apply the location
    """
    We have furni.raw_model_path
    Read the data from the path, and transform it to point cloud
    reference: pickle_threed_fucture_pointcloud.py
    Return: a dict which includes points, normals, loc, scale
    """
    # load furniture mesh
    mesh = trimesh.load(
        furni.raw_model_path, 
        process=False, 
        force="mesh", 
        skip_materials=True, 
        skip_texture=True
    )
    # if scale is not None, we use the scale from the input
    # print("bbox before scale: ", mesh.bounding_box.bounds)
    if scale is not None:
        # print("apply scale: ", scale)
        mesh.vertices *= scale
    bbox = mesh.bounding_box.bounds
    # print("bbox after scale: ", bbox)
    
    # compute location and scale
    loc = (bbox[0] + bbox[1]) / 2
    single_pc_scale = (bbox[1] - bbox[0]).max() / (1 - bbox_padding)

    # apply translation on the input mesh
    if apply_loc:
        mesh.apply_translation(-loc)
        # update bbox
        bbox = mesh.bounding_box.bounds
    
    # sample point clouds with normals
    points, face_idx = mesh.sample(pointcloud_size, return_index=True)
    normals = mesh.face_normals[face_idx]
    
    # set data type
    dtype = np.float32
    points = points.astype(dtype)
    # print("before array: ", points)
    # from trackedarray to numpy array
    points = np.array(points)
    # print("after array: ", points)
    normals = normals.astype(dtype)

    # create result dict 
    res_dict = {
        "points" : points, # We get the original positioned point cloud
        "normals" : normals,
        "loc" : loc,
        "single_pc_scale" : single_pc_scale,
        "single_pc_bbox" : bbox
    }
    
    return res_dict

def get_furni_translations(furni, scene):
    return furni.centroid(-scene.centroid)
    
def get_furni_bbox(furni_trans, furni_size):
    pass
    
def compute_rel(box1, box2):
    center1 = np.array([(box1[0] + box1[3]) / 2, (box1[1] + box1[4]) / 2, (box1[2] + box1[5]) / 2])
    center2 = np.array([(box2[0] + box2[3]) / 2, (box2[1] + box2[4]) / 2, (box2[2] + box2[5]) / 2])

    # random relationship
    sx0, sy0, sz0, sx1, sy1, sz1 = box1
    ox0, oy0, oz0, ox1, oy1, oz1 = box2
    d = center1 - center2
    theta = math.atan2(d[2], d[0])  # range -pi to pi

    distance = (d[2]**2 + d[0]**2)**0.5
    
    # "on" relationship
    p = None
    if center1[0] >= box2[0] and center1[0] <= box2[3]:
        if center1[2] >= box2[2] and center1[2] <= box2[5]:
            delta1 = center1[1] - center2[1]
            delta2 = (box1[4] - box1[1] + box2[4] - box2[1]) / 2
            if 0 <(delta1 - delta2) < 0.05:
                p = 'on'
            elif 0.05 < (delta1 - delta2):
                p = 'above'
        return p, distance

    # eliminate relation in vertical axis now
    if abs(d[1]) > 0.5:
        return p, distance

    area_s = (sx1 - sx0) * (sz1 - sz0)
    area_o = (ox1 - ox0) * (oz1 - oz0)
    ix0, ix1 = max(sx0, ox0), min(sx1, ox1)
    iz0, iz1 = max(sz0, oz0), min(sz1, oz1)
    area_i = max(0, ix1 - ix0) * max(0, iz1 - iz0)
    iou = area_i / (area_s + area_o - area_i)
    touching = 0.0001 < iou < 0.5

    if sx0 < ox0 and sx1 > ox1 and sz0 < oz0 and sz1 > oz1:
        p = 'surrounding'
    elif sx0 > ox0 and sx1 < ox1 and sz0 > oz0 and sz1 < oz1:
        p = 'inside'
    # 60 degree intervals along each direction
    elif theta >= 5 * math.pi / 6 or theta <= -5 * math.pi / 6:
        p = 'right touching' if touching else 'left of'
    elif -2 * math.pi / 3 <= theta < -math.pi / 3:
        p = 'behind touching' if touching else 'behind'
    elif -math.pi / 6 <= theta < math.pi / 6:
        p = 'left touching' if touching else 'right of'
    elif math.pi / 3 <= theta < 2 * math.pi / 3:
        p = 'front touching' if touching else 'in front of'

    return p, distance

def dict_bbox_to_vec(dict_box):
    '''
    input: {'min': [1,2,3], 'max': [4,5,6]}
    output: [1,2,3,4,5,6]
    '''
    return dict_box['min'] + dict_box['max']

def get_relation(translation_list, size_list):
    relations = []
    num_objs = len(translation_list) # This is the number of objects in the scene

    for ndx in range(num_objs):
        this_box_trans = translation_list[ndx]
        this_box_sizes = size_list[ndx]
        this_box = {  'min': list(this_box_trans-this_box_sizes), 'max': list(this_box_trans+this_box_sizes)  }
        # print("this_box: ", this_box)
        # only backward relations
        choices = [other for other in range(num_objs) if other < ndx]
        # print("choices: ", choices)
        for other_ndx in choices:
            # prev_box_trans = sample['translations'][other_ndx, :]
            prev_box_trans = translation_list[other_ndx]
            prev_box_sizes = size_list[other_ndx]
            prev_box = {  'min': list(prev_box_trans-prev_box_sizes), 'max': list(prev_box_trans+prev_box_sizes) }
            box1 = dict_bbox_to_vec(this_box)
            box2 = dict_bbox_to_vec(prev_box)

            relation_str, distance = compute_rel(box1, box2)
            if relation_str is not None:
                relation = (ndx, relation_str, other_ndx, distance)
                relations.append(relation)
        
    return relations

def clean_obj_name(name):
    return name.replace('_', ' ')

def starts_with_vowel_sound(word, pronunciations=cmudict.dict()):
    for syllables in pronunciations.get(word, []):
        return syllables[0][-1].isdigit()


def get_article(word):
    word = word.split(" ")[0]
    article = "an" if starts_with_vowel_sound(word) else "a"
    return article

def add_description(one_hot_class_labels, class_labels, relations):
    # print("Adding description------------------------------")
    '''
        Add text descriptions to each scene
        sample['description'] = str is a sentence
        eg: 'The room contains a bed, a table and a chair. The chair is next to the window'
    '''
    sentences = []
    # if len(relations) == 0: # no relations
    #     sentences.append("Please arrange this room in a reasonable way.")
    #     return sentences
    # clean object names once
    classes = class_labels
    class_index = one_hot_class_labels.argmax(-1)
    obj_names = list(map(clean_obj_name, [classes[ind] for ind in class_index ] ))
    # print("obj_names: ", obj_names)
    # objects that can be referred to
    refs = []
    # TODO: handle commas, use "and"
    # TODO: don't repeat, get counts and pluralize
    # describe the first 2 or 3 objects
    # if eval:
    #     first_n = 3
    # else:
        # first_n = random.choice([2, 3])
    first_n = 100
    # first_n = len(obj_names)
    first_n_names = obj_names[:first_n] 
    first_n_counts = Counter(first_n_names)

    # s = 'The room has '
    # for ndx, name in enumerate(sorted(set(first_n_names), key=first_n_names.index)):
    #     if ndx == len(set(first_n_names)) - 1 and len(set(first_n_names)) >= 2:
    #         s += "and "
    #     if first_n_counts[name] > 1:
    #         s += f'{num2words(first_n_counts[name])} {name}s '
    #     else:
    #         s += f'{get_article(name)} {name} '
    #     if ndx == len(set(first_n_names)) - 1:
    #         s += ". "
    #     if ndx < len(set(first_n_names)) - 2:
    #         s += ', '
    # sentences.append(s)
    refs = set(range(first_n))

    # for each object, the "position" of that object within its class
    # eg: sofa table table sofa
    #   -> 1    1    2      1
    # use this to get "first", "second"

    seen_counts = defaultdict(int)
    in_cls_pos = [0 for _ in obj_names]
    for ndx, name in enumerate(first_n_names):
        seen_counts[name] += 1
        in_cls_pos[ndx] = seen_counts[name]

    for ndx in range(1, len(obj_names)):
        # higher prob of describing the 2nd object
        prob_thresh = 0.3
            
        # if eval:
        #     random_num = 1.0
        # else:
            # random_num = random.random() 
        random_num = 1.0
        if random_num > prob_thresh:
            # possible backward references for this object
            possible_relations = [r for r in relations \
                                    if r[0] == ndx \
                                    and r[2] in refs \
                                    and r[3] < 5]
            if len(possible_relations) == 0:
                continue
            # now future objects can refer to this object
            refs.add(ndx)

            # if we haven't seen this object already
            if in_cls_pos[ndx] == 0:
                # update the number of objects of this class which have been seen
                seen_counts[obj_names[ndx]] += 1
                # update the in class position of this object = first, second ..
                in_cls_pos[ndx] = seen_counts[obj_names[ndx]]

            # pick any one
            # if eval:
            #     (n1, rel, n2, dist) = possible_relations[0]
            # else:
            # (n1, rel, n2, dist) = random.choice(possible_relations)
            for (n1, rel, n2, dist) in possible_relations:
                o1 = obj_names[n1]
                o2 = obj_names[n2]

                # prepend "second", "third" for repeated objects
                if seen_counts[o1] > 1:
                    o1 = f'{num2words(in_cls_pos[n1], ordinal=True)} {o1}'
                if seen_counts[o2] > 1:
                    o2 = f'{num2words(in_cls_pos[n2], ordinal=True)} {o2}'

                # dont relate objects of the same kind
                if o1 == o2:
                    continue

                a1 = get_article(o1)

                if 'touching' in rel:
                    if ndx in (1, 2):
                        s = F'The {o1} is next to the {o2}'
                    else:
                        s = F'There is {a1} {o1} next to the {o2}'
                elif rel in ('left of', 'right of'):
                    if ndx in (1, 2):
                        s = f'The {o1} is to the {rel} the {o2}'
                    else:
                        s = f'There is {a1} {o1} to the {rel} the {o2}'
                elif rel in ('surrounding', 'inside', 'behind', 'in front of', 'on', 'above'):
                    if ndx in (1, 2):
                        s = F'The {o1} is {rel} the {o2}'
                    else:
                        s = F'There is {a1} {o1} {rel} the {o2}'
                s += ' . '
                sentences.append(s)

    # set back into the sample
    # sample['description'] = sentences

    # delete sample['relations']
    # del sample['relations']
    if sentences == []:
        sentences.append("Please arrange this room in a reasonable way.") # If no relations, we just return this sentence
    return sentences

# def add_description(one_hot_class_labels, class_labels, relations):
#     sentences = []
#     if len(relations) == 0: # no relations
#         sentences.append("Please arrange this room in a reasonable way.")
#         return sentences



def get_text_des(translation_list, size_list, one_hot_class_labels, class_labels):
    relation = get_relation(translation_list, size_list)
    print("----- new -----")
    print("relation: ", relation)
    text_des = add_description(one_hot_class_labels, class_labels, relation)
    print("text_des: ", text_des)
    print("----- end -----")
    return text_des
