import os
import torch
import numpy as np
import json_tricks as json
from argparse import ArgumentParser
from tqdm import tqdm
from collections import defaultdict
import mitsuba_render_func as mitsuba_render

from arrange_tools import arrange_room_mesh

def render_transformed_obj(scene_folder, save_folder, render_img_name=None, keep_source_file=False) -> np.array:
    """
    Render the scene with the given configuration

    Args:
        scene_folder (str): the folder containing the scene files
        save_folder (str): the folder to save the rendered image
        render_img_name (str): the name of the rendered image
        keep_source_folder (bool, optional): _description_. Defaults to False.
    """
    
    os.makedirs(save_folder, exist_ok=True)
    
    obj_files = mitsuba_render.get_obj_files(scene_folder)
    mitsuba_render.save_render_xml([os.path.join(scene_folder, obj_file) for obj_file in obj_files], os.path.join(scene_folder, "render.xml"), scene_folder)
    rendered_img = mitsuba_render.mitsuba_render(os.path.join(scene_folder, "render.xml"), save_folder, render_img_name=render_img_name) # no save render
    os.remove(os.path.join(scene_folder, "render.xml"))
    
    # remove the source folder
    if not keep_source_file:
        files = os.listdir(scene_folder)
        for file in files:
            os.remove(os.path.join(scene_folder, file))
        os.rmdir(scene_folder)
                
    return rendered_img   

parser = ArgumentParser()

# General experiment settings
parser.add_argument("--json_dir", type=str, default="json_dir", help="directory containing the json files generated by the LLM")
parser.add_argument("--pred_output_dir", type=str, default="pred_output_dir", help="output directory for the pred rendered results")
parser.add_argument("--gt_output_dir", type=str, default="gt_output_dir", help="output directory for the gt rendered results")
parser.add_argument("--source_dir", type=str, default="source_dir", help="directory containing the source files")
parser.add_argument("--model3d_base_dir", type=str, default="model3d_base_dir", help="directory containing the 3D models")
parser.add_argument("--pred_rendered_img_dir", type=str, default="rendered_img_dir", help="directory to save the rendered images")
parser.add_argument("--gt_rendered_img_dir", type=str, default="rendered_img_dir", help="directory to save the rendered images")
parser.add_argument("--bbox_dir", type=str, default="trimesh_bbox_dir", help="directory to save the trimesh bounding boxes")

args = parser.parse_args()

# create dirs
os.makedirs(args.pred_output_dir, exist_ok=True)
os.makedirs(args.gt_output_dir, exist_ok=True)
os.makedirs(args.pred_rendered_img_dir, exist_ok=True)
os.makedirs(args.gt_rendered_img_dir, exist_ok=True)
os.makedirs(args.bbox_dir, exist_ok=True)

llm_answers_dir = os.path.join(args.json_dir, "Answer")
llm_answers = sorted(os.listdir(llm_answers_dir))

