import os
import torch
import torch.nn.functional as F
import numpy as np
import json
from argparse import ArgumentParser

def convert_to_serializable(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, torch.Tensor):
        return obj.cpu().numpy().tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj

def load_clip_features(clip_features_path, clip_labels_path):
    try:
        features = np.load(clip_features_path)  # [N, 512]
        labels = np.load(clip_labels_path)      # [N]
        
        print(f"成功加载CLIP特征，形状: {features.shape}")
        print(f"成功加载CLIP标签，形状: {labels.shape}")
        print(f"标签值范围: {labels.min()} - {labels.max()}")
        print(f"唯一标签值: {np.unique(labels).tolist()}")
        
        # 分析标签分布
        unique_labels, counts = np.unique(labels, return_counts=True)
        print("CLIP特征库标签分布:")
        for label, count in zip(unique_labels, counts):
            print(f"  标签 {label}: {count} 个样本")
        
        return features, labels
        
    except Exception as e:
        print(f"加载CLIP特征失败: {e}")
        return None, None

def load_flexigaussian_labels(model_path, total_categories):
    class_points = {}
    stats_counts_path = os.path.join(model_path, "mid_result")
    
    if not os.path.exists(stats_counts_path):
        print(f"警告: 中间结果目录不存在: {stats_counts_path}")
        return class_points
    
    print(f"从目录加载标签: {stats_counts_path}")
    
    for class_id in range(total_categories):
        # 构建文件名：class_id_{:03d}_total_categories_{:03d}_label.pth
        label_file = os.path.join(stats_counts_path, 
                                 f"class_id_{class_id:03d}_total_categories_{total_categories:03d}_label.pth")
        
        if os.path.exists(label_file):
            try:
                # 加载标签文件
                label_tensor = torch.load(label_file, map_location='cpu')
                print(f"加载类别 {class_id} 标签文件: {label_file}, 形状: {label_tensor.shape}")
                
                # 找到前景点（标签为1的点）
                foreground_indices = torch.where(label_tensor == 1)[0]
                
                if len(foreground_indices) > 0:
                    class_points[class_id] = foreground_indices
                    print(f"类别 {class_id}: 找到 {len(foreground_indices)} 个前景点")
                else:
                    print(f"类别 {class_id}: 没有前景点")
                    
            except Exception as e:
                print(f"加载类别 {class_id} 标签文件失败: {e}")
        else:
            print(f"类别 {class_id} 标签文件不存在: {label_file}")
    
    return class_points

def generate_text_prompts_for_custom_text(custom_text):

    # 基础提示
    base_prompts = [custom_text]
    
    # 添加场景上下文
    context_prompts = [
        f"a {custom_text} in a room",
        f"indoor {custom_text}",
        f"3D {custom_text}",
        f"realistic {custom_text}",
        f"{custom_text} object",
        f"{custom_text} furniture"
    ]
    
    # 合并所有提示
    all_prompts = base_prompts + context_prompts
    
    # 去重并返回
    return list(dict.fromkeys(all_prompts))  # 保持顺序的去重

def extract_custom_text_clip_features(custom_text, clip_features, clip_labels, device):

    print(f"为自定义文本 '{custom_text}' 生成CLIP特征...")
    
    try:
        import clip
        print("加载CLIP模型...")
        
        # 加载CLIP模型
        model, _ = clip.load("ViT-B/32", device=device)
        
        # 生成丰富的文本提示
        text_prompts = generate_text_prompts_for_custom_text(custom_text)
        print(f"文本提示: {text_prompts}")

        # 编码文本
        text_tokens = clip.tokenize(text_prompts).to(device)
        
        with torch.no_grad():
            # 获取文本特征
            text_features = model.encode_text(text_tokens)  # [num_prompts, 512]
            # L2归一化
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            # 计算平均特征
            text_feature = text_features.mean(dim=0)  # [512]
            # 归一化
            text_feature = text_feature / text_feature.norm()
        
        print(f"成功为文本 '{custom_text}' 生成CLIP特征，形状: {text_feature.shape}")
        return text_feature
        
    except ImportError:
        print("错误: 无法导入CLIP库，请安装: pip install clip")
        return None
    except Exception as e:
        print(f"CLIP文本编码失败: {e}")
        return None

