from collections import Counter
import os
from pathlib import Path
import pickle
import PIL


import multiprocessing


from tqdm import tqdm
from typing import List, Optional, Tuple
import open3d as o3d
import numpy as np
import torch
from typing import Callable, Literal, Optional, Tuple
import argparse
import json
from geometry_util import get_colored_points_from_depth
#__path__.insert(0, "superpoint_transformer", "src")
from superpoint_transformer.scripts.preprocess_point_cloud import segment_point_cloud_superpoints

def parse_args():
    p = argparse.ArgumentParser(
        description="Annotate a mesh with OneFormer and aggregate labels."
    )
    # required
    p.add_argument("--output-dir",      required=True,  help="Where to write all outputs")
    p.add_argument("--poses-file",      required=True,  help="Path to transforms.json or config.yml")
    p.add_argument("--mesh-file",       required=True,  help="Path to input .ply mesh")
    # optional / config
    p.add_argument("--load-config",     default=None,    help="(unused) placeholder for loading method config")
    p.add_argument("--label-type",      default="custom", help="Type of labels (just unused stub)")
    p.add_argument("--labels-file",     default="/mnt/usb_ssd/bieriv/oneformer/predictions/labels.txt",
                   help="Path to labels.txt")
    p.add_argument("--samples-per-frame",        type=int, default=3000, help="Number of rays per image")
    p.add_argument("--n-neighbors-for-aggregation", type=int, default=5,
                   help="K for KNN feature aggregation")
    p.add_argument('--keep-classes', type=str, default="ceiling,wall,floor,surface,door,stairs", help='Classes to keep in the output mesh')
    p.add_argument('--object-classes', type=str, default=None, help='Classes to project to the floor')
    p.add_argument('--window-classes', type=str, default="window", help='Classes to project to the floor')
    p.add_argument("--filter-edges-from-depth-maps", action="store_true",
                   help="Whether to filter edges (dummy stub)")
    p.add_argument("--edge-threshold",     type=float, default=0.1, help="Edge threshold (dummy stub)")
    p.add_argument("--edge-dilation-iterations", type=int, default=2,
                   help="Edge dilation iterations (dummy stub)")
    return p.parse_args()


def image_path_to_pred_path(image_path, seg_path):
    if "matterport" in image_path:
        return image_path.replace("/mnt/usb_ssd/bieriv/opennerf-data/nerfstudio/matterport_", seg_path).replace("/images/", "/").replace(".jpg", ".png")
    elif "scannet" in image_path:
        return image_path.replace("/mnt/usb_ssd/bieriv/opennerf-data/scannet++/data/", seg_path,).replace("/dslr/undistort_colmap/images", "").replace(".JPG", ".png")



def pick_indices_at_random(valid_mask, samples_per_frame):
    indices = torch.nonzero(torch.ravel(valid_mask))
    if samples_per_frame < len(indices):
        which = torch.randperm(len(indices))[:samples_per_frame]
        indices = indices[which]
    return torch.ravel(indices)


