import os
import trimesh
import torch
import shutil
from torch_geometric.data import Data, Batch
from torch import Tensor
from PIL import Image
import numpy as np
from utils import export_scene
import mitsuba_render_func as mitsuba_render


def arrange_room_mesh(output_dir, scene_id, model3d_base_dir, raw_model_path_list: list, transformation_list: Tensor, scene_centroid, texture_image_path_list=None):
    """
    arrange math formula:
    translation of one obj: origin_pos + translation (furni_centroid - scene_centroid)
                            = centered_model (origin_pos - furni_loc) + furni_loc + translation (furni_centroid - scene_centroid)
    In the learning procedure, we use centered_model as the input to the model, and the model will predict the translation.
    So we need to back to the original position before applying the translation.
    transformation_list: List of [theta, x, y, z]
    
    output: a folder of obj files with the given transformation
    """
    trimesh_meshes = []
    trimesh_bboxes = []
    for i, raw_model_path in enumerate(raw_model_path_list):
        # load the mesh file with original position
        tr_mesh = trimesh.load(os.path.join(model3d_base_dir, raw_model_path), force="mesh")
        
        if texture_image_path_list is not None:
            # load the texture image
            tr_mesh.visual.material.image = Image.open(os.path.join(model3d_base_dir, texture_image_path_list[i]))  
        
        # apply the transformation
        transformation = transformation_list[i]
        translation = transformation[1:4]
        theta = transformation[0]
        scale = transformation[4:]
        # print("transform: ", transformation)
        # tr_mesh.vertices *= scale
        
        # rotate the mesh
        R = np.zeros((3, 3))
        R[0, 0] = np.cos(theta)
        R[0, 2] = -np.sin(theta)
        R[2, 0] = np.sin(theta)
        R[2, 2] = np.cos(theta)
        R[1, 1] = 1.
        tr_mesh.vertices[...] = tr_mesh.vertices.dot(R)
        
        # translate the mesh
        tr_mesh.vertices[...] = tr_mesh.vertices + translation #- scene_centroid
        trimesh_meshes.append(tr_mesh)
        orig_bbox = tr_mesh.bounding_box.bounds
        min_max_bbox = [[0, 0, 0], [0, 0, 0]]
        min_max_bbox[0][0] = min(orig_bbox[0][0], orig_bbox[1][0])
        min_max_bbox[0][1] = min(orig_bbox[0][1], orig_bbox[1][1])
        min_max_bbox[0][2] = min(orig_bbox[0][2], orig_bbox[1][2])
        min_max_bbox[1][0] = max(orig_bbox[0][0], orig_bbox[1][0])
        min_max_bbox[1][1] = max(orig_bbox[0][1], orig_bbox[1][1])
        min_max_bbox[1][2] = max(orig_bbox[0][2], orig_bbox[1][2])
        trimesh_bboxes.append(min_max_bbox)
        # print("min_max_bbox: ", min_max_bbox)
        # print("orig_bbox: ", orig_bbox)
        
    # create a trimesh scene and export it
    path_to_objs = os.path.join(output_dir, "transformed_{}".format(scene_id))
    if not os.path.exists(path_to_objs):
        os.makedirs(path_to_objs)
    export_scene(path_to_objs, trimesh_meshes)
    return path_to_objs, trimesh_bboxes
        
        
def render_transformed_obj(scene_folder, keep_source_file=False, output_image_path=None) -> np.array:
    """
    Render the scene with the given configuration

    Args:
        scene_folder (str): the folder containing the scene files
        save_folder (str): the folder to save the rendered image
        render_img_name (str): the name of the rendered image
        keep_source_folder (bool, optional): _description_. Defaults to False.
    """
    
    # os.makedirs(save_folder, exist_ok=True)
    print(output_image_path)
    obj_files = mitsuba_render.get_obj_files(scene_folder)
    mitsuba_render.save_render_xml([os.path.join(scene_folder, obj_file) for obj_file in obj_files], os.path.join(scene_folder, "render.xml"), scene_folder)
    rendered_img = mitsuba_render.mitsuba_render(os.path.join(scene_folder, "render.xml"), output_image_path) # no save render
    os.remove(os.path.join(scene_folder, "render.xml"))
    
    # remove the source folder
    if not keep_source_file:
        shutil.rmtree(scene_folder)
                
    return rendered_img   

def arrange_3d_to_2d(pos_data: torch.Tensor):
    pass

def arrange_2d_to_3d(pos_data: torch.Tensor, y_axis_data: torch.Tensor):
    pass

# def pyg_data_to_list(pyg_data, batch_code):
#     pass

