import os
import torch
import torch.nn.functional as F
import numpy as np
import json
from argparse import ArgumentParser

def convert_to_serializable(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, torch.Tensor):
        return obj.cpu().numpy().tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    else:
        return obj

def sigmoid(x):  
    return 1 / (1 + np.exp(-x))  

scannet19_dict = {
    1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
    6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
    11: "picture", 12: "counter", 14: "desk", 16: "curtain",
    24: "refrigerator", 28: "shower_curtain", 33: "toilet", 34: "sink",
    36: "bathtub"
}

def load_clip_features(clip_features_path, clip_labels_path):
    try:
        features = np.load(clip_features_path)
        labels = np.load(clip_labels_path)
        
        print(f"成功加载CLIP特征，形状: {features.shape}")
        print(f"成功加载CLIP标签，形状: {labels.shape}")
        print(f"标签值范围: {labels.min()} - {labels.max()}")
        print(f"唯一标签值: {np.unique(labels).tolist()}")
        
        unique_labels, counts = np.unique(labels, return_counts=True)
        print("CLIP特征库标签分布:")
        for label, count in zip(unique_labels, counts):
            print(f"  标签 {label}: {count} 个样本")
        
        return features, labels
        
    except Exception as e:
        print(f"加载CLIP特征失败: {e}")
        return None, None