def run_oneformer_for_each_image(
    poses: dict,
    samples_per_frame: int,
    labels: List[str],
    filter_edges_from_depth_maps: bool,
    edge_threshold: float,
    edge_dilation_iterations: int,
    seg_path: str,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

    ray_origins = []
    points = []
    normals = []
    colors = []
    segmentations = []
    openseg_embeddings = []
    depth_is_valid = []

    print(poses.keys())

    for image_idx, pose in enumerate(tqdm(poses["frames"], desc="Processing frames")):
        file_path = pose["file_path"]
        fl_x = pose["fl_x"]
        fl_y = pose["fl_y"]
        cx = pose["cx"]
        cy = pose["cy"]
        H = pose["h"]
        W = pose["w"]
        transform_matrix = pose["transform_matrix"]

        if "depth_file_path" in pose:
            depth_path = pose["depth_file_path"]
            depth_gt = PIL.Image.open(depth_path)
            depth_gt = (np.array(depth_gt) * 0.00025).astype(np.float32)
            assert depth_gt.shape == (H, W)
        else:
            depth_gt = pose["depth"].reshape(H, W)

        image = PIL.Image.open(file_path)
        image = np.array(image)

        pred_path = image_path_to_pred_path(file_path, seg_path)
        pred = PIL.Image.open(pred_path)
        assert np.max(pred) < len(labels)
        if not pred.size == (H, W):
            pred = pred.resize((W, H), PIL.Image.NEAREST)
        pred = np.array(pred)
        assert pred.shape == (H, W)
        assert pred.max() < len(labels)
        assert pred.dtype == np.uint8

        c2w = np.array(transform_matrix).astype(np.float32)
        c2w[0:3, 1:3] *= -1

        depth_map = torch.from_numpy(depth_gt).float().cuda()
        c2w = torch.from_numpy(c2w).cuda()
        image = torch.from_numpy(image).float().cuda()
        
        valid_mask = torch.ones(H, W).bool().cuda()
        indices = pick_indices_at_random(valid_mask, samples_per_frame)

        valid_depth = (depth_map > 0).view(-1)
        assert not depth_map.isnan().any()
        assert depth_map.isfinite().all()
        # print(torch.sum(depth_map <= 0))


        depth_map_inpainted = depth_map.clone()
        depth_map_inpainted[depth_map <= 0] = 0.5
        # if torch.sum(~valid_depth) > 0.1 * H * W:
        #     print(f"Skipping frame {image_idx} due to invalid depth map")
        #     import matplotlib.pyplot as plt
        #     fig, ax = plt.subplots(1, 2)
        #     ax[0].imshow(image.cpu().numpy().astype(np.uint8))
        #     ax[1].imshow(depth_map_inpainted.cpu().numpy())
        #     plt.savefig(f"invalid_depth_map.png")
            
        #     continue

        # indices = np.meshgrid(
        #     np.arange(H),
        #     np.arange(W),
        #     indexing="ij"
        # )

        # print(depth_map.min(), depth_map.max())
        xyzs, rgbs, _ = get_colored_points_from_depth(
            depths=depth_map_inpainted,
            rgbs=image,
            features={},
            fx=fl_x,
            fy=fl_y,
            cx=cx,  # type: ignore
            cy=cy,  # type: ignore
            img_size=(W, H),
            c2w=c2w,
            mask=indices,
        )     

        # if torch.sum(~valid_depth) > 0.1 * H * W:
        #     fig, ax = plt.subplots(1, 2)
        #     ax[0].imshow(image.cpu().numpy().astype(np.uint8))
        #     ax[1].imshow(xyzs.view(H, W, 3).cpu().numpy())
        #     plt.savefig(f"invalid_depth_map2.png")

        # continue

        seg = torch.from_numpy(pred).to(indices.device).view(-1)[indices]
        valid_depth = valid_depth.view(-1)[indices]
        
        assert len(seg) == len(xyzs), f"Oneformer seg: {len(seg)}, xyzs: {len(xyzs)}"

        
        points.append(xyzs.detach().cpu().numpy())
        ray_origins.append(c2w[:3, 3].unsqueeze(0).repeat(len(xyzs), 1).detach().cpu().numpy())
        segmentations.append(seg.detach().cpu().numpy())
        depth_is_valid.append(valid_depth.detach().cpu().numpy())
        # colors.append(rgbs.detach().cpu().numpy())

    points = np.concatenate(points, axis=0)
    ray_origins = np.concatenate(ray_origins, axis=0)
    segmentations = np.concatenate(segmentations, axis=0)
    depth_is_valid = np.concatenate(depth_is_valid, axis=0)

    return ray_origins, points, segmentations, depth_is_valid


def compute_per_point_feature_knn(vertices : np.ndarray, original_points : np.ndarray, original_point_features : List[np.ndarray], k : int = 10) -> List[np.ndarray]:
    # create lookup tree
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(original_points)
    pcd_tree = o3d.geometry.KDTreeFlann(pcd)
    # find nearest neighbors
    per_point_features = [[None for _ in range(len(vertices))] for _ in range(len(original_point_features))]
    for vertex_ix, v in enumerate(vertices):
        [_, idx, _] = pcd_tree.search_knn_vector_3d(v, k)
        for i, f in enumerate(original_point_features):
            per_point_features[i][vertex_ix] = np.mean(f[idx], axis=0)
    return [np.array(p) for p in per_point_features]
    


def save_filtered(mesh, vertex_class_probabilities, class_names, output_dir, keep_classes=["ceiling", "wall", "floor", "surface", "door", "stairs"], object_classes=None, window_classes=["window"]):
    if len(vertex_class_probabilities.shape) > 1:
        vertex_class_labels = vertex_class_probabilities.argmax(axis=1)
    else:
        vertex_class_labels = vertex_class_probabilities
        vertex_class_probabilities = np.arange(len(class_names))[None, :] == vertex_class_labels[:, None]
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    class_names = np.array(class_names)
    per_vertex_classes_text = class_names[vertex_class_labels]
    ceiling_wall_floor_mask = np.isin(per_vertex_classes_text, keep_classes)
    
    ceiling_wall_floor_mesh = o3d.geometry.TriangleMesh(mesh)
    ceiling_wall_floor_mesh.remove_vertices_by_mask(~ceiling_wall_floor_mask)
    ceiling_wall_floor_mesh_classes = vertex_class_probabilities[ceiling_wall_floor_mask]
    assert len(ceiling_wall_floor_mesh_classes) == len(ceiling_wall_floor_mesh.vertices)
    o3d.io.write_triangle_mesh(f"{output_dir}/ceiling_wall_floor_mesh.ply", ceiling_wall_floor_mesh)
    np.save(f"{output_dir}/ceiling_wall_floor_mesh_classes.npy", ceiling_wall_floor_mesh_classes.astype(np.float16))
    print(f"Classes found: {Counter(per_vertex_classes_text)}")
    print(f"Classes in ceiling_wall_floor_mesh: {Counter(per_vertex_classes_text[ceiling_wall_floor_mask])}")



    window_mask = np.isin(per_vertex_classes_text, window_classes)
    object_mask = ~ceiling_wall_floor_mask & (~window_mask) if object_classes is None else np.isin(per_vertex_classes_text, object_classes)
    objects_mesh = o3d.geometry.TriangleMesh(mesh)
    objects_mesh.remove_vertices_by_mask(~object_mask)
    objects_mesh_classes = vertex_class_probabilities[object_mask]
    assert len(objects_mesh_classes) == len(objects_mesh.vertices)
    o3d.io.write_triangle_mesh(f"{output_dir}/objects_mesh.ply", objects_mesh)
    np.save(f"{output_dir}/objects_mesh_classes.npy", objects_mesh_classes)

    print(f"Classes in objects_mesh: {Counter(per_vertex_classes_text[object_mask])}")

    stair_mask = np.isin(per_vertex_classes_text, ["stairs", "stair"])
    stairs_mesh = o3d.geometry.TriangleMesh(mesh)
    stairs_mesh.remove_vertices_by_mask(~stair_mask)
    o3d.io.write_triangle_mesh(f"{output_dir}/stair_mesh.ply", stairs_mesh)

    print(f"Saved all outputs to {output_dir}")


def extract_skeleton(args):
    # Setup output
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Writing outputs to {output_dir}")

    # Decide seg_path based on poses_file name
    is_scannetpp = "scannet" in args.poses_file.lower()
    if not is_scannetpp:
        seg_path = "/mnt/usb_ssd/bieriv/oneformer/predictions/"
    else:
        seg_path = "/mnt/usb_ssd/bieriv/oneformer/scannetpp_predictions/"

    # Load mesh
    mesh = o3d.io.read_triangle_mesh(args.mesh_file)

    # Load or poses
    with open(args.poses_file, "r") as f:
        poses = json.load(f)

    # Load labels
    with open(args.labels_file, "r") as f:
        labels = [l.strip() for l in f if l.strip()]

    # Run oneformer stub
    ray_origins, points, segmentations, depth_is_valid = run_oneformer_for_each_image(
        poses=poses,
        samples_per_frame=args.samples_per_frame,
        labels=labels,
        filter_edges_from_depth_maps=args.filter_edges_from_depth_maps,
        edge_threshold=args.edge_threshold,
        edge_dilation_iterations=args.edge_dilation_iterations,
        seg_path=seg_path,
    )

    # Write out point cloud
    colors = np.random.rand(len(points), 3)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points.astype(np.float64))
    pcd.colors = o3d.utility.Vector3dVector(colors[segmentations])
    o3d.io.write_point_cloud(str(output_dir / "point_cloud.ply"), pcd)

    # Write ray set
    ray_ls = o3d.geometry.LineSet(
        points=o3d.utility.Vector3dVector(np.vstack([ray_origins, points])),
        lines=o3d.utility.Vector2iVector(
            np.array([[i, i + len(ray_origins)] for i in range(len(ray_origins))])
        )
    )
    o3d.io.write_line_set(str(output_dir / "rays.ply"), ray_ls)

    # Re-save original mesh with normals
    mesh.compute_vertex_normals()
    o3d.io.write_triangle_mesh(str(output_dir / "mesh.ply"), mesh)

    mesh_vertices = np.asarray(mesh.vertices)
    del mesh

    print("Computing aggregated segmentations")
    segmentations_one_hot = np.eye(len(labels)).astype(np.float16)[segmentations]
    aggregated_segmentations, = compute_per_point_feature_knn(
        vertices=mesh_vertices,
        original_points=points[depth_is_valid],
        original_point_features=[segmentations_one_hot[depth_is_valid]],
        k=args.n_neighbors_for_aggregation,
    )

    aggregated_vertex_hard_labels = np.argmax(aggregated_segmentations, axis=-1)
    print("Segmenting mesh")
    # spt_python = "/home/bieriv/miniconda3/envs/skeleton-extraction/bin/python"
    # os.system(
    #     "export PYTHONPATH=~/superpoint_transformer; "
    #     f"{spt_python} ~/superpoint_transformer/scripts/preprocess_point_cloud.py --input-pcd {output_dir / 'mesh.ply'} --output-dir {output_dir / 'spt'}"
    # )
    segment_point_cloud_superpoints(
        input_pcd=str(output_dir / "mesh.ply"),
        output_directory=str(output_dir / "spt")
    )

    aggregated_segmentations_torch = torch.from_numpy(aggregated_segmentations).float().cuda()
    print("Aggregating features")
    for level in tqdm(range(1, 4), desc="Aggregating features"):
        mesh_segmentation = torch.from_numpy(np.load(output_dir / "spt" / f"level_{level}_segmentation.npy").astype(np.int32)).long().cuda()
        assert len(mesh_segmentation) == len(mesh_vertices)
        n_segments = mesh_segmentation.max() + 1
        aggregated_simplified_segmentations_per_mask = torch.zeros((n_segments, len(labels))).cuda()
        for i in range(n_segments):
            mask = mesh_segmentation == i
            assert mask.sum() > 0
            probabilities = aggregated_segmentations_torch[mask].mean(axis=0)
            aggregated_simplified_segmentations_per_mask[i] = probabilities / (probabilities.sum() + 1e-6)

        # compute per mask classes
        per_segment_hard_assignments = torch.argmax(aggregated_simplified_segmentations_per_mask, dim=-1).cpu().numpy()

        np.save(output_dir / "spt" / f"level_{level}_segment_hard_assignments_simplified.npy", per_segment_hard_assignments)
        np.save(output_dir / "spt" / f"level_{level}_segment_probabilities_simplified.npy", aggregated_simplified_segmentations_per_mask.cpu().numpy().astype(np.float16))

    # also save colored version of the mesh
    mesh = o3d.io.read_triangle_mesh(str(output_dir / "mesh.ply"))
    mesh.vertex_colors = o3d.utility.Vector3dVector(colors[per_segment_hard_assignments[mesh_segmentation.cpu().numpy()]])
    o3d.io.write_triangle_mesh(str(output_dir / "spt" / "mesh_class_colored.ply"), mesh)
    np.save(output_dir / "vertex_probabilities.npy", aggregated_segmentations.astype(np.float16))
    np.save(output_dir / "vertex_hard_assignments.npy", aggregated_vertex_hard_labels.astype(np.uint16))
    np.save(output_dir / "simplified_segmentation_labels.npy", np.array(labels))

    from collections import Counter
    print("Statistics of hard labels")
    print(Counter(np.array(labels)[per_segment_hard_assignments]))
    print("Statistics of simplified labels")
    print(Counter(np.array(labels)[per_segment_hard_assignments]))
    
    
    # save ray / pcd info
    np.save(output_dir / "full_ray_origins.npy", ray_origins.astype(np.float16))
    np.save(output_dir / "full_ray_dests.npy", points.astype(np.float16))
    np.save(output_dir / "ray_is_valid.npy", depth_is_valid.astype(bool))
    if len(labels) < 256:
        np.save(output_dir / "hard_labels_simplified_segmentations.npy", segmentations.astype(np.uint8))
    else:
        np.save(output_dir / "hard_labels_simplified_segmentations.npy", segmentations.astype(np.uint16))


    # save filtered meshes
    save_filtered(
        mesh,
        aggregated_segmentations,
        labels,
        output_dir,
        keep_classes=args.keep_classes.split(","),
        object_classes=args.object_classes.split(",") if args.object_classes is not None else None,
        window_classes=args.window_classes.split(","),
    )
    print("Done!, wrote to", output_dir)





