
import os
import torch
from random import randint
from utils.loss_utils import l1_loss, ssim, l2_loss
from gaussian_renderer import render, network_gui
import sys
from scene import DynamicScene, GaussianModel
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.system_utils import searchForMaxIteration
import json

import numpy as np
from plyfile import PlyData

import torchvision

def save_tensor_img(img, name='rendering'):
    torchvision.utils.save_image(img, name+".png")

def get_ply_matrix(file_path):
    plydata = PlyData.read(file_path)
    num_vertices = len(plydata['vertex'])
    num_attributes = len(plydata['vertex'].properties)
    data_matrix = np.zeros((num_vertices, num_attributes), dtype=float)
    for i, name in enumerate(plydata['vertex'].data.dtype.names):
        data_matrix[:, i] = plydata['vertex'].data[name]
    return data_matrix

def get_attribute(sh_degree):
    frest_dim = 3 * (sh_degree + 1) * (sh_degree + 1) - 3
    attribute_names = []
    attribute_names.append('x')
    attribute_names.append('y')
    attribute_names.append('z')
    attribute_names.append('nx')
    attribute_names.append('ny')
    attribute_names.append('nz')
    for i in range(3):
        attribute_names.append('f_dc_' + str(i))
    for i in range(frest_dim):
        attribute_names.append('f_rest_' + str(i))
    attribute_names.append('opacity')
    for i in range(3):
        attribute_names.append('scale_' + str(i))
    for i in range(4):
        attribute_names.append('rot_' + str(i))

    return attribute_names


def calculating_importance_score(pcd, ratio=1e5):
    # calculate the importance score
    # importance score =  opacity + ratio * volume
    # volume = exp(pcd[:,-7] + pcd[:, -6] + pcd[:, -5])
    # opacity = sigmoid(pcd[:, -8])
    importance_score = ratio * np.exp(pcd[:,-7] + pcd[:, -6] + pcd[:, -5]) + 1/(1+np.exp(-pcd[:, -8]))
    sorted_indices = np.argsort(importance_score)[::-1]
    return sorted_indices

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    lp = ModelParams(parser)
    op = OptimizationParams(parser)
    pp = PipelineParams(parser)
    parser.add_argument('--debug_from', type=int, default=-1)
    parser.add_argument('--detect_anomaly', action='store_true', default=False)
    parser.add_argument("--test_iterations", nargs="+", type=int, default=[4000])
    parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
    parser.add_argument("--start_checkpoint", type=str, default = None)

    args = parser.parse_args(sys.argv[1:])
    args.save_iterations.append(args.iterations)
    
    # print("Optimizing " + args.model_path)
    print(f"Pruining {args.model_path}")

    # Initialize system state (RNG)
    safe_state(args.quiet)

    torch.autograd.set_detect_anomaly(args.detect_anomaly)

    # prune percentage

    prune_percentage = 0.4

    # search for the last checkpoint
    pcd_path = os.path.join(args.model_path, "point_cloud")
    last_ckpt_path = os.path.join(pcd_path, "stage_0", "point_cloud_before_prune.ply")

    sh_degree = 0

    pcd = get_ply_matrix(last_ckpt_path)
    print("Loaded point cloud with shape: ", pcd.shape)
    num_points = pcd.shape[0]
    num_points_to_prune = int(num_points * prune_percentage)
    sorted_indices = np.argsort(pcd[:, -8])
    pruned_pcd = pcd[sorted_indices[num_points_to_prune:]]

    pruned_num_points = pruned_pcd.shape[0]

    print("Pruned point cloud with shape: ", pruned_pcd.shape)

    before_hierarchical_path = os.path.join(lp.extract(args).model_path, "point_cloud/stage_0/point_cloud_before_hierarchical.ply")    
    attribute_list = get_attribute(sh_degree)  

                            
    # write the new ply file
    with open(os.path.join(before_hierarchical_path), 'wb') as ply_file:
        ply_file.write(b"ply\n")
        ply_file.write(b"format binary_little_endian 1.0\n")
        ply_file.write(b"element vertex %d\n" % pruned_num_points)
        
        for attribute_name in attribute_list:
            ply_file.write(b"property float %s\n" % attribute_name.encode())
        
        ply_file.write(b"end_header\n")
        
        for i in range(pruned_num_points):
            vertex_data = pruned_pcd[i].astype(np.float32).tobytes()
            ply_file.write(vertex_data)

    print("\nTraining complete.")
