import os
import random
import numpy as np
import torch
import matplotlib.pyplot as plt

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from SceneForge.data.utils.data import build_scene_from_point_clouds

# Initialize the tokenizer and the vLLM model for Qwen2.5-Instruct.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)
llm = LLM(
    model="Qwen/Qwen2.5-7B-Instruct",
    dtype="half",
    max_model_len=8192,             # Reduce the maximum sequence length
    gpu_memory_utilization=0.9      # Increase memory allocation for KV cache
)

# --- Utility: Set 3D Axes Equal ---
def set_axes_equal(ax):
    """
    Set equal scaling for a 3D plot.
    """
    x_limits = ax.get_xlim3d()
    y_limits = ax.get_ylim3d()
    z_limits = ax.get_zlim3d()
    x_range = abs(x_limits[1] - x_limits[0])
    y_range = abs(y_limits[1] - y_limits[0])
    z_range = abs(z_limits[1] - z_limits[0])
    plot_radius = 0.5 * max([x_range, y_range, z_range])
    x_middle = np.mean(x_limits)
    y_middle = np.mean(y_limits)
    z_middle = np.mean(z_limits)
    ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
    ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
    ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])

# seed
seed = 23
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# --- Main Script ---
def main():
    folder_path = "samples"  # Replace with your folder path containing .npy files.
    files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.npy')]
    if not files:
        print("No npy files found in the folder.")
        return

    # Randomly select between 1 and min(5, total files).
    N = random.randint(1, min(5, len(files)))
    selected_files = random.sample(files, N)
    print(f"Selected {N} files:")
    for file in selected_files:
        print(file)
    
    point_clouds = []
    captions = []
    
    for file in selected_files:
        data = np.load(file, allow_pickle=True).item()
        # Use blip_caption if available; otherwise, use msft_caption.
        caption = data.get('blip_caption') or data.get('msft_caption') or "No caption available."
        captions.append(caption)
        
        xyz = data['xyz']  # Expected shape: (N_points, 3)

        # swap y and z
        xyz = xyz[:, [0, 2, 1]]



        rgb = data['rgb']  # Expected shape: (N_points, 3)
        
        # Ensure the shape is (3, N_points).
        if xyz.shape[0] != 3:
            xyz = xyz.T
        if rgb.shape[0] != 3:
            rgb = rgb.T
        
        # Stack xyz and rgb to form a (6, N_points) array.
        pc = np.vstack([xyz, rgb])
        pc_tensor = torch.from_numpy(pc).float()
        point_clouds.append(pc_tensor)
    
    total_points = sum(pc.shape[1] for pc in point_clouds)
    out_size = 2048 if total_points > 2048 else total_points
    min_size_per_sample = 256  # Adjust as needed.
    
    scene, refined_caption = build_scene_from_point_clouds(tokenizer, llm, sampling_params, point_clouds, captions, out_size, min_size_per_sample)
    
    print("Refined Caption:\n", refined_caption)
    
    # Instead of displaying the plot, save it as an image.
    coords = scene[:3, :].T  # (num_points, 3)
    colors = scene[3:6, :].T  # (num_points, 3)
    if colors.max() > 1:
        colors = colors / 255.0

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=colors, s=3)
    set_axes_equal(ax)
    plt.title("Combined Point Cloud")

    # rotate the camera view
    ax.view_init(elev=30, azim=60)


    # Save the figure to a file instead of displaying it.
    plt.savefig("combined_point_cloud.png", dpi=300)
    plt.close(fig)

if __name__ == "__main__":
    main()