if __name__ == "__main__":
    
    args = parse_args()
    extract_skeleton(args)

        

"""
#!/bin/bash
# run_all_scenes.sh

# Set the environment variable as in your debug configuration.
export NERFSTUDIO_METHOD_CONFIGS="opengs=opennerf.opennerf_config:opengs_method,opengs-scannetpp=opennerf.opennerf_config:opengs_scannetpp_method"

# List of scene IDs to process.
matterport_val_split=(
    "2t7WUuJeko7"
    "WYY7iVyf5p8"
    "TbHJrupSAjP"
    "YFuZgdQ5vWj"
    "jtcxE69GiFV"
    "1LXtFkjw3qL"
    "5LpN3gDmAk7"
    "e9zR4mvMWw7"
    "i5noydFURQK"
    "HxpKQynjfin"
    "JeFG25nYj2p"
    "JmbYfDe2QKZ"
    "p5wJjkQkbXX"
    "r47D5H71a5s"
    "S9hNv5qa7GM"
    "17DRP5sb8fy"
)

# Loop over each scene ID.
for scene in "${matterport_val_split[@]}"; do
    echo "Processing scene: $scene"
    
    # Construct the paths for the current scene.
    output_dir="/mnt/usb_ssd/bieriv/segmented_matterport/02-25-matterport-oneformer/${scene}"
    poses_file="/mnt/usb_ssd/bieriv/opennerf-data/nerfstudio/matterport_${scene}/transforms.json"
    mesh_file="/mnt/usb_ssd/bieriv/opennerf-data/matterport/v1/scans/${scene}/poisson_meshes/${scene}_10.ply"
    
    # Run the Python script with the specified arguments.
    python scripts/annotate_mesh_with_oneformer.py \
        --output-dir="${output_dir}" \
        --poses-file="${poses_file}" \
        --mesh-file="${mesh_file}"
    
    # Check if the command failed.
    if [ $? -ne 0 ]; then
        echo "Error processing scene: ${scene}" >&2
        # Optionally exit here or continue with the next scene.
    fi
done

"""