def find_similar_classes(text_feature, clip_features, clip_labels, similarity_threshold, device):
    # 确保数据类型一致
    if text_feature.dtype != torch.float32:
        text_feature = text_feature.float()
    
    # 归一化特征
    text_feature_norm = F.normalize(text_feature.unsqueeze(0), dim=1)  # [1, 512]
    
    # 确保CLIP特征也是float32类型
    clip_features_tensor = torch.from_numpy(clip_features).to(device).float()
    clip_features_norm = F.normalize(clip_features_tensor, dim=1)  # [N, 512]
    
    # 计算余弦相似度
    similarities = torch.matmul(text_feature_norm, clip_features_norm.T).squeeze(0)  # [N]
    
    # 找到相似度超过阈值的索引
    similar_mask = similarities > similarity_threshold
    similar_indices = torch.where(similar_mask)[0]
    
    if len(similar_indices) == 0:
        print(f"未找到相似度超过 {similarity_threshold} 的类别")
        return [], []
    
    # 获取相似度超过阈值的标签和分数
    similar_labels = clip_labels[similar_indices.cpu().numpy()]
    similar_scores = similarities[similar_indices].cpu().numpy()
    
    # 按相似度排序（从高到低）
    sorted_indices = np.argsort(similar_scores)[::-1]
    similar_labels = similar_labels[sorted_indices]
    similar_scores = similar_scores[sorted_indices]
    
    # 过滤掉重复的标签，只保留最高相似度的
    unique_labels = []
    unique_scores = []
    seen_labels = set()
    
    for label, score in zip(similar_labels, similar_scores):
        if label not in seen_labels:
            unique_labels.append(label)
            unique_scores.append(score)
            seen_labels.add(label)
    
    print(f"找到 {len(unique_labels)} 个唯一相似类别:")
    for i, (label, score) in enumerate(zip(unique_labels, unique_scores)):
        print(f"  {i+1}. Category {label}: similarity {score:.4f}")
    
    return unique_labels, unique_scores