# iterate through all llm answers and render each scene
for i, llm_answer in tqdm(enumerate(llm_answers)):
    # if i < 18:
    #     continue
    # print(llm_answer)
    if llm_answer != 'prompt_A_65.json':
        continue
    with open(os.path.join(llm_answers_dir, llm_answer), "r") as f:
        llm_answer = json.load(f)
    
    # source file contains information about the scene
    source_file = os.path.join(args.source_dir, llm_answer["load_file"])
    data = np.load(source_file, allow_pickle=True).item()
    source_file_idx = llm_answer["load_file"].split(".")[0]
    
    # load necessary information from source file
    scene_id = data["scene_id"]
    furniture_list = data["all_furni"]['label']
    raw_model_path_list = data['all_furni']['raw_model_path']
    texture_image_path_list = data['all_furni']['texture_image_path']
    
    # load transformation from llm answer
    llm_answer = llm_answer["output_structure"]
    transformation_list = llm_answer['transformation']
    furniture_list_llm = [k['label'] for k in llm_answer["transformation"]]
    
    # check the validity of the llm answer and the source file
    # if they are not the same, change the order of the furniture_list_llm to match the furniture_list
    # if the furniture_list_llm is missing some furniture, add the missing furniture to the furniture_list_llm with translation (0, 0, 0) and rotation 0
    if furniture_list != furniture_list_llm:
        print(f"{source_file}: Furniture list {furniture_list_llm} in llm answer and source file {furniture_list} do not match. Reordering the furniture_list_llm")
        if len(furniture_list) != len(furniture_list_llm):
            print(f"{source_file}: Furniture list in llm answer and source file do not match. The length of the two lists are different")
            # add the missing furniture to the furniture_list_llm with translation (0, 0, 0) and rotation 0 add to the end of the transformation_list
            missing_furniture = list(set(furniture_list) - set(furniture_list_llm))
            for missing_furniture_label in missing_furniture:
                default_transform = {
                    "label": missing_furniture_label,
                    "translation": {"x": 0, "y": 0, "z": 0},
                    "rotation": 0
                }
                print('haha')
                transformation_list.append(default_transform)
                furniture_list_llm.append(missing_furniture_label)
  
        index_map = defaultdict(list)
        for idx, val in enumerate(furniture_list_llm):
            index_map[val].append(idx)
        # build the index list that reorders the furniture_list_llm
        print(index_map)
        indices = [index_map[val].pop(0) for val in furniture_list]
        transformation_list = [transformation_list[i] for i in indices]
        furniture_list_llm = [k['label'] for k in transformation_list]
        
        
    assert furniture_list == furniture_list_llm, "Furniture list in llm answer and source file do not match"
    
    # reformate the transformation list

    # set all rotation/translation to 0 if missing keys
    for t in transformation_list:
        if 'rotation' not in t:
            t['rotation'] = 0
        if 'translation' not in t:
            t['translation'] = {'x': 0, 'y': 0, 'z': 0}
    transformation_list = [[t['rotation'], t['translation']['x'], t['translation']['y'], t['translation']['z']] for t in transformation_list]
    
    # render the pred scene
    pred_scene_folder, pred_bbox_list = arrange_room_mesh(
        output_dir = args.pred_output_dir,
        scene_id = scene_id,
        model3d_base_dir = args.model3d_base_dir,
        raw_model_path_list = raw_model_path_list,
        transformation_list = transformation_list,
        scene_centroid = None, 
        texture_image_path_list = None,
    )
    
    pred_rendered_img = render_transformed_obj(
        scene_folder = pred_scene_folder,
        save_folder = os.path.join(pred_scene_folder, "rendered_img"),
        render_img_name = "rendered_img",
        keep_source_file = True
    )
    
    # render the gt scene
    gt_translation = data["all_furni"]["translation"]
    gt_rotation = data["all_furni"]["z_angle"]
    gt_transformation_list = [[gt_rotation[i], gt_translation[i][0], gt_translation[i][1], gt_translation[i][2]] for i in range(len(gt_translation))]
    
    gt_scene_folder, gt_bbox_list = arrange_room_mesh(
        output_dir = args.gt_output_dir,
        scene_id = scene_id,
        model3d_base_dir = args.model3d_base_dir,
        raw_model_path_list = raw_model_path_list,
        transformation_list = gt_transformation_list,
        scene_centroid = None, 
        texture_image_path_list = None,
    )
    
    gt_rendered_img = render_transformed_obj(
        scene_folder = gt_scene_folder,
        save_folder = os.path.join(gt_scene_folder, "rendered_img"),
        render_img_name = "rendered_img",
        keep_source_file = True
    )
    
    # save the rendered images
    mitsuba_render.mi_write_img(os.path.join(args.pred_rendered_img_dir, f"pred_rendered_img_{source_file_idx}.png"), pred_rendered_img)
    mitsuba_render.mi_write_img(os.path.join(args.gt_rendered_img_dir, f"gt_rendered_img_{source_file_idx}.png"), gt_rendered_img)
    
    # save the trimesh_bboxes as .pt in another folder
    torch.save(
        {"pred_bbox": pred_bbox_list, "gt_bbox": gt_bbox_list},
        os.path.join(args.bbox_dir, f"bbox_{source_file_idx}.pt")
    )