
import os
import torch
from random import randint
from utils.loss_utils import l1_loss, ssim, l2_loss
from gaussian_renderer import render, network_gui
import sys
from scene import DynamicScene, GaussianModel
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.system_utils import searchForMaxIteration
import json

import numpy as np
from plyfile import PlyData

import torchvision

def save_tensor_img(img, name='rendering'):
    torchvision.utils.save_image(img, name+".png")

def gaussian_kde_pytorch(data, bandwidth=None):
    if bandwidth is None:
        n = data.shape[0]
        std = torch.std(data)
        bandwidth = (4 * std**5 / (3 * n)) ** 0.2

    grid = torch.linspace(data.min().item(), data.max().item(), 1024, device=data.device)
    grid_size = grid.shape[0]

    kernel_grid = torch.linspace(-3 * bandwidth, 3 * bandwidth, grid_size, device=data.device)
    kernel = torch.exp(-0.5 * (kernel_grid / bandwidth)**2)
    kernel = kernel / (bandwidth * torch.sqrt(2 * torch.tensor(torch.pi)))

    hist = torch.histc(data, bins=grid_size, min=data.min().item(), max=data.max().item())

    kernel_fft = torch.fft.fft(kernel)  
    hist_fft = torch.fft.fft(hist)      
    pdf_fft = kernel_fft * hist_fft 
    pdf = torch.fft.ifft(pdf_fft).real  

    pdf = pdf / torch.sum(pdf)

    def kde(x):

        indices = torch.searchsorted(grid, x)
        return torch.take(pdf, indices)

    return kde

def entropy_regularization_loss(current_frame_gaussian):
    # current_frest = current_frame_gaussian._features_rest
    current_scale = current_frame_gaussian._scaling
    current_rotation = current_frame_gaussian._rotation
    current_opacity = current_frame_gaussian._opacity
    current_sh = current_frame_gaussian._features_dc
    current_attribute = []
    for i in range(current_scale.shape[1]):
        current_attribute.append(current_scale[:, i])
    for i in range(current_rotation.shape[1]):
        current_attribute.append(current_rotation[:, i])
    for i in range(current_opacity.shape[1]):
        current_attribute.append(current_opacity[:, i])
    for i in range(current_sh.shape[1]):
        current_attribute.append(current_sh[:, i])

    quantization_range = 255
    loss = 0.0
    for idx in range(len(current_attribute)):
        if torch.sum(current_attribute[idx]) == 0:
            return 0.0
        attribute_min = torch.min(current_attribute[idx])
        attribute_max = torch.max(current_attribute[idx])
        attribute_normalize = (current_attribute[idx] - attribute_min) / (attribute_max - attribute_min) * quantization_range
        disturb_noise = np.random.uniform(-0.5, 0.5)
        disturb_attribute = attribute_normalize + disturb_noise
        if torch.isnan(disturb_attribute).any():
            return 0.0
        kde = gaussian_kde_pytorch(disturb_attribute)

        pdf_values = kde(disturb_attribute)

        entropy = -torch.sum(torch.log2(pdf_values + 1e-10)) / current_attribute[idx].shape[0]
        loss += entropy
    return loss