def evaluate_flexigaussian_scene_custom_text(scene_name, model_path, output_dir, total_classes, 
                                           clip_features_path, clip_labels_path, custom_text,
                                           similarity_threshold=0.7):

    print(f"Processing scene: {scene_name} with custom text: '{custom_text}'")
    
    # 1. 加载CLIP特征和标签
    clip_features, clip_labels = load_clip_features(clip_features_path, clip_labels_path)
    if clip_features is None or clip_labels is None:
        print("CLIP特征加载失败，处理终止")
        return None
    
    # 2. Load label files
    class_points = load_flexigaussian_labels(model_path, total_classes)
    print(f"Loaded {len(class_points)} class point clusters")
    
    # 3. 创建输出目录
    results_dir = os.path.join(output_dir, f"{scene_name}_custom_text_results")
    os.makedirs(results_dir, exist_ok=True)
    
    # 4. 为自定义文本进行CLIP相似度匹配
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 打印CLIP特征库的统计信息
    print(f"\nCLIP特征库信息:")
    print(f"  特征维度: {clip_features.shape}")
    print(f"  标签范围: {clip_labels.min()} - {clip_labels.max()}")
    print(f"  唯一标签: {np.unique(clip_labels).tolist()}")
    
    # Print category information
    print(f"\nCategory information:")
    for class_id, point_indices in class_points.items():
        print(f"  类别 {class_id}: {len(point_indices)} 个点")
    
    print(f"\n处理自定义文本: '{custom_text}'")
    
    # 提取自定义文本的CLIP特征
    text_feature = extract_custom_text_clip_features(custom_text, clip_features, clip_labels, device)
    if text_feature is None:
        print(f"无法为文本 '{custom_text}' 提取CLIP特征，处理终止")
        return None
    
    # 找到相似的类别
    similar_labels, similarities = find_similar_classes(
        text_feature, clip_features, clip_labels, similarity_threshold, device
    )
    
    if len(similar_labels) == 0:
        print(f"No categories found similar to text '{custom_text}'")
        return None
    
    # 收集所有相似类别的点
    all_matched_points = set()
    matched_classes = []
    
    print(f"\n匹配结果:")
    for label, similarity in zip(similar_labels, similarities):
        if label in class_points:
            point_indices = class_points[label]
            if isinstance(point_indices, torch.Tensor):
                point_indices = point_indices.cpu().numpy()
            
            # 确保点索引是Python原生int类型
            point_indices_list = point_indices.tolist() if hasattr(point_indices, 'tolist') else list(point_indices)
            all_matched_points.update(point_indices_list)
            # 确保标签ID是Python原生int类型
            matched_classes.append((convert_to_serializable(label), float(similarity), int(len(point_indices))))
            
            print(f"  FlexiGaussian类别 {label}: 相似度 {similarity:.4f}, 点数 {len(point_indices)}")
        else:
            print(f"  警告: FlexiGaussian类别 {label} 不存在")
    
    if len(all_matched_points) > 0:
        # 保存匹配结果
        results = {
            'custom_text': custom_text,
            'matched_classes': matched_classes,
            'total_points': len(all_matched_points),
            'point_indices': convert_to_serializable(list(all_matched_points))
        }
        
        # 保存为pth文件
        pth_file = os.path.join(results_dir, f"{scene_name}_{custom_text.replace(' ', '_')}_points.pth")
        torch.save({
            'custom_text': custom_text,
            'point_indices': torch.tensor(list(all_matched_points), dtype=torch.long),
            'total_points': len(all_matched_points),
            'matched_classes': matched_classes
        }, pth_file)
        
        print(f"\n文本 '{custom_text}' 匹配结果:")
        print(f"  总共匹配到 {len(all_matched_points)} 个点")
        print(f"  匹配的FlexiGaussian类别: {[mc[0] for mc in matched_classes]}")
        print(f"  结果已保存到: {pth_file}")
        
        # 保存详细的匹配结果
        results_file = os.path.join(results_dir, f"{scene_name}_{custom_text.replace(' ', '_')}_matching_results.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        
        print(f"详细匹配结果已保存到: {results_file}")
        return results
    else:
        print(f"文本 '{custom_text}' 没有匹配到任何点")
        return None

def main():
    parser = ArgumentParser(description="Process FlexiGaussian Scene Segmentation with Custom Text")
    parser.add_argument("--scene_name", type=str, required=True,
                        help="Scene name to process")
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to FlexiGaussian model directory (contains mid_result/)")
    parser.add_argument("--output_dir", type=str, default="./evaluation_results",
                        help="Output directory for results")
    parser.add_argument("--total_classes", type=int, default=20,
                        help="Total number of classes (including background)")
    parser.add_argument("--clip_features", type=str, default="clip/clip_output_features.npy",
                        help="Path to CLIP features file")
    parser.add_argument("--clip_labels", type=str, default="clip/clip_output_labels.npy",
                        help="Path to CLIP labels file")
    parser.add_argument("--custom_text", type=str, required=True,
                        help="Custom text to match (e.g., 'trash can', 'coffee table')")
    parser.add_argument("--similarity_threshold", type=float, default=0.7,
                        help="Similarity threshold for CLIP feature matching (recommended: 0.7-0.9)")
    
    args = parser.parse_args()
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 验证CLIP特征文件
    if not os.path.exists(args.clip_features):
        print(f"错误: CLIP特征文件不存在: {args.clip_features}")
        return
    
    if not os.path.exists(args.clip_labels):
        print(f"错误: CLIP标签文件不存在: {args.clip_labels}")
        return
    
    # 处理场景
    results = evaluate_flexigaussian_scene_custom_text(
        args.scene_name,
        args.model_path,
        args.output_dir,
        args.total_classes,
        args.clip_features,
        args.clip_labels,
        args.custom_text,
        args.similarity_threshold
    )
    
    if results is not None:
        print(f"\n=== Custom Text Processing Results for {args.scene_name} ===")
        print(f"Successfully processed text: '{args.custom_text}'")
        print(f"Total matched points: {results['total_points']}")
        print(f"Matched FlexiGaussian classes: {[mc[0] for mc in results['matched_classes']]}")
    else:
        print("Custom text processing failed!")

if __name__ == "__main__":
    main()