"""
#!/bin/bash
# run_all_scannetpp_scenes.sh

# Set the environment variable for both methods.
export NERFSTUDIO_METHOD_CONFIGS="opengs=opennerf.opennerf_config:opengs_method,opengs-scannetpp=opennerf.opennerf_config:opengs_scannetpp_method"

# Read the list of Scannet++ scene directories from the file.
readarray -t scannetpp_val_split < "/mnt/usb_ssd/bieriv/opennerf-data/scannet++/splits/nvs_sem_val.txt"

demo_split=(
"6115eddb86"
"8b5caf3398"
)

# Loop over each scene ID.
# for scene in "${scannetpp_val_split[@]}"; do
for scene in "${demo_split[@]}"; do
    echo "Processing Scannet++ scene: $scene"
    
    # Construct the paths for the current scene.
    output_dir="/mnt/usb_ssd/bieriv/layout-estimation-outputs/02-27-dslr-oneformer/scannet++/${scene}"
    poses_file="/mnt/usb_ssd/bieriv/layout-estimation-outputs/01-20-dslr-mesh-fitting/scannet++/nerfstudio-outputs/outputs/scannet++_${scene}/opengs-scannetpp/01-20-dslr-mesh-fitting/config.yml"
    mesh_file="/mnt/usb_ssd/bieriv/layout-estimation-outputs/01-20-dslr-mesh-fitting/scannet++/nerfstudio-outputs/meshes/scannet++_${scene}/opengs-scannetpp/01-20-dslr-mesh-fitting/full-DepthAndNormalMapsPoisson_mesh.ply"
    
    # Run the Python script with the specified arguments.
    python scripts/annotate_mesh_with_oneformer.py \
        --output-dir="${output_dir}" \
        --poses-file="${poses_file}" \
        --mesh-file="${mesh_file}"
    
    # Check if the command failed.
    if [ $? -ne 0 ]; then
        echo "Error processing scene: ${scene}" >&2
        # Optionally exit here or continue with the next scene.
    fi
done

"""