def render_frames(pos_states: list, gt_pyg_data: Data, irregular_data, cache_dir, model3d_base_dir, y_axis_data=None, keep_source_file=False):
    # if y_axis_data is not None, that means we are using 2D arrangement
    if y_axis_data is not None:
        raise NotImplementedError("2D arrangement is not implemented yet.")
    pred_batch_frames = []
    gt_batch_frames = []
    text_des_list = []
    furni_rel_list = []
    pred_bbox_list = []
    gt_bbox_list = []
    need_gt = True
    pred_pyg_data = gt_pyg_data.detach().clone() # we need to form a similar pyg_data for the prediction
    for pos_state in pos_states:
        pred_pyg_data.y = pos_state.detach().clone()
        pred_render_batch, gt_render_batch, text_dex_batch, furni_rel_batch, pred_bbox_batch, gt_bbox_batch = render_one_batch_frame(pred_pyg_data, gt_pyg_data, irregular_data, cache_dir, model3d_base_dir, need_gt, keep_source_file=keep_source_file)
        need_gt = False
        pred_batch_frames.append(pred_render_batch)
        gt_batch_frames.append(gt_render_batch) # only the first frame has gt
        text_des_list.append(text_dex_batch)
        furni_rel_list.append(furni_rel_batch)
        pred_bbox_list.append(pred_bbox_batch)
        gt_bbox_list.append(gt_bbox_batch)
    return pred_batch_frames, gt_batch_frames, text_des_list, furni_rel_list, pred_bbox_list, gt_bbox_list

def render_one_batch_frame(pred_pyg_data: Data, gt_pyg_data: Data, irregular_data, cache_dir, model3d_base_dir, need_gt, keep_source_file=False):
    print("Your proc_pyg_data is: ", gt_pyg_data)
    pred_pyg_data_list = Batch.to_data_list(pred_pyg_data)
    gt_pyg_data_list = Batch.to_data_list(gt_pyg_data)
    pred_render_batch = []
    gt_render_batch = []
    text_dex_batch = []
    furni_rel_batch = []
    pred_bbox_batch = []
    gt_bbox_batch = []
    for idx in range(len(gt_pyg_data_list)):
        pred_render, gt_render, text_des, furni_rel, pred_bbox, gt_bbox = \
            get_one_render(pred_pyg_data_list[idx],
                            gt_pyg_data_list[idx], 
                            irregular_data["raw_model_path"][idx],
                            irregular_data["texture_image_path"][idx],
                            np.array(irregular_data["scene_centroid"]),
                            irregular_data["text_des"][idx],
                            irregular_data["furni_rel"][idx],
                            cache_dir,
                            model3d_base_dir,
                            need_gt,
                            keep_source_file=keep_source_file)
        pred_render_batch.append(pred_render)
        gt_render_batch.append(gt_render)
        text_dex_batch.append(text_des)
        furni_rel_batch.append(furni_rel)
        pred_bbox_batch.append(pred_bbox)
        gt_bbox_batch.append(gt_bbox)

    return pred_render_batch, gt_render_batch, text_dex_batch, furni_rel_batch, pred_bbox_batch, gt_bbox_batch

def get_one_render(pred_one_pyg_data: Data,
                   gt_one_pyg_data: Data, 
                   raw_model_path_list, 
                   texture_image_path_list,
                   scene_centroid,
                   text_des,
                   furni_rel,
                   cache_dir,
                   model3d_base_dir, 
                   need_gt,
                   keep_source_file=False):
    """
    in the cache_dir, we will generate some files to store the intermediate results, and we will remove them after the rendering
    In this case, the render can only run in a single process, otherwise the cache files will be overwritten
    If we want multi-process rendering, we need to use different cache_dir for each process
    """
    # print("Your one_pyg_data is: ", gt_one_pyg_data)
    # raw_model_path_list = one_irregular_data["raw_model_path"]
    # texture_image_path_list = one_irregular_data["texture_image_path"]
    gt_transformation_list = gt_one_pyg_data.y
    gt_transformation_list = gt_transformation_list.detach().cpu().numpy()
    gt_transformation_list = np.concatenate([gt_transformation_list, gt_one_pyg_data.scale.detach().cpu().numpy()], axis=-1)
    # scene_centroid = np.array(one_irregular_data["scene_centroid"])
    # text_des = one_irregular_data["text_des"]
    # furni_rel = one_irregular_data["furni_rel"]
    if need_gt:
        scene_folder, gt_bbox = arrange_room_mesh(os.path.join(cache_dir, "gt"), 0, model3d_base_dir, raw_model_path_list, 
                                scene_centroid=scene_centroid, transformation_list=gt_transformation_list, 
                                texture_image_path_list=texture_image_path_list)
        gt_render = render_transformed_obj(scene_folder, keep_source_file=keep_source_file)
    else:
        gt_render = None
        gt_bbox = None

    # get pred_render
    pred_transformation_list = pred_one_pyg_data.y
    pred_transformation_list = pred_transformation_list.detach().cpu().numpy()
    pred_transformation_list = np.concatenate([pred_transformation_list, pred_one_pyg_data.scale.detach().cpu().numpy()], axis=-1)
    scene_folder, pred_bbox = arrange_room_mesh(os.path.join(cache_dir, "pred"), 0, model3d_base_dir, raw_model_path_list, 
                              scene_centroid=scene_centroid, transformation_list=pred_transformation_list, 
                              texture_image_path_list=texture_image_path_list)
    pred_render = render_transformed_obj(scene_folder, keep_source_file=keep_source_file)
    return pred_render, gt_render, text_des, furni_rel, pred_bbox, gt_bbox