def finetune(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, last_ckpt_path, last_ckpt_iter):
    first_iter = 0
    part = 6
    gaussians = GaussianModel(dataset.sh_degree)
    scene = DynamicScene(dataset)
    gaussians.load_ply(last_ckpt_path)
    gaussians.training_setup(opt)

    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    iter_start = torch.cuda.Event(enable_timing = True)
    iter_end = torch.cuda.Event(enable_timing = True)

    viewpoint_stack = None
    ema_loss_for_log = 0.0
    progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
    first_iter += 1
    for iteration in range(first_iter, opt.iterations + 1):        

        iter_start.record()

        gaussians.update_learning_rate(iteration)

        # Pick a random Camera
        if not viewpoint_stack:
            viewpoint_stack = scene.getTrainCameras().copy()
        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))

        # Render
        if (iteration - 1) == debug_from:
            pipe.debug = True

        bg = torch.rand((3), device="cuda") if opt.random_background else background

        render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
        image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
        # Loss
        gt_image = viewpoint_cam.original_image.cuda()
        Ll1 = l1_loss(image, gt_image)
        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))  + 1e-7 * entropy_regularization_loss(gaussians)

        for i in range(1,part):
            num = gaussians.get_xyz.shape[0]//part*i
            render_pkg = render(viewpoint_cam, gaussians, pipe, bg, numofgaussians=num)
            image = render_pkg["render"]
            Ll1 = l1_loss(image, gt_image)
            a = 0.25/i
            loss += a * ((1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image)))

        gaussians.optimizer.zero_grad(set_to_none = True)
        loss.backward()
        if iteration < opt.iterations:
            gaussians.optimizer.step()

        iter_end.record()

        with torch.no_grad():
            # Progress bar
            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
            if iteration % 10 == 0:
                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
                progress_bar.update(10)
            if iteration == opt.iterations:
                progress_bar.close()

            if iteration > opt.densify_from_iter:
                # Keep track of max radii in image-space for pruning
                gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
            training_report_stage(None, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, gaussians, render, part, (pipe, background))
  
    save_pcd_path = os.path.join(dataset.model_path, "point_cloud/stage_0")   


    with torch.no_grad():
        print("\n[ITER {}] Saving Gaussians".format(iteration))
        gaussians.save_ply(os.path.join(save_pcd_path, "point_cloud.ply"),save_type="base")
    

def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, gaussians, scene : DynamicScene, renderFunc, renderArgs):

    # Report test and samples of training set
    if iteration in testing_iterations:
        torch.cuda.empty_cache()
        validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 
                              {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})

        for config in validation_configs:
            if config['cameras'] and len(config['cameras']) > 0:
                l1_test = 0.0
                psnr_test = 0.0
                for idx, viewpoint in enumerate(config['cameras']):
                    image = torch.clamp(renderFunc(viewpoint, gaussians, *renderArgs)["render"], 0.0, 1.0)
                    gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
                    if config['name'] == 'test':
                        os.makedirs(os.path.join(scene.model_path, "point_cloud/render_result"), exist_ok = True)
                        img_path = os.path.join(scene.model_path, "point_cloud/render_result/iteration_{}".format(last_ckpt_iter + iteration))
                        save_tensor_img(image, img_path)
                    l1_test += l1_loss(image, gt_image).mean().double()
                    psnr_test += psnr(image, gt_image).mean().double()
                psnr_test /= len(config['cameras'])
                l1_test /= len(config['cameras'])          
                print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
        torch.cuda.empty_cache()

def training_report_stage(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : DynamicScene,  gaussians, renderFunc, stage, renderArgs):
    if tb_writer:
        tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
        tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
        tb_writer.add_scalar('iter_time', elapsed, iteration)

    # Report test and samples of training set
    if iteration in testing_iterations:
        torch.cuda.empty_cache()
        validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 
                              {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})

        for config in validation_configs:
            if config['cameras'] and len(config['cameras']) > 0:
                for i in range(stage):
                    l1_test = 0.0
                    psnr_test = 0.0
                    num = gaussians.get_xyz.shape[0]//stage*(i+1)
                    for idx, viewpoint in enumerate(config['cameras']):
                        image = torch.clamp(renderFunc(viewpoint, gaussians, *renderArgs, numofgaussians=num)["render"], 0.0, 1.0)
                        if config['name'] == 'test':
                            os.makedirs(os.path.join(scene.model_path, "point_cloud/render_result"), exist_ok = True)
                            img_path = os.path.join(scene.model_path, "point_cloud/render_result/stage_{}".format(i))
                            save_tensor_img(image, img_path)
                        gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
                        if tb_writer and (idx < 5):
                            tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
                            if iteration == testing_iterations[0]:
                                tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
                        l1_test += l1_loss(image, gt_image).mean().double()
                        psnr_test += psnr(image, gt_image).mean().double()
                    psnr_test /= len(config['cameras'])
                    l1_test /= len(config['cameras'])          
                    print("\n[Stage {} Iteration {}] Evaluating {}: L1 {} PSNR {} N {}".format(i,iteration, config['name'], l1_test, psnr_test, num))   
                if tb_writer:
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)

        if tb_writer:
            tb_writer.add_histogram("scene/opacity_histogram", gaussians.get_opacity, iteration)
            tb_writer.add_scalar('total_points', gaussians.get_xyz.shape[0], iteration)
        torch.cuda.empty_cache()

