#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import torch
import torch.nn.functional as F
from random import randint
from utils.loss_chamfer_utils import chamfer_distance
from utils.loss_utils import l1_loss, ce_loss, or_loss, ssim
from gaussian_renderer import render_hair, network_gui, render_hair_weight, render_hair_weight_fine, render_hair_weight_wo_active,render_hair_weight_sparse,render_hair_weight_fine_test
import sys
import yaml
from scene import Scene, GaussianModel, GaussianModelCurves
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr, vis_orient,psnr_mask
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams, ModelHiddenParams, TextureHiddenParams
from torch.utils.data import DataLoader
import pickle as pkl
from utils.general_utils import build_rotation, custom_collate_fn, DynamicFrameSampler
import time
from kaolin.metrics.pointcloud import sided_distance
import numpy as np
from collections import deque

try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False

def print_autograd_graph(fn, seen=None, indent=0):
    if seen is None:
        seen = set()
    spacer = ' ' * (indent * 2)
    print(f"{spacer}{fn.__class__.__name__}")
    for next_fn, _ in fn.next_functions:
        if next_fn is not None and next_fn not in seen:
            seen.add(next_fn)
            print_autograd_graph(next_fn, seen, indent + 1)

def training(dataset, defor, texture_hidden, opt, opt_hair, pipe, testing_iterations, saving_iterations, checkpoint_iterations, model_path_curves, pointcloud_path_head, checkpoint, checkpoint_hair, checkpoint_all,debug_from):
    # torch.autograd.set_detect_anomaly(True)
    first_iter = 0
    time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    print(f"log and checkpoint will be saved at {time_str}")
    tb_writer = prepare_output_and_logger(dataset, model_path_curves, time_str)
    gaussians = GaussianModel(dataset.sh_degree)
    gaussians_hair = GaussianModelCurves(dataset.source_path, dataset.flame_mesh_dir, opt_hair, texture_hidden, defor, dataset.sh_degree, dataset.start_time_step, dataset.num_time_steps)
    # scene = Scene(dataset, gaussians, pointcloud_path=pointcloud_path_head, load_iteration=-1)
    scene = Scene(dataset, gaussians, pointcloud_path = pointcloud_path_head, load_iteration = None)
    gaussians.training_setup(opt)
    model_params, _ = torch.load(checkpoint_hair)
    # num_strands = 22_500
    num_strands = 20_000
    gaussians_hair.create_from_pcd(dataset.source_path, model_params, num_strands, gaussians.spatial_lr_scale)
    gaussians_hair.training_setup(opt)
    if checkpoint:
        print("Loading model parameters from checkpoint")
        model_params, _ = torch.load(checkpoint)
        gaussians.restore(model_params, opt)
    if checkpoint_all:
        print("Loading all model parameters from checkpoint")
        model_params, _ = torch.load(checkpoint_all)
        # model_params, first_iter = torch.load(checkpoint_all)
        gaussians_hair.restore(model_params, opt)    
    if dataset.trainable_cameras:
        print(f'Loading optimized cameras from iter {scene.loaded_iter}')
        params_cam_rotation, params_cam_translation, params_cam_fov = pkl.load(open(scene.model_path + "/cameras/" + str(scene.loaded_iter) + ".pkl", 'rb'))
        for k in scene.train_cameras.keys():
            for camera in scene.train_cameras[k]:
                if dataset.trainable_cameras:
                    camera._rotation_res.data = params_cam_rotation[camera.image_name]
                    camera._translation_res.data = params_cam_translation[camera.image_name]
                if dataset.trainable_intrinsics:
                    camera._fov_res.data = params_cam_fov[camera.image_name]
    gaussians_hair.initialize_gaussians_hair(num_strands,time_step=0)
    with torch.no_grad():
        # Head gaussians data
        gaussians.mask_precomp = gaussians.get_label()[..., 0] < 0.6
        gaussians.points_mask_head_indices = gaussians.mask_precomp.nonzero(as_tuple=True)[0]
        gaussians.xyz_precomp = gaussians.get_xyz()[gaussians.mask_precomp].detach()
        gaussians.opacity_precomp = gaussians.get_opacity()[gaussians.mask_precomp].detach()
        gaussians.scaling_precomp = gaussians.get_scaling()[gaussians.mask_precomp].detach()
        gaussians.rotation_precomp = gaussians.get_rotation()[gaussians.mask_precomp].detach()
        gaussians.cov3D_precomp = gaussians.get_covariance(1.0)[gaussians.mask_precomp].detach()
        gaussians.shs_view = gaussians.get_features()[gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree + 1)**2)
    gaussians_hair.xyz_gradient_accum = torch.zeros((gaussians_hair._features_dc.shape[0], 1), device="cuda")
    gaussians_hair.denom = torch.zeros((gaussians_hair._features_dc.shape[0], 1), device="cuda")
    # import ipdb;ipdb.set_trace()
    gaussians_hair.xyz_gradient_accum_multi_view = torch.zeros((15,gaussians_hair._features_dc.shape[0] + gaussians.mask_precomp.shape[0], 52), device="cuda")
    gaussians_hair.denom_multi_view = torch.zeros((15,gaussians_hair._features_dc.shape[0] + gaussians.mask_precomp.shape[0]), device="cuda")
    # import ipdb; ipdb.set_trace()
    # Pick a random Camera
    viewpoint_stack = None
    batch_size = opt.batch_size
    if not viewpoint_stack:
        viewpoint_stack = scene.getTrainCameras()
    # sampler = DynamicFrameSampler(viewpoint_stack, loss_threshold = 0.0035) # l1_loss
    # sampler = DynamicFrameSampler(viewpoint_stack, loss_threshold = 0.050)   # total_loss
    # viewpoint_stack_loader = DataLoader(viewpoint_stack, batch_size=batch_size, shuffle=False, num_workers=0, sampler=sampler, collate_fn=list)
    viewpoint_stack_loader = DataLoader(viewpoint_stack, batch_size=batch_size,shuffle=True,num_workers=16,collate_fn=list, pin_memory=True)
    # sampler.select_timestep(1,'False')
    # sampler.select_timestep(1,'False')
    # sampler.num_time_steps = 10
    loader = iter(viewpoint_stack_loader)
    # current_loss = 0.5
    # for i in range(0,1000000):
    #     try:
    #         current_loss = current_loss - 0.06
    #         viewpoint_cams = next(loader)
    #     except StopIteration:
    #         print("Next iteration XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
    #         sampler.update(current_loss)
    #         current_loss = 0.5
    #         loader = iter(viewpoint_stack_loader)
    #         viewpoint_cams = next(loader)
    #     print(len(loader))
    #     # print("current_loss: ", current_loss, "  camera_id ", viewpoint_cams[0].camera_id,"    ",viewpoint_cams[0].time_step)
    # exit()
    # bg_color = [1, 1, 1, 0, 0, 0, 0, 0, 0, 100] if dataset.white_background else [0, 0, 0, 0, 0, 0, 0, 0, 0, 100]
    bg_color = [1, 1, 1, 0, 0, 0, 0, 0, 0, 0] if dataset.white_background else [0, 0, 0, 0, 0, 0, 0, 0, 0, 100]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    iter_start = torch.cuda.Event(enable_timing = True)
    iter_end = torch.cuda.Event(enable_timing = True)
    opt.iterations = 30000
    coarse_iteration = 200
    grid_active_interval = 100
    sparse_iteration = 0
    ema_loss_for_log = 0.0
    progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training")
    first_iter += 1
    # render_state = "corase"
    render_state = "fine"
    gaussians_hair.deformation_state = 'coarse'
    iteration_fine = 1
    # cam_id_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    cam_id_list = [0,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,3,4,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,4,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,12,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,11,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,11,12,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
    # cam_id_list = [0,1,2,3,4,5,6,7,8,9,10,11,13,14,15]
    # cam_id_list = [0,1,2,3,4,5,6,7,9,10,11,12,13,14,15]
    # cam_id_list = [0,1,2,3,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    # cam_id_list = [0,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    # gaussians_hair.deformaton_scale = 1e-5
    # gaussians_hair.deformaton_scale = 1e-1
    gaussians_hair.deformaton_pts_scale = 1e-2
    gaussians_hair.deformaton_color_scale = 1
    gaussians_hair.deformaton_hf_pts_scale = 1e-7
    gaussians_hair.deformaton_hf_color_scale = 1e-3
    gaussians_hair.deformaton_coarse_scale = 1e-1
    loss_dq = deque(maxlen=50)
    loss_average_dq = deque(maxlen=50)
    loss_velocity_dq = deque(maxlen=50)
    loss_step = 25
    dense_state = False
    max_grad = 1e-3 
    current_loss = 0.0
    freeze_params = False
    time_list = {}
    # current_iterations = len(loader)
    last_iter = first_iter
    viewpoint_indices = list(range(15))
    cam_id = 0
    for iteration in range(first_iter, opt.iterations + 1):
        gaussians_hair.iteration = iteration
        # import ipdb; ipdb.set_trace()
        if network_gui.conn == None:
            network_gui.try_connect()
        while network_gui.conn != None:
            try:
                net_image_bytes = None
                custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
                if custom_cam != None:
                    custom_cam.time = 0.0
                    net_image = render_hair_weight_fine(custom_cam, gaussians, gaussians_hair, pipe, background, render_state, scaling_modifer)["render"]
                    net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
                network_gui.send(net_image_bytes, dataset.source_path)
                if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
                    break
            except Exception as e:
                network_gui.conn = None
        iter_start.record()
        torch.cuda.synchronize()
        start_time = time.time()
        gaussians_hair.update_learning_rate(iteration,iteration_fine)
        try:
            viewpoint_cams = next(loader)
        except StopIteration:
            # if iteration-last_iter > 1000:
            #     sam_up = sampler.update(current_loss/len(loader),True)
            #     last_iter = iteration
            # else:
            #     sam_up = sampler.update(current_loss/len(loader),False)
            #     if sam_up:
            #         last_iter = iteration
            loader = iter(viewpoint_stack_loader)
            viewpoint_cams = next(loader)
        
        current_loss = 0.0
        # if viewpoint_cams[0].time_step != 7:
        #     continue
            # gaussians_hair.set_deformation_scale(1e-3, 1e-3, iteration, opt.iterations + 1)
            # print(gaussians_hair.deformaton_scale)
        # viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
        Ll1 = 0.0
        Lssim = 0.0
        Lmask = 0.0
        Lorient = 0.0
        Lsds = 0.0
        tv_loss = 0.0
        hair_smoothness = 0.0
        # if iteration % grid_active_interval == 0 and iteration > coarse_iteration:
        #     _viewpoint_stack = viewpoint_stack[:16]
        # else:
        #     _viewpoint_stack = [viewpoint_stack[(iteration-1) % 16]]
        # import ipdb; ipdb.set_trace()
        torch.cuda.synchronize()
        time2 = time.time()
        view_cam_sum = 0
        # for i, viewpoint_cam in enumerate(_viewpoint_stack):
        # for i, viewpoint_cam in enumerate([viewpoint_stack[(iteration-1) % 15]]):
        # if not viewpoint_indices:
        #     viewpoint_indices = list(range(15))
        # rand_idx = randint(0, len(viewpoint_indices) - 1)
        # vind = viewpoint_indices.pop(rand_idx)
        # for i, viewpoint_cam in enumerate([viewpoint_stack[cam_id_list[(iteration-1) % 15]]]):
        # for i, viewpoint_cam in enumerate([viewpoint_stack[cam_id_list[vind]]]):
        # import ipdb; ipdb.set_trace()
        for viewpoint_cam in viewpoint_cams:
            # cam_id = viewpoint_cam.camera_index
            cam_id = (iteration-1) % 15
            # cam_id = viewpoint_cam.camera_index
            # print("iteration: ", iteration, "  camera_id ", cam_id,"    ",viewpoint_cam.time_step)
            torch.cuda.synchronize()
            time21 = time.time()
            with torch.no_grad():
                # Head gaussians data
                # gaussians.mask_precomp = gaussians.get_label(viewpoint_cam.time_step)[..., 0] < 0.5
                gaussians_hair.training = False
                # gaussians.mask_precomp = gaussians.get_label(viewpoint_cam.time_step)[..., 0] < 0.6
                # gaussians.xyz_precomp = gaussians.get_xyz(viewpoint_cam.time_step)[gaussians.mask_precomp].detach()
                # gaussians.opacity_precomp = gaussians.get_opacity(viewpoint_cam.time_step)[gaussians.mask_precomp].detach()
                # gaussians.scaling_precomp = gaussians.get_scaling(viewpoint_cam.time_step)[gaussians.mask_precomp].detach()
                # gaussians.rotation_precomp = gaussians.get_rotation(viewpoint_cam.time_step)[gaussians.mask_precomp].detach()
                # gaussians.cov3D_precomp = gaussians.get_covariance(scaling_modifier = 1.0,time_index = viewpoint_cam.time_step)[gaussians.mask_precomp].detach()
                # gaussians.shs_view = gaussians.get_features(viewpoint_cam.time_step)[gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (gaussians.max_sh_degree + 1)**2)
            torch.cuda.synchronize()
            time3 = time.time()
            gaussians_hair.initialize_gaussians_hair(num_strands, time_step = viewpoint_cam.time_step)
            # Render
            if (iteration - 1) == debug_from:
                pipe.debug = True
            torch.cuda.synchronize()
            time4 = time.time()
            # if iteration % grid_active_interval == 0:
            # if iteration % grid_active_interval == 0:
            if (iteration > coarse_iteration and iteration % grid_active_interval == 0) or iteration == coarse_iteration:
                gaussians_hair.idx_active_mask = None
                gaussians_hair.points_mask_active_hair_indices = None
                gaussians_hair.points_mask_wo_active_hair_indices = None
                render_pkg = render_hair_weight_fine(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            elif iteration < coarse_iteration:
                render_pkg = render_hair_weight_fine(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            else:
                render_pkg = render_hair_weight_fine(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # render_pkg = render_hair_weight_fine(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # render_pkg = render_hair(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # import random
            # eps =  1e-2
            # render_pkg = render_hair_weight_fine_test(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # renders = render_pkg["renders"]
            # renders.mean().backward()
            # grad_p = gaussians_hair._xyz_static.grad.clone()
            # while True:
            #     random.seed(time.time()) 
            #     i = random.randint(0, gaussians_hair._xyz_static.shape[0])
            #     old_opacity = gaussians_hair._xyz_static.data.clone()
            #     gaussians_hair._xyz_static.data[i,0] += eps
            #     render_pkg1 = render_hair_weight_fine_test(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            #     renders1 = render_pkg1["renders"]
            #     pred = (renders1.mean() - renders.mean())/eps
            #     print("index:                      ",i)
            #     print("renders1_final.mean():      ",renders1.mean().item())
            #     print("renders_final.mean():       ",renders.mean().item())
            #     print("pred:                       ",pred.item())
            #     print("grad_n[i]:                  ",grad_p[i,0].item())
            #     # print(pred, grad_n[i])
            #     # import ipdb;ipdb.set_trace()
            #     gaussians_hair._xyz_static.data[i,0] -= eps
            # assert False  
            
            # render_pkg = render_hair_weight(viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            
            # render_pkg = render_hair_torch(gaussians_hair.gaussRender, viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # render_pkg = render_hair_torch_fine(gaussians_hair.gaussRender, viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            # render_pkg = render_hair_torch_partial(gaussians_hair.gaussRender, viewpoint_cam, gaussians, gaussians_hair, pipe, background, render_state)
            torch.cuda.synchronize()
            time5 = time.time()
            image = render_pkg["render"]
            mask = render_pkg["mask"] 
            orient_angle = render_pkg["orient_angle"]
            orient_conf = render_pkg["orient_conf"]
            viewspace_point_tensor = render_pkg["viewspace_points"]
            visibility_filter = render_pkg["visibility_filter"]
            radii = render_pkg["radii"]
            # means3D = render_pkg["means3D"]
            # colors_precomp = render_pkg["colors_precomp"]
            # opacity = render_pkg["opacity"]
            # scales = render_pkg["scales"]
            # rotations = render_pkg["rotations"]
            
            # Loss
            gt_image = viewpoint_cam.original_image
            gt_mask = viewpoint_cam.original_mask
            gt_orient_angle = viewpoint_cam.original_orient_angle
            gt_orient_conf = viewpoint_cam.original_orient_conf
            
            # image = image * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
            # mask = mask * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
            # orient_angle = orient_angle * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
            # orient_conf = orient_conf * torch.any(gt_mask>0, dim=0,keepdim=True).cuda()
            
            # tv_loss += gaussians_hair.compute_regulation(defor.time_smoothness_weight, defor.l1_time_planes, defor.plane_tv_weight)
            # tv_loss += torch.tensor(0.0, device="cuda")
            Ll1 += l1_loss(image, gt_image)
            Lssim += (1.0 - ssim(image, gt_image))
            Lmask += l1_loss(mask, gt_mask)
        
            orient_weight = torch.ones_like(gt_mask[:1])
            if opt.use_gt_orient_conf: orient_weight = orient_weight * gt_orient_conf
            if not opt.train_orient_conf: orient_conf = None
            Lorient += or_loss(orient_angle, gt_orient_angle, orient_conf, weight=orient_weight, mask=gt_mask[:1])

            if torch.isnan(Lorient).any(): Lorient = torch.zeros_like(Ll1)

            # Lsds += gaussians_hair.Lsds if hasattr(gaussians_hair, 'Lsds') and gaussians_hair.Lsds is not None else torch.zeros_like(Ll1)
            # gaussians_hair.compute_lsds(viewpoint_cam.time_step)
            # Lsds += gaussians_hair.Lsds if hasattr(gaussians_hair, 'Lsds') and gaussians_hair.Lsds is not None else torch.zeros_like(Ll1)
            
            hair_smoothness += gaussians_hair.hair_smoothness if gaussians_hair.hair_smoothness is not None else torch.zeros_like(Ll1)
            torch.cuda.synchronize()
            time51 = time.time()
            # rela_len_loss = torch.zeros_like(Ll1)
        loss = (
            Ll1 * opt.lambda_dl1 + 
            Lssim * opt.lambda_dssim + 
            Lmask * opt.lambda_dmask + 
            Lorient * opt.lambda_dorient +
            # Lsds * opt.lambda_dsds +
            # tv_loss * opt.lambda_tv +
            hair_smoothness * opt.lambda_hair_smoothness
        )
        # current_loss += Ll1.item()
        # current_loss += loss.item()
        loss.backward()
        # loss_dq.append(loss.item())
        # if iteration % loss_step == 0 and iteration >= loss_dq.maxlen:
        #     loss_average = np.mean(loss_dq)
        #     loss_average_dq.append(loss_average)
        #     if len(loss_average_dq) >= 2:
        #         loss_average_diff = max(loss_average_dq[-2] - loss_average_dq[-1],0)
        #         loss_velocity_dq.append(loss_average_diff)
        #         if len(loss_velocity_dq) >= 2:
        #             loss_velocity_rate = loss_velocity_dq[-1] / (loss_velocity_dq[-2] + 1e-8)
        #             print("loss_velocity_rate:", loss_velocity_rate)
        #             print("loss_velocity:", loss_average_diff)
        #             print("loss_average:         ", loss_average)
        #             if (loss_velocity_rate < 0.4 or loss_average_diff < 0.001) and loss_average_diff > 1e-6:
        #                 dense_state = True
        #                 print("dense_state:", dense_state)
            
        # print_autograd_graph(loss.grad_fn)
        # import ipdb; ipdb.set_trace()
        # gaussians_hair._xyz_static.grad 
        torch.cuda.synchronize()
        time6 = time.time()
        # import ipdb; ipdb.set_trace()
        if iteration < coarse_iteration and iteration >= coarse_iteration - 15: 
        # N_total = 32
        # if iteration % N_total >=1 and iteration % N_total <= 16:
        # if True:
            # viewspace_point_tensor = torch.cat([means3D.grad, colors_precomp.grad, opacity.grad, scales.grad, rotations.grad], dim=1)
            # import ipdb; ipdb.set_trace()
            N_hair = gaussians_hair._pts_wo_root_static.shape[0]
            N_head = gaussians.mask_precomp.shape[0]
            if gaussians_hair._features_dc.grad is None:
                features_all = torch.cat([torch.zeros_like(gaussians_hair._features_dc).cuda(), torch.zeros_like(gaussians_hair._features_rest).cuda()],dim = 1).view(N_hair, -1)
            else:
                features_all = torch.cat([gaussians_hair._features_dc.grad, gaussians_hair._features_rest.grad],dim = 1).view(N_hair, -1)
            if gaussians_hair._orient_conf.grad is None:
                gaussians_hair_orient_conf_grad = torch.zeros_like(gaussians_hair._orient_conf).cuda()
            else:
                gaussians_hair_orient_conf_grad = gaussians_hair._orient_conf.grad
            # import ipdb;ipdb.set_trace()
            
            viewspace_point_hair_tensor = torch.cat([gaussians_hair._pts_wo_root_static.grad, features_all, gaussians_hair_orient_conf_grad], dim=1)
            viewspace_point_head_tensor = torch.zeros(N_head, viewspace_point_hair_tensor.shape[1]).cuda()
            gaussians_hair.add_densification_stats_average(cam_id, torch.cat([viewspace_point_head_tensor, viewspace_point_hair_tensor], dim=0))
        if (iteration > coarse_iteration and iteration % grid_active_interval == 0) or iteration == coarse_iteration:
        # if False:
        # if iteration % N_total == 16:
        # if iteration >= 16:
        # if iteration >= coarse_iteration:
            N_hair = gaussians_hair._pts_wo_root_static.shape[0]
            N_head = gaussians.mask_precomp.shape[0]
            if gaussians_hair._features_dc.grad is None:
                features_all = torch.cat([torch.zeros_like(gaussians_hair._features_dc).cuda(), torch.zeros_like(gaussians_hair._features_rest).cuda()],dim = 1).view(N_hair, -1)
            else:
                features_all = torch.cat([gaussians_hair._features_dc.grad, gaussians_hair._features_rest.grad],dim = 1).view(N_hair, -1)
            if gaussians_hair._orient_conf.grad is None:
                gaussians_hair_orient_conf_grad = torch.zeros_like(gaussians_hair._orient_conf).cuda()
            else:
                gaussians_hair_orient_conf_grad = gaussians_hair._orient_conf.grad
            features_all = torch.cat([gaussians_hair._features_dc.grad, gaussians_hair._features_rest.grad],dim = 1).view(N_hair, -1)
            viewspace_point_hair_tensor = torch.cat([gaussians_hair._pts_wo_root_static.grad, features_all, gaussians_hair_orient_conf_grad], dim=1)
            viewspace_point_head_tensor = torch.zeros(N_head, viewspace_point_hair_tensor.shape[1]).cuda()
            gaussians_hair.add_densification_stats_average(cam_id, torch.cat([viewspace_point_head_tensor, viewspace_point_hair_tensor], dim=0))
            gaussians_hair.update_active_set(nums_gaussian_head = gaussians.mask_precomp.shape[0], iteration = iteration, sparse_iteration = sparse_iteration, 
                                            coarse_iteration = coarse_iteration, grid_active_interval = grid_active_interval, grid_active_threshold = 1e-1)
            # gaussians_hair.vis_active_set(nums_gaussian_head = gaussians.mask_precomp.shape[0], iteration = iteration, sparse_iteration = sparse_iteration, 
            #                                 coarse_iteration = coarse_iteration, grid_active_interval = grid_active_interval, grid_active_threshold = 1e-1)
        # if viewspace_points.grad is not None:
        #     viewspace_points_grad = viewspace_points.grad
        #     gaussians_hair.add_densification_stats(viewspace_points_grad, visibility_filter)
        # if iteration % 100 == 0:
        #     gaussians_hair.densify(viewpoint_cam.time_step)
        torch.cuda.synchronize()
        time7 = time.time()
        # dense
        # if iteration < opt.densify_until_iter:
        #     # Keep track of max radii in image-space for pruning
        #     # import ipdb;ipdb.set_trace()
        #     viewspace_point_tensor_hair_grad = viewspace_point_tensor.grad[gaussians.mask_precomp.shape[0]:]
        #     visibility_filter_hair  = visibility_filter[gaussians.mask_precomp.shape[0]:]
        #     radii = radii[gaussians.mask_precomp.shape[0]:]
        #     gaussians_hair.max_radii2D[visibility_filter_hair] = torch.max(gaussians_hair.max_radii2D[visibility_filter_hair], radii[visibility_filter_hair])
        #     gaussians_hair.add_densification_stats(viewspace_point_tensor_hair_grad, visibility_filter_hair)

        #     if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
        #         size_threshold = 20 if iteration > opt.opacity_reset_interval else None
        #         gaussians_hair.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold, radii)
                
            # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
            #     gaussians_hair.reset_opacity()
                
        # Optimizer step
        if iteration < opt.iterations:
            for param_group in gaussians_hair.optimizer.param_groups:
                for param in param_group["params"]:
                    if param.grad is not None and torch.any(torch.isnan(param.grad)):
                        gaussians_hair.optimizer.zero_grad(set_to_none = True)
            # if (iteration > coarse_iteration and iteration % grid_active_interval == 0) or iteration == coarse_iteration:
            if False:
            # if iteration % N_total >=1 and iteration % N_total <= 16 and False:
                gaussians_hair.optimizer.zero_grad(set_to_none = True)  
                gaussians_hair.optimizer_gaussianWeight.zero_grad(set_to_none = True)
            else:              
                gaussians_hair.optimizer.step()
                gaussians_hair.optimizer.zero_grad(set_to_none = True)  
                # gaussians_hair.optimizer_gaussianWeight.step()
                # gaussians_hair.optimizer_gaussianWeight.zero_grad(set_to_none = True)
        # if (iteration > coarse_iteration and iteration % grid_active_interval == 0) or iteration == coarse_iteration:
        #     gaussians_hair.update_wo_active_set_data(num_strands, scene, render_hair_weight_wo_active,(pipe, background,render_state))
        # if dense_state == True:
        #     gaussians_hair.densify_and_split_hair(viewpoint_cam.time_step)
        #     dense_state = False
        iter_end.record()
        torch.cuda.synchronize()
        end_time = time.time()
        time_list["D_L"] = time2 - start_time
        time_list["G_P"] = time3 - time21
        time_list["GH_P"] = time4 - time3
        time_list["Render"] = time5 - time4
        time_list["Loss_Cal"] = time51 - time5
        time_list["Loss_Backward"] = time6 - time51
        time_list["active_set_update"] = time7 - time6
        time_list["Opt"] = end_time - time7
        torch.cuda.empty_cache()
        with torch.no_grad():
            # Progress bar
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
            # dx = gaussians_hair._d_xyz.reshape(-1,3).norm(dim=-1).mean().detach().cpu().numpy()
            if iteration % 10 == 0:
                # progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}","Loader": f"{int(len(loader))}","Dx": f"{dx:.{7}f}"})
                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
                # progress_bar.set_postfix({"loader": f"{int(len(loader))}"})
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()

            # Log and save
            training_report(tb_writer, iteration, Ll1, Lssim, Lmask, Lorient, Lsds, tv_loss, hair_smoothness, loss, gaussians_hair.loss_pts, gaussians_hair.pts_wo_rot, gaussians_hair.loss_feat,gaussians_hair.shs_view_hair,gaussians_hair._d_xyz_norm
                            ,l1_loss, or_loss, iter_start.elapsed_time(iter_end), end_time - start_time, time_list, testing_iterations, scene, gaussians_hair, render_hair, (pipe, background,render_state))
            if (iteration % 2000 == 0):
                print("\n[ITER {}] Saving Gaussians".format(iteration))
                # scene.save(iteration)
                hair_path_curves = os.path.join(model_path_curves, "point_cloud/iteration_{}/point_cloud.ply".format(iteration))
                gaussians_hair.save_ply(hair_path_curves)

            if (iteration % 1000 == 0):
                print("\n[ITER {}] Saving Checkpoint".format(iteration))
                os.makedirs(model_path_curves + "/checkpoints/"+time_str, exist_ok=True)
                gaussians_hair.initialize_gaussians_hair(num_strands,time_step=0)
                torch.save((gaussians_hair.capture(), iteration), model_path_curves + f"/checkpoints/{time_str}/" + str(iteration) + ".pth")
               
            # if (iteration % 2000 == 0):
            #     print("\n[ITER {}] Saving renders".format(iteration))
            #     os.makedirs(model_path_curves + "/renders/"+time_str, exist_ok=True)
            #     # gaussians_hair.initialize_gaussians_hair(num_strands,time_step=0)
            #     # torch.save((gaussians_hair.capture(), iteration), model_path_curves + f"/checkpoints/{time_str}/" + str(iteration) + ".pth")
            #     test = [scene.getTestCameras()[idx % len(scene.getTestCameras())] for idx in range(0, 127, 10)]
            #     for idx, viewpoint in enumerate(test):
            #         num_strands = 20_000
            #         gaussians_hair.initialize_gaussians_hair(num_strands, time_step = viewpoint.time_step)
            #         render_pkg = render_hair_weight_fine(viewpoint, scene.gaussians, gaussians_hair, (pipe, background,render_state))
            #         image = torch.clamp(render_pkg["render"] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda(), 0.0, 1.0)
            #         mask = torch.clamp(render_pkg["mask"] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda(), 0.0, 1.0)
            #         orient_angle = torch.clamp(render_pkg["orient_angle"], 0.0, 1.0)
            #         orient_conf = render_pkg["orient_conf"]
            #         orient_conf_vis = (1 - 1 / (orient_conf + 1)) * mask[:1]
                    
                

def prepare_output_and_logger(args, model_path_curves, time_str):    
    if not model_path_curves:
        if os.getenv('OAR_JOB_ID'):
            unique_str=os.getenv('OAR_JOB_ID')
        else:
            unique_str = str(uuid.uuid4())
        model_path_curves = os.path.join("./output/", unique_str[0:10])
        
    # Set up output folder
    print("Output folder: {}".format(model_path_curves))
    os.makedirs(model_path_curves, exist_ok = True)
    with open(os.path.join(model_path_curves, "cfg_args"), 'w') as cfg_log_f:
        cfg_log_f.write(str(Namespace(**vars(args))))

    # Create Tensorboard writer
    # time_str = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
    tb_writer = None
    if TENSORBOARD_FOUND:
        tb_writer = SummaryWriter(os.path.join(model_path_curves, "logs",time_str))
    else:
        print("Tensorboard not available: not logging progress")
    return tb_writer

def training_report(tb_writer, iteration, Ll1, Lssim, Lmask, Lorient, Lsds, tv_loss, hair_smoothness, loss, loss_pts, pts_wo_rot, loss_feat, shs_view_hair, _d_xyz_norm, l1_loss, or_loss, GPU_elapsed ,elapsed, time_list, testing_iterations, scene : Scene, gaussians_hair, renderFunc, renderArgs):
    gaussians_hair.time_used += elapsed
    gaussians_hair.GPU_time_used += GPU_elapsed
    if tb_writer:
        tb_writer.add_scalar('active_set/hair_active_nums', gaussians_hair.uvs_mask_active_count, iteration)
        tb_writer.add_scalar('gaussian_grad/pos_grad', gaussians_hair.gaussian_hair_pos_grad_average, iteration)
        tb_writer.add_scalar('gaussian_grad/features_grad', gaussians_hair.gaussian_hair_features_grad_average, iteration)
        tb_writer.add_scalar('gaussian_grad/orient_conf_grad', gaussians_hair.gaussian_hair_orient_conf_grad_average, iteration)
        tb_writer.add_scalar('train_loss_patches/hair_active_nums', gaussians_hair.uvs_mask_active_count, iteration)
        tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/Lssim_loss', Lssim.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/ce_loss', Lmask.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/or_loss', Lorient.item(), iteration)
        # tb_writer.add_scalar('train_loss_patches/df_loss', Lsds.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/tv_loss', tv_loss.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/hair_smoothness', hair_smoothness.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/loss_pts', loss_pts, iteration)
        tb_writer.add_scalar('train_loss_patches/loss_feat', loss_feat, iteration)
        tb_writer.add_scalar('train_loss_patches/shs_view_hair', shs_view_hair, iteration)
        tb_writer.add_scalar('train_loss_/xyz_loss', (gaussians_hair._pts_wo_root_static).norm().item(), iteration)
        # tb_writer.add_scalar('train_loss_/opacity_loss', (gaussians_hair._opacity - gaussians_hair._opacity_init).norm().item(), iteration)
        tb_writer.add_scalar('train_loss_/rotation_loss', (gaussians_hair._rotation).norm().item(), iteration)
        tb_writer.add_scalar('train_loss_/scaling_loss', (gaussians_hair._scaling).norm().item(), iteration)
        tb_writer.add_scalar('train_loss_/features_dc_loss', (gaussians_hair._features_dc).norm().item(), iteration)
        tb_writer.add_scalar('train_loss_/features_rest_loss', (gaussians_hair._features_rest).norm().item(), iteration)
        tb_writer.add_scalar('time/iter_gpu_time', GPU_elapsed, iteration)
        tb_writer.add_scalar('time/iter_time', elapsed, iteration)
        tb_writer.add_scalar('time/total_gpu_time',gaussians_hair.GPU_time_used/1000.00, iteration)
        tb_writer.add_scalar('time/total_time',gaussians_hair.time_used, iteration)
        # tb_writer.add_scalar('time/average_gpu_time',(gaussians_hair.GPU_time_used/1000.00)/iteration, iteration)
        # tb_writer.add_scalar('time/average_time',gaussians_hair.time_used/iteration, iteration)
        tb_writer.add_scalar('time/average_gpu_time(h/w)',(gaussians_hair.GPU_time_used/1000.00)/iteration*10000/3600, iteration)
        tb_writer.add_scalar('time/average_time(h/w)',gaussians_hair.time_used/iteration*10000/3600, iteration)
        tb_writer.add_scalar('time/data_preprocess',time_list["D_L"], iteration)
        tb_writer.add_scalar('time/gaussian_preprocess',time_list["G_P"], iteration)
        tb_writer.add_scalar('time/gaussian_hair_preprocess',time_list["GH_P"], iteration)
        tb_writer.add_scalar('time/render',time_list["Render"], iteration)
        tb_writer.add_scalar('time/loss_back',time_list["Loss_Backward"], iteration)
        tb_writer.add_scalar('time/loss_cal',time_list["Loss_Cal"], iteration)
        tb_writer.add_scalar('time/active_set_update',time_list["active_set_update"], iteration)
        tb_writer.add_scalar('time/optimizer',time_list["Opt"], iteration)

    # Report test and samples of training set
    if iteration % 2000 == 0:
        torch.cuda.empty_cache()
        validation_configs = [{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(0, 2500, 219)]},
                              {'name': 'test', 'cameras' : [scene.getTestCameras()[idx % len(scene.getTestCameras())] for idx in range(0, 127, 20)]}]
        for config in validation_configs:
            if config['cameras'] and len(config['cameras']) > 0:
                Ll1_test = 0.0
                Lmask_test = 0.0
                Lorient_test = 0.0
                psnr_test = 0.0
                Lssim = 0.0
                psnr_test_mask = 0.0
                for idx, viewpoint in enumerate(config['cameras']):
                    num_strands = 20_000
                    # scene.gaussians.mask_precomp = scene.gaussians.get_label(viewpoint.time_step)[..., 0] < 0.6
                    # scene.gaussians.xyz_precomp = scene.gaussians.get_xyz(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                    # scene.gaussians.points_mask_head_indices = scene.gaussians.mask_precomp.nonzero(as_tuple=True)[0]
                    # scene.gaussians.opacity_precomp = scene.gaussians.get_opacity(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                    # scene.gaussians.scaling_precomp = scene.gaussians.get_scaling(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                    # scene.gaussians.rotation_precomp = scene.gaussians.get_rotation(viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                    # scene.gaussians.cov3D_precomp = scene.gaussians.get_covariance(scaling_modifier = 1.0,time_index = viewpoint.time_step)[scene.gaussians.mask_precomp].detach()
                    # scene.gaussians.shs_view = scene.gaussians.get_features(viewpoint.time_step)[scene.gaussians.mask_precomp].detach().transpose(1, 2).view(-1, 3, (scene.gaussians.max_sh_degree + 1)**2)
                    gaussians_hair.initialize_gaussians_hair(num_strands, time_step = viewpoint.time_step)
                    gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
                    gt_mask = torch.clamp(viewpoint.original_mask.to("cuda"), 0.0, 1.0)
                    gt_orient_angle = torch.clamp(viewpoint.original_orient_angle.to("cuda"), 0.0, 1.0)
                    gt_orient_conf = viewpoint.original_orient_conf.to("cuda")
                    gt_orient_conf_vis = (1 - 1 / (gt_orient_conf + 1)) * gt_mask[:1]
                    # render_pkg = renderFunc(viewpoint, scene.gaussians, gaussians_hair, *renderArgs)
                    render_pkg = render_hair_weight_fine(viewpoint, scene.gaussians, gaussians_hair, *renderArgs)
                    # render_pkg = render_hair_weight(viewpoint, scene.gaussians, gaussians_hair, *renderArgs)
                    # render_pkg = render_hair_torch(gaussians_hair.gaussRender,viewpoint, scene.gaussians, gaussians_hair, *renderArgs)
                    # render_pkg = render_hair_torch_fine(gaussians_hair.gaussRender,viewpoint, scene.gaussians, gaussians_hair, *renderArgs)
                    image = torch.clamp(render_pkg["render"] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda(), 0.0, 1.0)
                    mask = torch.clamp(render_pkg["mask"] * torch.any(gt_mask>0, dim=0,keepdim=True).cuda(), 0.0, 1.0)
                    orient_angle = torch.clamp(render_pkg["orient_angle"], 0.0, 1.0)
                    orient_conf = render_pkg["orient_conf"]
                    orient_conf_vis = (1 - 1 / (orient_conf + 1)) * mask[:1]
                    # if tb_writer and (idx < 5):
                    if tb_writer :
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/render", image[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/render_gt", torch.abs(gt_image[None] - image[None]), global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/render_mask", F.pad(mask, (0, 0, 0, 0, 0, 3-mask.shape[0]), 'constant', 0)[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/render_orient", vis_orient(orient_angle, mask[:1])[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/render_orient_conf", vis_orient(orient_angle, orient_conf_vis)[None], global_step=iteration)
                        # if iteration == testing_iterations[0]:
                    if (iteration == 10  or iteration == 200 or iteration == 10000 or iteration == 40200 or iteration == 1000) and tb_writer:
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/ground_truth", gt_image[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/ground_truth_mask", F.pad(gt_mask, (0, 0, 0, 0, 0, 3-gt_mask.shape[0]), 'constant', 0)[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/ground_truth_orient", vis_orient(gt_orient_angle, gt_mask[:1])[None], global_step=iteration)
                        tb_writer.add_images(config['name'] + f"_cam_{viewpoint.camera_id}_time_{viewpoint.image_name}/ground_truth_orient_conf", vis_orient(gt_orient_angle, gt_orient_conf_vis)[None], global_step=iteration)
                    Ll1_test += l1_loss(image, gt_image).mean().double()
                    Lmask_test += l1_loss(mask, gt_mask).mean().double()
                    Lorient_test += or_loss(orient_angle, gt_orient_angle, mask=gt_mask[:1], weight=gt_orient_conf).mean().double()
                    psnr_test += psnr(image, gt_image).mean().double()
                    Lssim += (1.0 - ssim(image, gt_image)).mean().double()
                    psnr_test_mask += psnr_mask(gt_image, image, gt_mask[:1]).mean().double()
                Ll1_test /= len(config['cameras'])
                Lmask_test /= len(config['cameras'])
                Lorient_test /= len(config['cameras'])
                psnr_test /= len(config['cameras'])
                Lssim /= len(config['cameras'])
                psnr_test_mask /= len(config['cameras'])
                print("\n[ITER {}] Evaluating {}: L1 {} CE {} OR {} PSNR {}".format(iteration, config['name'], Ll1_test, Lmask_test, Lorient_test, psnr_test))
                if tb_writer:
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', Ll1_test, iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - ce_loss', Lmask_test, iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - or_loss', Lorient_test, iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - Lssim', Lssim, iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
                    tb_writer.add_scalar(config['name'] + '/metrics_PSNR',psnr_test_mask,iteration)
                    tb_writer.add_scalar(config['name'] + '/metrics_Time',gaussians_hair.time_used, iteration)

        if tb_writer:
            tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity(), iteration)
            tb_writer.add_histogram("scene/label_histogram", scene.gaussians.get_label(), iteration)
            tb_writer.add_scalar('total_points', scene.gaussians.get_xyz().shape[0], iteration)
        torch.cuda.empty_cache()

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    hp_coarse = ModelHiddenParams(parser)
    tp = TextureHiddenParams(parser)
    # hp_fine = ModelHiddenParams(parser)
    parser.add_argument('--ip', type=str, default="127.0.0.1")
    parser.add_argument('--port', type=int, default=6009)
    parser.add_argument('--debug_from', type=int, default=-1)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_iterations", nargs="+", type=int, default=[1_000, 5_000, 10_000])
    parser.add_argument("--save_iterations", nargs="+", type=int, default=[1_000, 5_000, 10_000])
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[1_000, 5_000, 10_000])
    parser.add_argument("--start_checkpoint", type=str, default = None)
    parser.add_argument("--start_checkpoint_hair", type=str, default = None)
    parser.add_argument("--start_checkpoint_all", type=str, default = None)
    parser.add_argument("--hair_conf_path", type=str, default = None)
    parser.add_argument("--model_path_curves", type=str, default = None)
    parser.add_argument("--pointcloud_path_head", type=str, default = None)
    parser.add_argument("--configs", type=str, default = "")
    args = parser.parse_args(sys.argv[1:])
    args.save_iterations.append(args.iterations)
    
    print("Optimizing " + args.model_path_curves)

    # Configuration of hair strands
    with open(args.hair_conf_path, 'r') as f:
        replaced_conf = str(yaml.load(f, Loader=yaml.Loader)).replace('DATASET_TYPE', 'monocular')
        opt_hair = yaml.load(replaced_conf, Loader=yaml.Loader)
        
    if args.configs:
        import mmcv
        from utils.params_utils import merge_hparams
        config = mmcv.Config.fromfile(args.configs)
        args = merge_hparams(args, config)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    # Start GUI server, configure and run training
    network_gui.init(args.ip, args.port)
    torch.autograd.set_detect_anomaly(args.detect_anomaly)
    training(lp.extract(args), hp_coarse.extract(args), tp.extract(args), op.extract(args), opt_hair, pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.model_path_curves, args.pointcloud_path_head, args.start_checkpoint, args.start_checkpoint_hair, args.start_checkpoint_all, args.debug_from)

    # All done
    print("\nTraining complete.")
