import os
import numpy as np
import torch
import json
from plyfile import PlyData, PlyElement
from argparse import ArgumentParser
import glob

scannet19_dict = {
    1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
    6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
    11: "picture", 12: "counter", 14: "desk", 16: "curtain",
    24: "refrigerator", 28: "shower_curtain", 33: "toilet", 34: "sink",
    36: "bathtub"
}

CLASS_COMBINATIONS = {
    "19": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36],
    "15": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 33, 34],
    "10": [1, 2, 4, 5, 6, 7, 8, 9, 10, 33]
}

def get_class_dict_for_evaluation(class_count):
    if isinstance(class_count, str) and class_count in CLASS_COMBINATIONS:
        target_ids = CLASS_COMBINATIONS[class_count]
        print(f"Using predefined {class_count} class combination: {target_ids}")
    elif isinstance(class_count, list):
        target_ids = class_count
        print(f"Using custom class combination: {target_ids}")
    else:
        target_ids = CLASS_COMBINATIONS["19"]
        print(f"Using default 19-class combination: {target_ids}")

    class_dict = {}
    class_names = []

    for i, class_id in enumerate(target_ids):
        if class_id in scannet19_dict:
            class_dict[class_id] = scannet19_dict[class_id]
            class_names.append(scannet19_dict[class_id])
        else:
            print(f"Warning: class ID {class_id} not in ScanNet19 dictionary")

    print(f"Evaluation classes: {class_names}")
    return class_dict, class_names

def read_gt_ply(file_path):
    try:
        ply_data = PlyData.read(file_path)
        vertex_data = ply_data['vertex'].data
        points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
        labels = vertex_data['label']

        if labels.dtype == np.uint16:
            labels = labels.astype(np.int32)
        elif labels.dtype == np.uint8:
            labels = labels.astype(np.int32)

        print(f"Loaded GT PLY: {file_path}, points: {len(points)}")
        print(f"  label dtype: {labels.dtype}, range: {labels.min()} - {labels.max()}")
        return points, labels
    except Exception as e:
        print(f"Failed to read GT PLY: {e}")
        return None, None