def get_ply_matrix(file_path):
    plydata = PlyData.read(file_path)
    num_vertices = len(plydata['vertex'])
    num_attributes = len(plydata['vertex'].properties)
    data_matrix = np.zeros((num_vertices, num_attributes), dtype=float)
    for i, name in enumerate(plydata['vertex'].data.dtype.names):
        data_matrix[:, i] = plydata['vertex'].data[name]
    return data_matrix

def get_attribute(sh_degree):
    frest_dim = 3 * (sh_degree + 1) * (sh_degree + 1) - 3
    attribute_names = []
    attribute_names.append('x')
    attribute_names.append('y')
    attribute_names.append('z')
    attribute_names.append('nx')
    attribute_names.append('ny')
    attribute_names.append('nz')
    for i in range(3):
        attribute_names.append('f_dc_' + str(i))
    for i in range(frest_dim):
        attribute_names.append('f_rest_' + str(i))
    attribute_names.append('opacity')
    for i in range(3):
        attribute_names.append('scale_' + str(i))
    for i in range(4):
        attribute_names.append('rot_' + str(i))

    return attribute_names

def calculating_importance_score(pcd, ratio=1e5):
    # calculate the importance score
    # importance score =  opacity + ratio * volume
    # volume = exp(pcd[:,-7] + pcd[:, -6] + pcd[:, -5])
    # opacity = sigmoid(pcd[:, -8])
    importance_score = ratio * np.exp(pcd[:,-7] + pcd[:, -6] + pcd[:, -5]) + 1/(1+np.exp(-pcd[:, -8])) #- entropys * 0
    sorted_indices = np.argsort(importance_score)[::-1]
    return sorted_indices

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    parser.add_argument('--debug_from', type=int, default=-1)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_iterations", nargs="+", type=int, default=[1500,2000,3000,4000,6000,7000,8000,9000,10000,11000,12000,13000,14000,15000])
    parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
    parser.add_argument("--start_checkpoint", type=str, default = None)
    args = parser.parse_args(sys.argv[1:])
    args.save_iterations.append(args.iterations)
    
    # print("Optimizing " + args.model_path)
    print(f"Hierarchical {args.model_path}")

    # Initialize system state (RNG)
    safe_state(args.quiet)

    torch.autograd.set_detect_anomaly(args.detect_anomaly)

    last_ckpt_iter = 4000
    # search for the last checkpoint
    pcd_path = os.path.join(args.model_path, "point_cloud")
    last_ckpt_path = os.path.join(pcd_path, "stage_0", "point_cloud_before_hierarchical.ply")

    sh_degree = 0

    pcd = get_ply_matrix(last_ckpt_path)
    print("Loaded point cloud with shape: ", pcd.shape)
    num_points = pcd.shape[0]

    sorted_indices = calculating_importance_score(pcd,ratio=1e5)
    pruned_pcd = pcd[sorted_indices]
    pruned_num_points = pruned_pcd.shape[0]

    pruned_pcd_path = last_ckpt_path.replace(".ply", "_hierarchical.ply") # added
    attribute_list = get_attribute(sh_degree)

    # write the new ply file
    with open(os.path.join(pruned_pcd_path), 'wb') as ply_file:
        ply_file.write(b"ply\n")
        ply_file.write(b"format binary_little_endian 1.0\n")
        ply_file.write(b"element vertex %d\n" % pruned_num_points)
        
        for attribute_name in attribute_list:
            ply_file.write(b"property float %s\n" % attribute_name.encode())
        
        ply_file.write(b"end_header\n")
        
        for i in range(pruned_num_points):
            vertex_data = pruned_pcd[i].astype(np.float32).tobytes()
            ply_file.write(vertex_data)

    finetune(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, pruned_pcd_path, last_ckpt_iter)

    # All done
    print("\nTraining complete.")
