#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render_hair,render_hair_weight_fine
import torchvision
from utils.general_utils import safe_state
from utils.image_utils import vis_orient
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams, get_combined_args, ModelHiddenParams, TextureHiddenParams
from scene import Scene, GaussianModel, GaussianModelCurves
import pickle as pkl
import yaml
import math
import shutil
import numpy as np
from plyfile import PlyData, PlyElement

def construct_list_of_attributes():
    l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
    return l

def render_set(model_path, name, iteration, views, gaussians, gaussians_hair, pipeline, background, scene_suffix,num_strands=50_000):
    dir_name = f"{name}{scene_suffix}"
    render_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "renders")
    hair_mask_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "hair_masks")
    head_mask_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "head_masks")
    orient_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "orients")
    orient_vis_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "orients_vis")
    orient_conf_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "orient_confs")
    orient_conf_vis_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "orient_confs_vis")
    strands_path = os.path.join(model_path, dir_name, "ours_{}".format(iteration), "strands")

    makedirs(render_path, exist_ok=True)
    makedirs(hair_mask_path, exist_ok=True)
    makedirs(head_mask_path, exist_ok=True)
    makedirs(orient_path, exist_ok=True)
    makedirs(orient_vis_path, exist_ok=True)
    makedirs(orient_conf_path, exist_ok=True)
    makedirs(orient_conf_vis_path, exist_ok=True)
    makedirs(strands_path, exist_ok=True)
    render_state = 'fine'
    for idx in tqdm(range(len(views)), desc="Rendering progress"):
        view = views[idx]
        gaussians_hair.initialize_gaussians_hair(num_strands,time_step=view.time_step)
        # output = render_hair(view, gaussians, gaussians_hair, pipeline, background,render_state)
        output = render_hair_weight_fine(view, gaussians, gaussians_hair, pipeline, background, render_state)
        gt_mask = view.original_mask.cuda()
        # image = output["render"] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
        # hair_mask = output["mask"][:1] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
        image = output["render"] 
        hair_mask = output["mask"][:1] 
        head_mask = output["mask"][1:] 
        # orient_angle = output["orient_angle"]  * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
        orient_angle = output["orient_angle"]
        orient_angle_vis = vis_orient(output["orient_angle"], hair_mask)
        orient_conf = output["orient_conf"] * hair_mask
        orient_conf_vis = (1 - 1 / (orient_conf + 1))
        orient_conf_vis = vis_orient(output["orient_angle"], orient_conf_vis)

        basename = os.path.basename(view.image_name).split('.')[0]
        # cam_name = os.path.basename(view.camera_name).split('.')[0]
        cam_name = view.image_path.split('/')[-2]

        makedirs(os.path.join(render_path, cam_name), exist_ok=True)
        makedirs(os.path.join(hair_mask_path, cam_name), exist_ok=True)
        makedirs(os.path.join(head_mask_path, cam_name), exist_ok=True)
        makedirs(os.path.join(orient_path, cam_name), exist_ok=True)
        makedirs(os.path.join(orient_vis_path, cam_name), exist_ok=True)
        makedirs(os.path.join(orient_conf_path, cam_name), exist_ok=True)
        makedirs(os.path.join(orient_conf_vis_path, cam_name), exist_ok=True)
        # makedirs(os.path.join(strands_path, cam_name), exist_ok=True)
        
        L = 100
        pts = gaussians_hair._pts
        pts = pts.reshape(-1, L, 3)
        p_npy = pts.detach().cpu().numpy()
        xyz = p_npy.reshape(-1, 3)
        normals = np.zeros_like(xyz)
        attributes = np.concatenate((xyz, normals), axis=1)
        dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()]
        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        # PlyData([el]).write(f'{args.data_dir}/curves_reconstruction/{args.model_name}/strands/{args.iter}_strands.ply')
        PlyData([el]).write(os.path.join(strands_path, basename + ".ply"))
        
        torchvision.utils.save_image(image, os.path.join(render_path, cam_name, basename + ".png"))
        torchvision.utils.save_image(hair_mask, os.path.join(hair_mask_path, cam_name, basename + ".png"))
        torchvision.utils.save_image(head_mask, os.path.join(head_mask_path, cam_name, basename + ".png"))
        torchvision.utils.save_image(orient_angle, os.path.join(orient_path, cam_name, basename + ".png"))
        torchvision.utils.save_image(orient_angle_vis, os.path.join(orient_vis_path, cam_name, basename + ".png"))
        torch.save(orient_conf, os.path.join(orient_conf_path, cam_name, basename + ".pth"))
        torchvision.utils.save_image(orient_conf_vis, os.path.join(orient_conf_vis_path, cam_name, basename + ".png"))