def read_pred_ply(file_path):
    try:
        ply_data = PlyData.read(file_path)
        vertex_data = ply_data['vertex'].data
        points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
        colors = np.vstack([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T
        print(f"Loaded prediction PLY: {file_path}, points: {len(points)}")
        return points, colors
    except Exception as e:
        print(f"Failed to read prediction PLY: {e}")
        return None, None

def color_to_label(colors, target_class_id):
    red_mask = (colors[:, 0] == 255) & (colors[:, 1] == 0) & (colors[:, 2] == 0)
    white_mask = (colors[:, 0] == 255) & (colors[:, 1] == 255) & (colors[:, 2] == 255)

    labels = np.zeros(len(colors), dtype=np.int64)
    labels[red_mask] = target_class_id
    labels[white_mask] = 0

    num_foreground = np.sum(red_mask)
    num_background = np.sum(white_mask)
    print(f"  Foreground(red): {num_foreground}, Background(white): {num_background}")

    return labels

def match_point_clouds(gt_points, pred_points, tolerance=1e-6):
    try:
        from scipy.spatial import cKDTree
        gt_tree = cKDTree(gt_points)
        distances, gt_indices = gt_tree.query(pred_points, distance_upper_bound=tolerance)
        valid_mask = distances < tolerance
        valid_pred_indices = np.where(valid_mask)[0]
        valid_gt_indices = gt_indices[valid_mask]
        print(f"  Matched points: {len(valid_gt_indices)}")
        return valid_gt_indices, valid_pred_indices
    except ImportError:
        print("Warning: scipy not installed, using simple matching")
        matched_gt = []
        matched_pred = []
        for i, pred_point in enumerate(pred_points):
            for j, gt_point in enumerate(gt_points):
                if np.allclose(pred_point, gt_point, atol=tolerance):
                    matched_gt.append(j)
                    matched_pred.append(i)
                    break
        return np.array(matched_gt), np.array(matched_pred)

def calculate_class_metrics(gt_labels, pred_labels, class_id):
    gt_tensor = torch.from_numpy(gt_labels)
    pred_tensor = torch.from_numpy(pred_labels)
    
    tp = torch.sum((gt_tensor == class_id) & (pred_tensor == class_id)).item()
    fp = torch.sum((gt_tensor != class_id) & (pred_tensor == class_id)).item()
    fn = torch.sum((gt_tensor == class_id) & (pred_tensor != class_id)).item()
    tn = torch.sum((gt_tensor != class_id) & (pred_tensor != class_id)).item()
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
    iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
    class_accuracy = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    
    return {
        'class_id': class_id,
        'class_name': scannet19_dict.get(class_id, f"unknown_{class_id}"),
        'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn,
        'precision': precision, 'recall': recall, 'f1_score': f1_score,
        'iou': iou, 'class_accuracy': class_accuracy
    }

def evaluate_scene(scene_name, gt_ply_path, pred_ply_dir, output_dir, class_count="19"):
    print(f"\n{'='*60}")
    print(f"Evaluate scene: {scene_name}")
    print(f"{'='*60}")

    evaluation_class_dict, evaluation_class_names = get_class_dict_for_evaluation(class_count)

    gt_points, gt_labels = read_gt_ply(gt_ply_path)
    if gt_points is None:
        return None

    pred_ply_pattern = os.path.join(pred_ply_dir, f"{scene_name}_*_segmentation.ply")
    pred_ply_files = glob.glob(pred_ply_pattern)

    if not pred_ply_files:
        print(f"No prediction PLY found: {pred_ply_pattern}")
        return None

    print(f"Found {len(pred_ply_files)} prediction files")

    all_metrics = []
    class_results = {}

    for pred_ply_file in pred_ply_files:
        filename = os.path.basename(pred_ply_file)
        class_name = filename.replace(f"{scene_name}_", "").replace("_segmentation.ply", "")

        print(f"\nProcess class: {class_name}")

        pred_points, pred_colors = read_pred_ply(pred_ply_file)
        if pred_points is None:
            continue

        class_id = None
        for cid, cname in evaluation_class_dict.items():
            if cname == class_name:
                class_id = cid
                break
        
        if class_id is None:
            print(f"  Warning: class '{class_name}' not in current evaluation class list, skip")
            continue

        pred_labels = color_to_label(pred_colors, class_id)

        gt_indices, pred_indices = match_point_clouds(gt_points, pred_points)

        if len(gt_indices) == 0:
            print(f"  Warning: no matched points, skip class {class_name}")
            continue

        matched_gt_labels = gt_labels[gt_indices]
        matched_pred_labels = pred_labels[pred_indices]

        metrics = calculate_class_metrics(matched_gt_labels, matched_pred_labels, class_id)
        all_metrics.append(metrics)
        class_results[class_name] = metrics

        print(f"  Class {class_name} (ID: {class_id}) metrics:")
        print(f"    IoU: {metrics['iou']:.4f}")
        print(f"    Precision: {metrics['precision']:.4f}")
        print(f"    Recall: {metrics['recall']:.4f}")
        print(f"    F1-Score: {metrics['f1_score']:.4f}")
        print(f"    Class Accuracy: {metrics['class_accuracy']:.4f}")
    
    
    if all_metrics:
        ious = [m['iou'] for m in all_metrics]
        miou = np.mean(ious)
        
        class_accuracies = [m['class_accuracy'] for m in all_metrics]
        macc = np.mean(class_accuracies)
        
        total_tp = sum(m['tp'] for m in all_metrics)
        total_fp = sum(m['fp'] for m in all_metrics)
        total_fn = sum(m['fn'] for m in all_metrics)
        total_tn = sum(m['tn'] for m in all_metrics)
        
        overall_accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn) if (total_tp + total_tn + total_fp + total_fn) > 0 else 0.0
        
        print(f"\nOverall evaluation:")
        print(f"  mIoU: {miou:.4f}")
        print(f"  mAcc: {macc:.4f}")
        print(f"  Overall Accuracy: {overall_accuracy:.4f}")
        print(f"  Num classes evaluated: {len(all_metrics)}")
        print(f"  Classes: {list(class_results.keys())}")
        
        results = {
            'scene_name': scene_name,
            'class_count': class_count,
            'evaluation_classes': evaluation_class_names,
            'miou': miou,
            'macc': macc,
            'overall_accuracy': overall_accuracy,
            'class_metrics': class_results,
            'all_metrics': all_metrics
        }
        os.makedirs(output_dir, exist_ok=True)
        results_file = os.path.join(output_dir, f"{scene_name}_{class_count}classes_evaluation_results.json")
        def convert_to_serializable(obj):
            if isinstance(obj, np.integer):
                return int(obj)
            elif isinstance(obj, np.floating):
                return float(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            elif isinstance(obj, dict):
                return {key: convert_to_serializable(value) for key, value in obj.items()}
            elif isinstance(obj, list):
                return [convert_to_serializable(item) for item in obj]
            else:
                return obj
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(convert_to_serializable(results), f, indent=2, ensure_ascii=False)
        print(f"Saved results to: {results_file}")
        return results
    else:
        print(f"No valid evaluation results")
        return None

def main():
    parser = ArgumentParser(description="Evaluate segmentation metrics for each class")
    parser.add_argument("--scene_name", type=str, required=True,
                        help="Scene name to evaluate")
    parser.add_argument("--gt_ply_path", type=str, required=True,
                        help="Path to ground truth PLY file")
    parser.add_argument("--pred_ply_dir", type=str, default="./evaluation_results",
                        help="Directory containing prediction PLY files")
    parser.add_argument("--output_dir", type=str, default="./evaluation_results",
                        help="Output directory for evaluation results")
    parser.add_argument("--class_count", type=str, default="19",
                        help="Number of classes to evaluate (e.g., '19', '15', '10', or a custom list)")
    
    args = parser.parse_args()
    
    results = evaluate_scene(
        args.scene_name,
        args.gt_ply_path,
        args.pred_ply_dir,
        args.output_dir,
        args.class_count
    )
    if results:
        print(f"\nEvaluation finished. Scene {args.scene_name} mIoU: {results['miou']:.4f}, mAcc: {results['macc']:.4f}")
    else:
        print(f"Evaluation failed")

if __name__ == "__main__":
    main()
