# 
# Copyright (C) 2021 NVIDIA Corporation.  All rights reserved.
# Licensed under the NVIDIA Source Code License.
# See LICENSE at https://github.com/nv-tlabs/ATISS.
# Authors: Despoina Paschalidou, Amlan Kar, Maria Shugrina, Karsten Kreis,
#          Andreas Geiger, Sanja Fidler
# 

"""Script used for visualizing 3D-FRONT room specified by its scene_id."""
import argparse
import logging
import os
import sys

import numpy as np
import torch
from PIL import Image
import pyrr
import trimesh
import random

from scene_synthesis.datasets.threed_front import ThreedFront
from scene_synthesis.datasets import filter_function

# from simple_3dviz import Scene
# from simple_3dviz.behaviours.keyboard import SnapshotOnKey
# from simple_3dviz.behaviours.io import SaveFrames
# from simple_3dviz.renderables.textured_mesh import TexturedMesh
# from simple_3dviz.utils import render
# from simple_3dviz.window import show

from utils import floor_plan_from_scene, export_scene, create_directory
from gen_utils import get_processed_room_dict_full_llm
from llm_handler import LlamaHandler, load_LLM

def process_dataset(dataset, llm_model, pointcloud_size, dataset_save_path):
    data_idx = 7444
    for scene in dataset.scenes:
        scene_dict = get_processed_room_dict_full_llm(scene, dataset.class_labels, llm_model, pointcloud_size=pointcloud_size)
        # print("---------------------")
        # print(scene_dict["all_furni"]["scale"])
        # We save the scene_dict as npy file, and we also save the text in a txt file. We add the text path in the scene_dict.
        scene_dict_save_path = os.path.join(dataset_save_path, f"{data_idx}.npy")
        text_des = scene_dict["text_des"]
        text_des_save_path = os.path.join(dataset_save_path, f"{data_idx}.txt")
        with open(text_des_save_path, "w") as f:
            f.write(text_des)
        scene_dict["text_path"] = f"{data_idx}.txt"
        np.save(scene_dict_save_path, scene_dict)
        furni_rel = scene_dict["furni_rel"]
        furni_rel_save_path = os.path.join(dataset_save_path, f"{data_idx}_furni_rel.txt")
        with open(furni_rel_save_path, "w") as f:
            f.write(furni_rel)
        data_idx += 1
        # Debugging code
        # if data_idx == 5:
        #     break

def main(argv):
    parser = argparse.ArgumentParser(
        description="Visualize a 3D-FRONT room from json file"
    )
    ################## Original dataset paths ##################
    parser.add_argument(
        "path_to_3d_front_dataset_directory",
        help="Path to the 3D-FRONT dataset"
    )
    parser.add_argument(
        "path_to_3d_future_dataset_directory",
        help="Path to the 3D-FUTURE dataset"
    )
    parser.add_argument(
        "path_to_model_info",
        help="Path to the 3D-FUTURE model_info.json file"
    )
    parser.add_argument(
        "path_to_floor_plan_textures",
        help="Path to floor texture images"
    )
    parser.add_argument(
        "--path_to_invalid_bbox_jids",
        default="../config/black_list.txt",
        help="Path to objects that ae blacklisted"
    )
    parser.add_argument(
        "--path_to_invalid_scene_ids",
        default="../config/invalid_threed_front_rooms.txt",
        help="Path to invalid scenes"
    )
    parser.add_argument(
        "--annotation_file",
        default="../config/bedroom_threed_front_splits.csv",
        help="Path to the train/test splits file"
    )
    parser.add_argument(
        "--dataset_filtering",
        default="threed_front_bedroom",
        choices=[
            "threed_front_bedroom",
            "threed_front_livingroom",
            "threed_front_diningroom",
            "threed_front_library"
        ],
        help="The type of dataset filtering to be used"
    )
    parser.add_argument(
        "--without_lamps",
        action="store_true",
        help="If set ignore lamps when rendering the room"
    )
    parser.add_argument(
        "--pointcloud_size",
        type=int,
        default=2048,
        help="Sampling size for the point cloud"
    )
    ################## LLM and other dataset generation parameters ##################
    parser.add_argument("--seed", type=int, default=100) # seed for reproducibility
    parser.add_argument("--llm_model", type=str, default="Llama-3.1-8B", help="Options: Llama-3-8B, Llama-3.1-8B, Llama-3-70B, Llama-3.1-70B")
    parser.add_argument("--openai_api_key", type=str, default="") # OpenAI API key
    parser.add_argument("--max_new_tokens", type=int, default=500)
    parser.add_argument("--temperature", type=float, default=0.1)
    parser.add_argument("--dataset_save_path", type=str, default="my_dataset")

    args = parser.parse_args(argv)
    print("Your args: ", args)
    # fix seed
    if args.seed >= 0:
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    
    config = {
        "filter_fn":                 args.dataset_filtering,
        "min_n_boxes":               -1,
        "max_n_boxes":               -1,
        "path_to_invalid_scene_ids": args.path_to_invalid_scene_ids,
        "path_to_invalid_bbox_jids": args.path_to_invalid_bbox_jids,
        "annotation_file":           args.annotation_file
    }

    # create_flag = create_directory(args.dataset_save_path) # create the main directory
    # if not create_flag:
    #     print("You use the existing directory. Please make sure you want to overwrite the existing directory.")

    # load LLM
    print("Loading LLM model")
    llm_model = load_LLM(args)

    # ------------------ Training ------------------
    train_data = ThreedFront.from_dataset_directory(
        args.path_to_3d_front_dataset_directory,
        args.path_to_model_info,
        args.path_to_3d_future_dataset_directory,
        filter_fn=filter_function(config, ["train"], args.without_lamps)
    )

    print("Creating training dataset.")
    print("Loading train dataset with {} rooms".format(len(train_data)))
    print("class labels: ", train_data.class_labels)

    train_dataset_save_path = os.path.join(args.dataset_save_path, "train")
    # create_directory(train_dataset_save_path)
    process_dataset(train_data, llm_model, args.pointcloud_size, train_dataset_save_path)
    
    print("Training dataset generation completed.")

    # ------------------ Testing ------------------
    # test_data = ThreedFront.from_dataset_directory(
    #     args.path_to_3d_front_dataset_directory,
    #     args.path_to_model_info,
    #     args.path_to_3d_future_dataset_directory,
    #     filter_fn=filter_function(config, ["val", "test"], args.without_lamps)
    # )
    # print("Creating testing dataset")
    # print("Loading test dataset with {} rooms".format(len(test_data)))
    # print("class labels: ", test_data.class_labels)

    # test_dataset_save_path = os.path.join(args.dataset_save_path, "test")
    # create_directory(test_dataset_save_path)
    # process_dataset(test_data, llm_model, args.pointcloud_size, test_dataset_save_path)

    # print("Testing dataset generation completed.")

    # ------------------ END ------------------
    print("All dataset generation completed.")

if __name__ == "__main__":
    main(sys.argv[1:])
    