@torch.no_grad()
def render_sets(dataset, defor, texture_hidden, optimizer, optimizer_hair, iteration : int, pipeline : PipelineParams, model_hair_path : str, pointcloud_path_head : str, checkpoint_hair : str, checkpoint_curves : str, skip_train : bool, skip_test : bool, scene_suffix : str):
    gaussians = GaussianModel(3)
    gaussians_hair = GaussianModelCurves(dataset.source_path, dataset.flame_mesh_dir, opt_hair, texture_hidden, defor, 3,dataset.start_time_step, dataset.num_time_steps)
    scene = Scene(dataset, gaussians, pointcloud_path=pointcloud_path_head, load_iteration=-1)
    gaussians.training_setup(optimizer)
    
    # Initialize hair gaussians
    model_params, _ = torch.load(checkpoint_hair)
    gaussians_hair.create_from_pcd(dataset.source_path, model_params, 20_000, gaussians.spatial_lr_scale)
    model_params, _ = torch.load(checkpoint_curves)
    # model_params, _ = torch.load(args.checkpoint_curves)
    gaussians_hair.restore(model_params, optimizer_hair)
    gaussians_hair.use_sds = False
    num_strands = 20_000
    gaussians_hair.initialize_gaussians_hair(num_strands,time_step=0)
    gaussians_hair.deformaton_pts_scale = 1e-2
    gaussians_hair.deformaton_color_scale = 1
    gaussians_hair.deformaton_hf_pts_scale = 1e-7
    gaussians_hair.deformaton_hf_color_scale = 1e-3
    gaussians_hair.deformaton_coarse_scale = 1e-1

    # Precompute head gaussians
    # gaussians.mask_precomp = gaussians.get_label[..., 0] < 0.5
    # gaussians.xyz_precomp = gaussians.get_xyz[gaussians.mask_precomp].detach()
    # gaussians.opacity_precomp = gaussians.get_opacity[gaussians.mask_precomp].detach()
    # gaussians.scaling_precomp = gaussians.get_scaling[gaussians.mask_precomp].detach()
    # gaussians.rotation_precomp = gaussians.get_rotation[gaussians.mask_precomp].detach()
    # gaussians.cov3D_precomp = gaussians.get_covariance(1.0)[gaussians.mask_precomp].detach()
    # gaussians.shs_view = gaussians.get_features[gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree + 1)**2)
    gaussians.mask_precomp = gaussians.get_label()[..., 0] < 0.6
    gaussians.points_mask_head_indices = gaussians.mask_precomp.nonzero(as_tuple=True)[0]
    gaussians.xyz_precomp = gaussians.get_xyz()[gaussians.mask_precomp].detach()
    gaussians.opacity_precomp = gaussians.get_opacity()[gaussians.mask_precomp].detach()
    gaussians.scaling_precomp = gaussians.get_scaling()[gaussians.mask_precomp].detach()
    gaussians.rotation_precomp = gaussians.get_rotation()[gaussians.mask_precomp].detach()
    gaussians.cov3D_precomp = gaussians.get_covariance(1.0)[gaussians.mask_precomp].detach()
    gaussians.shs_view = gaussians.get_features()[gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree + 1)**2)

    bg_color = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] if dataset.white_background else [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    if not skip_train:
        render_set(model_hair_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, gaussians_hair, pipeline, background, scene_suffix,num_strands)

    if not skip_test:
        render_set(model_hair_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, gaussians_hair, pipeline, background, scene_suffix,num_strands)


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser)
    optimizer = OptimizationParams(parser)
    pipeline = PipelineParams(parser)
    hp = ModelHiddenParams(parser)
    tp = TextureHiddenParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--data_dir", type=str, default = None)
    parser.add_argument("--model_hair_path", type=str, default = None)
    parser.add_argument("--hair_conf_path", type=str, default = None)
    parser.add_argument("--checkpoint_hair", type=str, default = None)
    parser.add_argument("--checkpoint_curves", type=str, default = None)
    parser.add_argument("--scene_suffix", default="", type=str)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--pointcloud_path_head", type=str, default = None)
    parser.add_argument("--configs", type=str, default = "")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Configuration of hair strands
    with open(args.hair_conf_path, 'r') as f:
        replaced_conf = str(yaml.load(f, Loader=yaml.Loader)).replace('DATASET_TYPE', 'monocular')
        opt_hair = yaml.load(replaced_conf, Loader=yaml.Loader)

    # Initialize system state (RNG)
    safe_state(args.quiet)
    
    if args.configs:
        import mmcv
        from utils.params_utils import merge_hparams
        config = mmcv.Config.fromfile(args.configs)
        args = merge_hparams(args, config)
    
    dataset = model.extract(args)
    if args.data_dir is not None:
        dataset.source_path = args.data_dir

    render_sets(model.extract(args), hp.extract(args), tp.extract(args), optimizer.extract(args), opt_hair, args.iteration, pipeline.extract(args), args.model_hair_path, args.pointcloud_path_head, args.checkpoint_hair, args.checkpoint_curves, args.skip_train, args.skip_test, args.scene_suffix)
            
    # Clean extra files to conserve space
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/hair_masks')
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/head_masks')
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/orient_confs')
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/orient_confs_vis')
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/orients')
    # shutil.rmtree(f'{args.model_hair_path}/train/ours_30000/orients_vis')