import os
import numpy as np
import cv2
from plyfile import PlyData
import json
import argparse

def normalize_uint8(data):
    min_val = np.min(data)
    max_val = np.max(data)
    normalized = (data - min_val) / (max_val - min_val) * 255.0
    # print(max_val, min_val)
    return normalized.astype(np.uint8), min_val, max_val

def normalize_uint8_tog(data, min_val, max_val):
    normalized = (data - min_val) / (max_val - min_val) * 255.0
    return normalized.astype(np.uint8), min_val, max_val

def normalize_uint16(data):
    min_val = np.min(data)
    max_val = np.max(data)
    normalized = (data - min_val) / (max_val - min_val) * (2 ** 16 - 1)
    return normalized.astype(np.uint16), min_val, max_val

def normalize_uint32(data):
    min_val = np.min(data)
    max_val = np.max(data)
    if max_val == min_val:
        return np.zeros_like(data, dtype=np.uint32), min_val, max_val
    normalized = (data - min_val) / (max_val - min_val) * (2 ** 32 - 1)
    normalized = np.clip(normalized, 0, 2 ** 32 - 1).astype(np.uint32)
    
    return normalized, min_val, max_val

def get_ply_matrix(file_path):
    plydata = PlyData.read(file_path)
    num_vertices = len(plydata['vertex'])
    num_attributes = len(plydata['vertex'].properties)
    data_matrix = np.zeros((num_vertices, num_attributes), dtype=float)
    for i, name in enumerate(plydata['vertex'].data.dtype.names):
        data_matrix[:, i] = plydata['vertex'].data[name]
    return data_matrix

def calculate_image_size(num_points):
    image_size = 8
    while image_size * image_size < num_points:
        image_size += 8
    return image_size

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--frame_start", type=int, default=95)
    parser.add_argument("--frame_end", type=int, default=115)
    parser.add_argument("--interval", type=int, default=1)
    parser.add_argument("--ply_path", type=str, default="")
    parser.add_argument("--output_folder", type=str, default="")
    parser.add_argument("--sh_degree", type=int, default=0)
    parser.add_argument("--level", type=int, default=6)
    args = parser.parse_args()

    frame_start_init = args.frame_start
    frame_end_init = args.frame_end
    group_path = os.path.join(args.ply_path, "group_info.json")
    with open(group_path, "r") as f:
        free_group_info_json = json.load(f)
    group_num = len(free_group_info_json['frame_index'])
    all_start = free_group_info_json['frame_index'][0][0]
    interval = args.interval
    ply_path = args.ply_path
    output_folder = args.output_folder
    sh_degree = args.sh_degree
    SH_N = (sh_degree + 1) * (sh_degree + 1)
    sh_number = SH_N * 3
    level = args.level

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    min_max_json = {}
    viewer_min_max_json = {}
    group_info_json = {}

    def searchForMaxIteration(folder):
        saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
        return max(saved_iters)


    for group in range(group_num):
        frame_start = free_group_info_json['frame_index'][group][0]
        frame_end = free_group_info_json['frame_index'][group][1]
        group_size = frame_end - frame_start + 1


        group_info_json[str(group)] = {}
        group_info_json[str(group)]['frame_index'] = [group * group_size, (group + 1) * group_size - 1]
        group_info_json[str(group)]['name_index'] = [frame_start-all_start, frame_end-all_start]

        output_path = os.path.join(output_folder, f"group{group}")
        os.makedirs(output_path, exist_ok=True)

        for frame in range(frame_start, frame_end + 1, interval):
            if frame > frame_end_init:
                break

            png_ind = (frame - frame_start ) / interval

            ckpt_path = os.path.join(ply_path, str(frame), "point_cloud")


            current_data = get_ply_matrix(os.path.join(ply_path, str(frame), "point_cloud", f"stage_0", f"point_cloud.ply"))

            for l in range(level):
                num_points = current_data.shape[0] // level
                image_size = calculate_image_size(num_points=num_points)
                num_attributes = current_data.shape[1]

                min_max_json[f'{frame}_num'] = num_points * level 
                viewer_min_max_json[frame] = {}
                viewer_min_max_json[frame]['num'] = num_points * level 
                viewer_min_max_json[frame]['info'] = []

                for i in range(num_attributes):
                    if i > 2:
                        attribute_data, min_val, max_val = normalize_uint8(current_data[l*num_points:(l+1)*num_points, i])
                        min_max_json[f'{frame}_{i}_{l}_min'] = float(min_val)
                        min_max_json[f'{frame}_{i}_{l}_max'] = float(max_val)
                        viewer_min_max_json[frame]['info'].append(float(min_val))
                        viewer_min_max_json[frame]['info'].append(float(max_val))
                        attribute_data_reshaped = attribute_data.reshape(-1, 1)
                        image = np.zeros((image_size * image_size, 1), dtype=np.uint8)
                        image[:attribute_data_reshaped.shape[0], :] = attribute_data_reshaped  
                        image_reshaped = image.reshape((image_size, image_size))
                        cv2.imwrite(os.path.join(output_path, f"{frame}_{i+3}_{l}.png"), image_reshaped)
                    else: 
                        attribute_data, min_val, max_val = normalize_uint16(current_data[l*num_points:(l+1)*num_points, i])
                        min_max_json[f'{frame}_{i}_{l}_min'] = float(min_val)
                        min_max_json[f'{frame}_{i}_{l}_max'] = float(max_val)
                        viewer_min_max_json[frame]['info'].append(float(min_val))
                        viewer_min_max_json[frame]['info'].append(float(max_val))
                        attribute_data_reshaped = attribute_data.reshape(-1, 1)
                        image_odd = np.zeros((image_size * image_size, 1), dtype=np.uint8)
                        image_even = np.zeros((image_size * image_size, 1), dtype=np.uint8)
                        
                        image_even[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped & 0xff)
                        image_odd[:attribute_data_reshaped.shape[0], :] += (attribute_data_reshaped >> 8)

                        image_odd_reshaped = image_odd.reshape((image_size, image_size))
                        image_even_reshaped = image_even.reshape((image_size, image_size))
                        cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i}_{l}.png"), image_even_reshaped)
                        cv2.imwrite(os.path.join(output_path, f"{frame}_{2*i+1}_{l}.png"), image_odd_reshaped)

    with open(os.path.join(output_folder, "min_max.json"), "w") as f:
        json.dump(min_max_json, f, indent=4)

    with open(os.path.join(output_folder, "viewer_min_max.json"), "w") as f:
        json.dump(viewer_min_max_json, f, indent=4)

    with open(os.path.join(output_folder, "group_info.json"), "w") as f:
        json.dump(group_info_json, f, indent=4)