def load_clip_label_mapping(mapping_file=None):
    if mapping_file and os.path.exists(mapping_file):
        print(f"从文件加载CLIP标签映射: {mapping_file}")
        label_to_name = {}
        name_to_label = {}
        
        try:
            with open(mapping_file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        parts = line.split(':')
                        if len(parts) == 2:
                            label_id = int(parts[0].strip())
                            label_name = parts[1].strip()
                            label_to_name[label_id] = label_name
                            name_to_label[label_name] = label_id
            
            print(f"成功加载 {len(label_to_name)} 个标签映射")
            return label_to_name, name_to_label
            
        except Exception as e:
            print(f"加载标签映射文件失败: {e}")
    
    print("使用默认CLIP标签映射（请根据实际情况修改）")
    default_mapping = {
        0: "background",
        1: "wall", 
        2: "floor", 
        3: "cabinet", 
        4: "bed", 
        5: "chair",
        6: "sofa", 
        7: "table", 
        8: "door", 
        9: "window", 
        10: "bookshelf",
        11: "picture", 
        12: "counter", 
        13: "curtain", 
        14: "sofa",
        15: "refrigerator",
        16: "shower_curtain", 
        17: "toilet", 
        18: "sink", 
        19: "bathtub"
    }
    
    name_to_label = {v: k for k, v in default_mapping.items()}
    return default_mapping, name_to_label

def load_flexigaussian_labels(model_path, total_categories):
    class_points = {}
    stats_counts_path = os.path.join(model_path, "mid_result")
    
    if not os.path.exists(stats_counts_path):
        print(f"警告: 中间结果目录不存在: {stats_counts_path}")
        return class_points
    
    print(f"从目录加载标签: {stats_counts_path}")
    
    for class_id in range(total_categories):
        label_file = os.path.join(stats_counts_path, 
                                 f"class_id_{class_id:03d}_total_categories_{total_categories:03d}_label.pth")
        
        if os.path.exists(label_file):
            try:
                label_tensor = torch.load(label_file, map_location='cpu')
                print(f"加载类别 {class_id} 标签文件: {label_file}, 形状: {label_tensor.shape}")
                foreground_indices = torch.where(label_tensor == 1)[0]
                
                if len(foreground_indices) > 0:
                    class_points[class_id] = foreground_indices
                    print(f"类别 {class_id}: 找到 {len(foreground_indices)} 个前景点")
                else:
                    print(f"类别 {class_id}: 没有前景点")
                    
            except Exception as e:
                print(f"加载类别 {class_id} 标签文件失败: {e}")
        else:
            print(f"类别 {class_id} 标签文件不存在: {label_file}")
    
    return class_points

def generate_text_prompts_for_class(class_name):
    base_prompts = [class_name]
    class_synonyms = {
        "wall": ["wall surface", "partition", "wall panel", "vertical surface", "interior wall"],
        "floor": ["floor surface", "ground", "flooring", "horizontal surface", "ground level"],
        "cabinet": ["cabinet", "cupboard", "closet", "storage unit", "storage cabinet"],
        "bed": ["bed", "mattress", "sleeping surface", "bed frame", "sleeping area"],
        "chair": ["chair", "seat", "stool", "armchair", "sitting furniture"],
        "sofa": ["sofa", "couch", "settee", "sofa chair", "living room furniture"],
        "table": ["table", "desk", "surface", "tabletop", "work surface"],
        "door": ["door", "entrance", "gateway", "doorway", "access point"],
        "window": ["window", "glass", "opening", "window pane", "natural light source"],
        "bookshelf": ["bookshelf", "shelf", "bookcase", "storage shelf", "book storage"],
        "picture": ["picture", "painting", "art", "frame", "wall decoration"],
        "counter": ["counter", "countertop", "surface", "work surface", "kitchen counter"],
        "curtain": ["curtain", "drape", "blind", "window covering", "fabric panel"],
        "refrigerator": ["refrigerator", "fridge", "freezer", "cooling appliance", "kitchen appliance"],
        "shower_curtain": ["shower curtain", "curtain", "shower", "bathroom curtain", "water barrier"],
        "toilet": ["toilet", "bathroom", "commode", "wc", "lavatory"],
        "sink": ["sink", "basin", "washbasin", "faucet", "water fixture"],
        "bathtub": ["bathtub", "bath", "tub", "bathroom fixture", "bathing area"],
        "desk": ["desk", "table", "workstation", "work surface", "office furniture"]
    }
    if class_name in class_synonyms:
        base_prompts.extend(class_synonyms[class_name])
    context_prompts = [
        f"a {class_name} in a room",
        f"indoor {class_name}",
        f"3D {class_name}",
        f"realistic {class_name}"
    ]
    all_prompts = base_prompts + context_prompts
    return list(dict.fromkeys(all_prompts))

def extract_class_clip_features(class_name, clip_features, clip_labels, device, name_to_label):
    print(f"为类别 '{class_name}' 生成CLIP特征...")
    if class_name in name_to_label:
        clip_label_id = name_to_label[class_name]
        print(f"尝试从CLIP特征库中查找标签ID {clip_label_id}")
        class_mask = clip_labels == clip_label_id
        if np.any(class_mask):
            class_features = clip_features[class_mask]
            class_feature = np.mean(class_features, axis=0)
            print(f"类别 '{class_name}' (CLIP标签ID: {clip_label_id}): 找到 {len(class_features)} 个样本，平均特征形状: {class_feature.shape}")
            return torch.from_numpy(class_feature).to(device).float()
    print(f"使用CLIP文本编码器为类别 '{class_name}' 生成特征...")
    
    try:
        import clip
        print("加载CLIP模型...")
        model, _ = clip.load("ViT-B/32", device=device)
        text_prompts = generate_text_prompts_for_class(class_name)
        print(f"文本提示: {text_prompts}")
        text_tokens = clip.tokenize(text_prompts).to(device)
        with torch.no_grad():
            text_features = model.encode_text(text_tokens)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            class_feature = text_features.mean(dim=0)
            class_feature = class_feature / class_feature.norm()
        
        print(f"成功为类别 '{class_name}' 生成CLIP特征，形状: {class_feature.shape}")
        return class_feature
        
    except ImportError:
        print("错误: 无法导入CLIP库，请安装: pip install clip")
        return None
    except Exception as e:
        print(f"CLIP文本编码失败: {e}")
        return None

def find_similar_classes(class_feature, clip_features, clip_labels, similarity_threshold, device):
    if class_feature.dtype != torch.float32:
        class_feature = class_feature.float()
    class_feature_norm = F.normalize(class_feature.unsqueeze(0), dim=1)
    clip_features_tensor = torch.from_numpy(clip_features).to(device).float()
    clip_features_norm = F.normalize(clip_features_tensor, dim=1)
    similarities = torch.matmul(class_feature_norm, clip_features_norm.T).squeeze(0)
    similar_mask = similarities > similarity_threshold
    similar_indices = torch.where(similar_mask)[0]
    
    if len(similar_indices) == 0:
        print(f"未找到相似度超过 {similarity_threshold} 的类别")
        return [], []
    similar_labels = clip_labels[similar_indices.cpu().numpy()]
    similar_scores = similarities[similar_indices].cpu().numpy()
    sorted_indices = np.argsort(similar_scores)[::-1]
    similar_labels = similar_labels[sorted_indices]
    similar_scores = similar_scores[sorted_indices]
    unique_labels = []
    unique_scores = []
    seen_labels = set()
    
    for label, score in zip(similar_labels, similar_scores):
        if label not in seen_labels:
            unique_labels.append(label)
            unique_scores.append(score)
            seen_labels.add(label)
    
    print(f"找到 {len(unique_labels)} 个唯一相似类别:")
    for i, (label, score) in enumerate(zip(unique_labels, unique_scores)):
        class_name = scannet19_dict.get(label, f"unknown_{label}")
        print(f"  {i+1}. {class_name} (ID: {label}): 相似度 {score:.4f}")
    
    return unique_labels, unique_scores

def evaluate_flexigaussian_scene_clip(scene_name, model_path, output_dir, total_classes, 
                                    clip_features_path, clip_labels_path, similarity_threshold=0.7,
                                    clip_label_mapping_file=None, use_clip_encoder=True):
    print(f"Processing scene: {scene_name} with CLIP-based similarity")
    
    clip_features, clip_labels = load_clip_features(clip_features_path, clip_labels_path)
    if clip_features is None or clip_labels is None:
        print("CLIP特征加载失败，处理终止")
        return None
    
    class_points = load_flexigaussian_labels(model_path, total_classes)
    print(f"Loaded {len(class_points)} class point clusters")
    
    results_dir = os.path.join(output_dir, f"{scene_name}_clip_results")
    os.makedirs(results_dir, exist_ok=True)
    
    class_matching_results = {}
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    if use_clip_encoder:
        print(f"\n🔧 使用CLIP文本编码器直接生成特征")
        print(f"   这将为每个类别生成实时CLIP特征，不依赖预定义的特征库")
        label_to_name, name_to_label = {}, {}
    else:
        print(f"\n🔧 使用预定义的CLIP特征库")
        print(f"   需要确保CLIP标签映射正确")
        label_to_name, name_to_label = load_clip_label_mapping(clip_label_mapping_file)
    
    print(f"\nCLIP特征库信息:")
    print(f"  特征维度: {clip_features.shape}")
    print(f"  标签范围: {clip_labels.min()} - {clip_labels.max()}")
    print(f"  唯一标签: {np.unique(clip_labels).tolist()}")
    
    if not use_clip_encoder:
        is_compatible = validate_clip_label_system(clip_labels, scannet19_dict)
        if not is_compatible:
            print(f"\n⚠️  警告: CLIP标签系统与ScanNet19不完全兼容")
            print(f"   这可能导致类别匹配错误！")
    
    print(f"\nFlexiGaussian类别信息:")
    for class_id, point_indices in class_points.items():
        print(f"  类别 {class_id}: {len(point_indices)} 个点")
    
    for class_id, class_name in scannet19_dict.items():
        print(f"\n处理类别 {class_id}: {class_name}")
        
        class_feature = extract_class_clip_features(class_name, clip_features, clip_labels, device, name_to_label)
        if class_feature is None:
            print(f"  跳过类别 {class_name} (无法提取CLIP特征)")
            continue
        
        similar_labels, similarities = find_similar_classes(
            class_feature, clip_features, clip_labels, similarity_threshold, device
        )
        
        if len(similar_labels) == 0:
            print(f"  跳过类别 {class_name} (未找到相似类别)")
            continue
        
        all_matched_points = set()
        matched_classes = []
        
        print(f"  匹配结果:")
        for label, similarity in zip(similar_labels, similarities):
            if label in class_points:
                point_indices = class_points[label]
                if isinstance(point_indices, torch.Tensor):
                    point_indices = point_indices.cpu().numpy()
                point_indices_list = point_indices.tolist() if hasattr(point_indices, 'tolist') else list(point_indices)
                all_matched_points.update(point_indices_list)
                matched_classes.append((convert_to_serializable(label), float(similarity), int(len(point_indices))))
                
                print(f"    FlexiGaussian类别 {label}: 相似度 {similarity:.4f}, 点数 {len(point_indices)}")
            else:
                print(f"    警告: FlexiGaussian类别 {label} 不存在")
        
        if len(all_matched_points) > 0:
            class_matching_results[class_id] = {
                'class_name': class_name,
                'matched_classes': matched_classes,
                'total_points': len(all_matched_points),
                'point_indices': convert_to_serializable(list(all_matched_points))
            }
            
            print(f"  类别 {class_name} 总共匹配到 {len(all_matched_points)} 个点")
        else:
            print(f"  类别 {class_name} 没有匹配到任何点")
    
    results_file = os.path.join(results_dir, f"{scene_name}_class_matching_results.json")
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump(class_matching_results, f, indent=2, ensure_ascii=False)
    
    print(f"详细匹配结果已保存到: {results_file}")
    
    for class_id, result in class_matching_results.items():
        class_name = result['class_name']
        point_indices = result['point_indices']
        pth_file = os.path.join(results_dir, f"{scene_name}_{class_name}_points.pth")
        torch.save({
            'class_id': class_id,
            'class_name': class_name,
            'point_indices': torch.tensor(point_indices, dtype=torch.long),
            'total_points': len(point_indices),
            'matched_classes': result['matched_classes']
        }, pth_file)
        
        print(f"类别 {class_name} 的点云已保存到: {pth_file}")
    
    print(f"所有结果已保存到: {results_dir}")
    return class_matching_results

def main():
    parser = ArgumentParser(description="Process FlexiGaussian Scene Segmentation with CLIP Similarity")
    parser.add_argument("--scene_name", type=str, required=True,
                        help="Scene name to process")
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to FlexiGaussian model directory (contains mid_result/)")
    parser.add_argument("--output_dir", type=str, default="./evaluation_results",
                        help="Output directory for results")
    parser.add_argument("--total_classes", type=int, default=20,
                        help="Total number of classes (including background)")
    parser.add_argument("--clip_features", type=str, default="clip/clip_output_features.npy",
                        help="Path to CLIP features file")
    parser.add_argument("--clip_labels", type=str, default="clip/clip_output_labels.npy",
                        help="Path to CLIP labels file")
    parser.add_argument("--clip_label_mapping", type=str, default=None,
                        help="Path to CLIP label mapping file (format: label_id:label_name)")
    parser.add_argument("--use_clip_encoder", action="store_true", default=True,
                        help="Use CLIP text encoder directly instead of pre-computed features")
    parser.add_argument("--similarity_threshold", type=float, default=0.85,
                        help="Similarity threshold for CLIP feature matching (recommended: 0.85-0.95)")
    
    args = parser.parse_args()
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 验证CLIP特征文件
    if not os.path.exists(args.clip_features):
        print(f"错误: CLIP特征文件不存在: {args.clip_features}")
        return
    
    if not os.path.exists(args.clip_labels):
        print(f"错误: CLIP标签文件不存在: {args.clip_labels}")
        return
    
    # 处理场景
    results = evaluate_flexigaussian_scene_clip(
        args.scene_name,
        args.model_path,
        args.output_dir,
        args.total_classes,
        args.clip_features,
        args.clip_labels,
        args.similarity_threshold,
        args.clip_label_mapping,
        args.use_clip_encoder
    )
    
    if results is not None:
        print(f"\n=== CLIP-based Processing Results for {args.scene_name} ===")
        print(f"Successfully processed {len(results)} ScanNet classes")
        
        # 打印每个类别的处理结果
        print("\nPer-class Results:")
        for class_id, result in results.items():
            class_name = result['class_name']
            total_points = result['total_points']
            matched_classes = result['matched_classes']
            print(f"  Class {class_id} ({class_name}): {total_points} points")
            print(f"    Matched FlexiGaussian classes: {[mc[0] for mc in matched_classes]}")
    else:
        print("CLIP-based processing failed!")

if __name__ == "__main__":
    main()
