import json
import numpy as np
import open3d as o3d
import torch
import os


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resolution = 0.2

def load_segs(segs_json_path):
    with open(segs_json_path, 'r') as f:
        segs = json.load(f)
    seg_indices = np.array(segs['segIndices'], dtype=np.int32)
    return seg_indices

def load_aggregation(agg_json_path):
    with open(agg_json_path, 'r') as f:
        agg = json.load(f)
    obj_to_segments = {}
    for sg in agg['segGroups']:
        obj_id = sg['objectId']
        obj_to_segments[obj_id] = sg['segments']
    return obj_to_segments

def get_object_vertices(ply_path, segs_json_path, agg_json_path, object_id):
    """
    Returns a numpy array of shape (num_vertices, 3) containing 3D points of the given object_id.
    """
    # Load mesh vertices
    mesh = o3d.io.read_triangle_mesh(ply_path)
    vertices = np.asarray(mesh.vertices)

    # Load segment assignments
    seg_ids = load_segs(segs_json_path)
    if len(seg_ids) != vertices.shape[0]:
        print(f"Warning: number of vertices {vertices.shape[0]} vs segs length {len(seg_ids)} don't match")

    # Load aggregation
    obj_to_seg = load_aggregation(agg_json_path)
    if object_id not in obj_to_seg:
        raise ValueError(f"Object ID {object_id} not found in aggregation file. Available: {list(obj_to_seg.keys())}")

    obj_segments = set(obj_to_seg[object_id])

    # Mask vertices belonging to the object's segments
    mask = np.array([seg in obj_segments for seg in seg_ids], dtype=bool)

    # Return the 3D points
    return vertices[mask]

def points_to_grid(points,resolution):

    converted_points = ((points/resolution).astype(int).astype(np.float32))*resolution

    return converted_points

def grid_indices(points,params,res):
    points = torch.tensor(points).to(device)
    x_min, y_min, z_min, x_max, y_max, z_max = params
    indices = torch.zeros_like(torch.tensor(points))
    indices[:,0] = ((points[:,0]-x_min)/res).long()
    indices[:,1] = ((points[:,1]-y_min)/res).long()
    indices[:,2] = ((points[:,2]-z_min)/res).long()
    return indices.long()

def voxelize_batch(points,X,Y,Z,params):
    """Convert a batch of point clouds into [B, X*Y*Z] flattened voxel grids."""

    voxel = torch.zeros((X, Y, Z), device=device)
    indices = grid_indices(points, params, resolution)
    voxel[indices[:, 0], indices[:, 1], indices[:, 2]] = 1
    # voxels.append(voxel.view(-1))
     # flatten immediately
    return voxel



def iou_calc(pred_points, target_obj_id, scene, data_path, device='cuda'):
    prefix_path = '/home/user01/main_folder/ScanNet/'
    prefix_path = data_path

    ply_path = os.path.join(prefix_path, scene, "data", f"{scene}_vh_clean_2.ply")
    segs_json_path = os.path.join(prefix_path, scene, "data", f"{scene}_vh_clean_2.0.010000.segs.json")
    agg_json_path = os.path.join(prefix_path, scene, "data", f"{scene}.aggregation.json")

    obj_id = target_obj_id  # replace with your target ids
    # object_points = {}

    # for obj_id in target_ids:
    target_points = get_object_vertices(ply_path, segs_json_path, agg_json_path, obj_id)

    # object_points[obj_id] = points
    x_min_gt = np.min(target_points[:, 0])
    x_max_gt = np.max(target_points[:, 0])
    y_min_gt = np.min(target_points[:, 1])
    y_max_gt = np.max(target_points[:, 1])
    z_min_gt = np.min(target_points[:, 2])
    z_max_gt = np.max(target_points[:, 2])
    # print(x_min_gt, y_min_gt, z_min_gt)
    # print(x_max_gt, y_max_gt, z_max_gt)
    x_min_pred = np.min(pred_points[:, 0])
    x_max_pred = np.max(pred_points[:, 0])
    y_min_pred = np.min(pred_points[:, 1])
    y_max_pred = np.max(pred_points[:, 1])
    z_min_pred = np.min(pred_points[:, 2])
    z_max_pred = np.max(pred_points[:, 2])

    volume_gt = (x_max_gt-x_min_gt)*(y_max_gt-y_min_gt)*(z_max_gt-z_min_gt)

    volume_pred = (x_max_pred-x_min_pred) * (y_max_pred-y_min_pred) * (z_max_pred-z_min_pred)
    
    print(volume_gt, volume_pred)
    x_max = np.min([x_max_gt,x_max_pred])
    y_max = np.min([y_max_gt,y_max_pred])
    z_max = np.min([z_max_gt,z_max_pred])
    x_min = np.max([x_min_gt,x_min_pred])
    y_min = np.max([y_min_gt,y_min_pred])
    z_min = np.max([z_min_gt,z_min_pred])
    intersection = (x_max-x_min)*(y_max-y_min)*(z_max-z_min)

    iou = intersection/(volume_gt + volume_pred-intersection)
    if((z_max-z_min)<0.0  or (y_max-y_min)<0.0 or (x_max-x_min)<0.0):
        iou = 0.0

    return iou

