import sqlite3
import struct
import torch
import json
import numpy as np
from .read_write_model import read_points3D_binary, read_images_binary
import random
import re
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

from collections import defaultdict
import heapq
import os
def save_custom_ply(filename, pcd_points, feature_dim):
    """
    将点云数据保存为自定义 PLY 文件，包含以下字段：
      - x, y, z (坐标, float)
      - red, green, blue (颜色, uchar)
      - feature0, feature1, ..., feature{feature_dim-1} (特征向量, float)

    参数：
      filename: 输出文件名（例如 "output.ply"）
      pcd_points: 列表，每个元素是一个列表，格式为 [x, y, z, r, g, b, feature0, ..., feature{C-1}]
      feature_dim: 特征向量的维度
    """
    num_points = len(pcd_points)
    header_lines = []
    header_lines.append("ply")
    header_lines.append("format ascii 1.0")
    header_lines.append(f"element vertex {num_points}")
    header_lines.append("property float x")
    header_lines.append("property float y")
    header_lines.append("property float z")
    header_lines.append("property uchar red")
    header_lines.append("property uchar green")
    header_lines.append("property uchar blue")
    for i in range(feature_dim):
        header_lines.append(f"property float feature{i}")
    header_lines.append("end_header")

    with open(filename, "w") as f:
        # 写入头部
        for line in header_lines:
            f.write(line + "\n")
        # 写入每个点的数据
        for pt in pcd_points:
            # pt 中的顺序：[x, y, z, r, g, b, feature0, ..., feature{feature_dim-1}]
            x, y, z = pt[0], pt[1], pt[2]
            # 将颜色转换为整数（uchar），假设原始数据为0~255范围内的数字
            r, g, b = int(pt[3]), int(pt[4]), int(pt[5])
            # 剩余部分为特征向量，使用6位小数格式化
            features_str = " ".join([f"{v:.6f}" for v in pt[6:]])
            # 组合成一行
            line = f"{x:.6f} {y:.6f} {z:.6f} {r} {g} {b} {features_str}".strip()
            f.write(line + "\n")


def save_pcd(features, data_folder):
    points3D_path = data_folder + "sparse/points3D.bin"
    images_file = data_folder + "sparse/images.bin"
    output_path = data_folder + "sparse/points3D_feature.ply"
    output_feather_path = data_folder + "attention/"

    images = read_images_binary(images_file)
    points3D = read_points3D_binary(points3D_path)
    pcd_points = []

    # # 假设所有图像的特征图特征维度相同，这里取 image_id 为0 的特征维度
    # feature_dim = features[0].shape[0]

    similarity_stats = []
    # 遍历每个 3D 点
    for point_id, point_data in points3D.items():
        # 获取该3D点的观察图像及对应的像素坐标
        x, y, z = point_data.xyz
        r, g, b = point_data.rgb
        random_idx = random.randint(0, len(point_data.image_ids) - 1)
        # 随机选择一个图像
        image_id = point_data.image_ids[random_idx]
        feature_idx = point_data.point2D_idxs[random_idx]
        # 获取该 image_id 对应的图像数据
        image_data = images[image_id]
        u, v = image_data.xys[feature_idx]
        image_name = image_data.name
        match = re.search(r'\d+', image_name)
        # feature_map_name = output_feather_path + features[int(match.group())]
        
        # feature_map = torch.load(feature_map_name) 
        feature_map = features[int(match.group())]
        image_path = data_folder + "images/" + image_name
        with Image.open(image_path) as img:
            width, height = img.size
        # x_resized, y_resized = get_preprocess_shape(height, width, target_length, u, v)

        if u < 0 or u > width or v < 0 or v > height:
            print(
                f"Warning: Point ({point_id}, {width}, {height}) is out of bounds for image {image_name}. Skipping.")
            continue
        feature_vector = feature_map[int(v), int(u)]
        # 构造点的完整数据：[x, y, z, r, g, b, feature_vector...]
        point_datas = [x, y, z, r, g, b] + feature_vector.tolist()
        pcd_points.append(point_datas)
        
        feature_vectors = []
        for i in range(len(point_data.image_ids)):
            image_id = point_data.image_ids[i]
            image_data = images[image_id]
            feature_idx = point_data.point2D_idxs[i]
            u, v = image_data.xys[feature_idx]
            image_name = image_data.name
            match = re.search(r'\d+', image_name)        
            
            feature_map = features[int(match.group())]
            image_path = data_folder + "images/" + image_name
            with Image.open(image_path) as img:
                width, height = img.size

            if u < 0 or u > width or v < 0 or v > height:
                print(
                    f"Warning: Point ({point_id}, {width}, {height}) is out of bounds for image {image_name}. Skipping.")
                continue
            feature_vector = feature_map[int(v), int(u)].detach().cpu().numpy()
            feature_vectors.append(feature_vector)
            if len(feature_vectors) > 1:
                # 转换为矩阵 (n_samples, n_features)
                feature_matrix = np.vstack(feature_vectors)
                
                # 计算余弦相似度矩阵
                sim_matrix = cosine_similarity(feature_matrix)
                
                # 提取上三角部分（不含对角线）
                upper_triangle = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]
                
                # 记录统计信息
                stats = {
                    "point_id": point_id,
                    "mean": np.mean(upper_triangle),
                    "max": np.max(upper_triangle),
                    "min": np.min(upper_triangle),
                    "std": np.std(upper_triangle)
                }
                similarity_stats.append(stats)
            
    # 保存自定义 PCD 文件
    # 在统计后添加可视化代码
    means = [s['mean'] for s in similarity_stats]
    plt.figure(figsize=(10, 6))
    plt.hist(means, bins=50, alpha=0.7)
    plt.title("Cosine Similarity Distribution Across Points")
    plt.xlabel("Similarity Score")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.show()
    save_custom_ply(output_path, pcd_points, 32)
    print("自定义 PCD 文件已保存到", output_path)