import h5py
from scipy.spatial.transform import Rotation as R
import numpy as np
import torch
class BasicPointCloudDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, args):
        self.file_path = file_path
        self.hdf5_file = h5py.File(file_path, 'r')
        self.point_clouds_group = self.hdf5_file['point_clouds']
        self.num_point_clouds = len(self.point_clouds_group)
        self.pcls_per_class = self.num_point_clouds // 4
        self.indices = list(range(self.num_point_clouds))
        self.std_dev = args.std_dev
        self.clip = args.clip
        self.rotate_data = args.rotate_data
        self.contr_loss_weight = args.contr_loss_weight
        self.sampled_points = args.sampled_points
        self.max_curve = 5
        self.min_curve = 1
        self.smallest_angle = 60
        self.max_angle = 120
        self.max_curve_diff = args.max_curve_diff
        self.min_curve_diff = args.min_curve_diff
        self.constant = self.max_curve / (2 * np.cos(np.radians(self.smallest_angle) / 2)) + 0.05
        self.int_K_const =( (self.max_curve + self.max_curve_diff + 10e-6)**2 / (2 * np.pi) )
    def __len__(self):
        return self.num_point_clouds

    def __getitem__(self, idx):
        point_cloud_name = f"point_cloud_{self.indices[idx]}"
        # Load metadata from attributes
        info = {key: self.point_clouds_group[point_cloud_name].attrs[key] for key in
                    self.point_clouds_group[point_cloud_name].attrs}
        # enforce basic plane patch 5 pct of the time for planes and for edges
        class_label = info['class']
        if class_label in [0,4]:
            if np.random.rand() < 0.05:
                info = {k: 0 for k, v in info.items()}
        info['idx']= self.indices[idx]
        angle = info['angle']
        radius = info['radius']
        edge_label = info['edge']
        bias = 0.25

        [min_len, max_len] = [0.45, 0.55]
        bounds, point_cloud = samplePcl(angle=angle, radius=radius,class_label=class_label,sampled_points=self.sampled_points,min_len=min_len,max_len=max_len, bias=bias, info=info, edge_label=edge_label)

        point_cloud1 = torch.tensor(point_cloud, dtype=torch.float32)
        if self.rotate_data:
            rot_orig, point_cloud1 = random_rotation(point_cloud1)

        # permute points
        shuffled_indices = torch.randperm(self.sampled_points) + 1
        permuted_indices = torch.cat((torch.tensor([0]), shuffled_indices), dim=0)
        point_cloud1 = point_cloud1[permuted_indices]

        #Add noise to point cloud
        if self.std_dev != 0:
            noise = torch.normal(0, self.std_dev, size=point_cloud1.shape, dtype=torch.float32, device=point_cloud1.device)
            noise = torch.clamp(noise, min=-self.clip, max=self.clip)
            point_cloud1 = point_cloud1 + noise
            point_cloud1 = point_cloud1 - point_cloud1[0,:]

        if self.contr_loss_weight  != 0:
            count, old_k1, old_k2, new_k1, new_k2, bounds_neg, contrastive_point_cloud = sampleContrastivePcl(angle=angle,radius=radius,class_label=class_label,sampled_points=self.sampled_points,
                                                           min_len=min_len,max_len=max_len, bias=bias, info=info,min_curve_diff=self.min_curve_diff,
                                                           max_curve_diff=self.max_curve_diff, constant=self.constant,edge_label=edge_label,
                                                           bounds=bounds,  min_curve=self.min_curve, max_curve=self.max_curve,int_K_const=self.int_K_const)


            if class_label == 4:
                positive_point_cloud = point_cloud
            else:
                bounds_pos,positive_point_cloud = samplePcl(angle=angle, radius=radius, class_label=class_label,
                                                          sampled_points=self.sampled_points, min_len=min_len,max_len=max_len, bias=bias,
                                                          info=info, bounds=bounds, edge_label=edge_label)

            contrastive_point_cloud = torch.tensor(contrastive_point_cloud, dtype=torch.float32)
            neg_rot,contrastive_point_cloud = random_rotation(contrastive_point_cloud)

            point_cloud2 = torch.tensor(positive_point_cloud, dtype=torch.float32)
            pos_rot, point_cloud2 = random_rotation(point_cloud2)


            if self.std_dev != 0:
                noise = torch.normal(0, self.std_dev, size=point_cloud2.shape, dtype=torch.float32,
                                     device=point_cloud2.device)
                noise = torch.clamp(noise, min=-self.clip, max=self.clip)
                point_cloud2 = point_cloud2 + noise
                point_cloud2 = point_cloud2 - point_cloud2[0, :]


                contrastive_noise = torch.normal(0, self.std_dev, size=contrastive_point_cloud.shape, dtype=torch.float32,
                                     device=contrastive_point_cloud.device)
                contrastive_noise = torch.clamp(contrastive_noise, min=-self.clip, max=self.clip)
                contrastive_point_cloud = contrastive_point_cloud + contrastive_noise
                contrastive_point_cloud = contrastive_point_cloud - contrastive_point_cloud[0, :]
        else:
            # point_cloud2 = torch.tensor((0))
            # contrastive_point_cloud = torch.tensor((0))
            orig_norm_point = torch.tensor([-info['d'], -info['e'], 1], dtype=torch.float32, device=point_cloud1.device)
            if not (angle >0 or radius > 0):
                if ((class_label in [1,2] and info['k1'] + info['k2'] > 0) or
                        (class_label==3 and (abs(info['k1']) >  abs(info['k2'])))):
                    orig_norm_point *= -1
            if angle > 0 and class_label==1:
                orig_norm_point *= -1
            normal_vec = torch.matmul(orig_norm_point, (torch.tensor(rot_orig, dtype=torch.float32)).T)
            normal_vec = ( normal_vec / torch.norm(normal_vec) )
            # plot_point_clouds(point_cloud1,normal_vec.reshape(1, 3), title=f"{class_label}: {[angle, radius, info['k1'], info['k2']]}")
            return {"point_cloud": point_cloud1, "info": info, "normal_vec": normal_vec}

        # if class_label in [0,1,2,3]:
        #     if  radius>0 or angle>0:
        #         plot_point_clouds(point_cloud1 @ rot_orig, point_cloud2 @ pos_rot, contrastive_point_cloud @ neg_rot, np.load("one_clean.npy"),axis_range=None,
        #                           title=f'COUNT: {count} XXX neg; class: {class_label}, angle: {angle:.2f}, radius: {radius:.2f}; old_k1: {old_k1:.2f},new_k1: {new_k1:.2f} || old_k2: {old_k2:.2f},new_k2: {new_k2:.2f}')
        #         a =1

        return {"point_cloud": point_cloud1, "point_cloud2": point_cloud2, "contrastive_point_cloud":contrastive_point_cloud, "info": info}
        # return {"point_cloud": point_cloud1, "point_cloud2": point_cloud2, "contrastive_point_cloud":contrastive_point_cloud, "info": info, "count": count}


def samplePcl(angle,radius,class_label,sampled_points, bias, min_len,max_len, info,edge_label=0, bounds=None):
    cur_class_label = class_label
    if bounds is not None:
        cur_bounds = [bound * (1 + np.random.uniform(-0.1, 0.1)) for bound in bounds]
    else:
        cur_bounds = None
    if angle != 0:
        if cur_class_label == 1 or edge_label == 1:
            r_tri, point_cloud = sample_pyramid(n_points=sampled_points, head_angle_rad=np.radians(angle), bias=bias,min_len=min_len,max_len=max_len, bounds=cur_bounds)
            new_bounds = [-r_tri,r_tri,-r_tri,r_tri]
        if cur_class_label == 2 or edge_label == 2:
            new_bounds, point_cloud = generate_surfaces_angles_and_sample(sampled_points, angle, min_len=min_len,max_len=max_len,bounds=cur_bounds)

    elif radius != 0:

        if cur_class_label == 1 or edge_label == 1:
            point_cloud = sample_sphere_point_cloud(radius=radius, num_of_points=sampled_points,bounds=cur_bounds)
            new_bounds = [-radius, radius, -radius, radius]
        if cur_class_label == 2 or edge_label == 2:
            new_bounds, point_cloud = sample_cylinder_point_cloud(radius=radius, min_len=min_len,max_len=max_len, num_of_points=sampled_points,bounds=cur_bounds)
    else:
        new_bounds, point_cloud = samplePoints(info['a'], info['b'], info['c'], info['d'], info['e'], count=sampled_points, min_len=min_len,max_len=max_len,bounds=cur_bounds)
    if class_label == 4:
        # point_cloud = sampleHalfSpacePoints(point_cloud)
        point_cloud = find_representative_point(point_cloud)
    return new_bounds, point_cloud

def sampleContrastivePcl(angle,radius,class_label,sampled_points, bias, min_len,max_len, info,min_curve_diff, max_curve_diff, constant,max_curve, min_curve,int_K_const,  edge_label=0, bounds=None):
    cur_class_label = class_label
    count = 0
    if bounds is not None:
        cur_bounds = [bound * (1 + np.random.uniform(-0.1, 0.1)) for bound in bounds]
    else:
        cur_bounds = None
    if angle != 0:
        if cur_class_label == 1 or edge_label==1:
            angle_rad = np.radians(angle)
            cur_gauss_curv = (2 * np.pi - angle_rad * 3) * int_K_const
            cur_curve = np.clip(np.sqrt(cur_gauss_curv), min_curve, max_curve)
            old_k1 = old_k2 = cur_curve
            angle_vals = []
            boundaries = np.clip( [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff], min_curve, max_curve)
            # boundaries = [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff]
            for cur_val in boundaries:
                new_angle_rad = (2 * np.pi - (cur_val**2 / int_K_const)) / 3
                angle_vals.append(new_angle_rad)

            a, b, c, d = angle_vals
            int_1 = [a, b]
            int_2 = [d, c]
            if (np.any(np.isnan(int_1)) or boundaries[0]==max_curve):
                int_1 = int_2
            if (np.any(np.isnan(int_2)) or boundaries[2]==min_curve):
                int_2 = int_1
            prob = 0.5
            interval = int_1 if np.random.uniform(0, 1) < prob else int_2
            new_angle_rad = np.random.uniform(interval[0],interval[1])

            r_tri, contrastive_point_cloud = sample_pyramid(n_points=sampled_points, head_angle_rad=new_angle_rad, bias=bias,min_len=min_len,max_len=max_len,bounds=cur_bounds)
            new_bounds = [-r_tri, r_tri, -r_tri, r_tri]
            new_k1 = new_k2 = np.sqrt( (2 * np.pi - new_angle_rad * 3) * int_K_const)
            # new_angle_deg = np.degrees(new_angle_rad)

        if cur_class_label == 2 or edge_label==2:
            angle_rad = np.radians(angle)
            cur_curve = constant * ( 2 * np.cos(angle_rad/ 2))
            old_k1 = cur_curve
            old_k2 = 0
            angle_vals = []
            boundaries = np.clip( [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff], min_curve, max_curve)
            # boundaries = [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff]
            for cur_val in boundaries:
                x = np.clip(cur_val / (2 * constant),-1,1)
                new_angle_rad = 2 *  np.arccos(x)
                angle_vals.append(np.degrees(new_angle_rad))
            a,b,c,d = angle_vals
            int_1 = [a,b]
            int_2 = [d,c]
            if (np.any(np.isnan(int_1)) or boundaries[0]==max_curve):
                int_1 = int_2
            if (np.any(np.isnan(int_2)) or boundaries[2]==min_curve):
                int_2 = int_1
            prob = 0.5
            interval = int_1 if np.random.uniform(0, 1) < prob else int_2
            new_angle_deg = np.random.uniform(interval[0],interval[1])
            new_bounds, contrastive_point_cloud = generate_surfaces_angles_and_sample(sampled_points, new_angle_deg,min_len=min_len,max_len=max_len, bounds=cur_bounds)
            new_k1 = constant * ( 2 * np.cos(np.radians(new_angle_deg)/ 2))
            new_k2 = 0
        # if (((old_k1 < new_k1) and (angle < new_angle_deg)) or ((old_k1 > new_k1) and (angle > new_angle_deg)) == True):
        #     raise Exception("Something went wrong with curvature calculations and the angles given")

    elif radius != 0:
        # Maximum values for spheres etc. is defined such that they wont be too different from rest of data; 1.5<curve<3
        cur_curve = 1 / radius
        old_k1 = old_k2 = cur_curve
        #radius should be
        max_curve_diff = min(max_curve_diff, 1)
        min_curve_diff = min(min_curve_diff, 0.5)
        rad_vals = []
        boundaries = np.clip( [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff], 1.5, 3)
        # boundaries = [cur_curve + max_curve_diff, cur_curve + min_curve_diff, cur_curve - max_curve_diff,cur_curve - min_curve_diff]
        for cur_val in boundaries:
            rad_vals.append(1/cur_val)
        a, b, c, d = rad_vals
        int_1 = [a, b]
        int_2 = [d, c]
        if (np.any(np.isnan(int_1)) or boundaries[0] == 3):
            int_1 = int_2
        if (np.any(np.isnan(int_2)) or boundaries[2] == 1.5):
            int_2 = int_1
        prob = 0.5
        interval = int_1 if np.random.uniform(0, 1) < prob else int_2
        new_radius = np.random.uniform(interval[0], interval[1])
        if cur_class_label == 1 or edge_label==1:
            contrastive_point_cloud = sample_sphere_point_cloud(radius=new_radius, num_of_points=sampled_points,bounds=cur_bounds)
            new_bounds = [-radius, radius, -radius, radius]
            new_k1 = new_k2 = ( 1 / new_radius )
        if cur_class_label == 2 or edge_label==2:
            old_k1 = ( 1 / radius)
            old_k2 = 0
            new_bounds, contrastive_point_cloud = sample_cylinder_point_cloud(radius=new_radius, min_len=min_len,max_len=max_len, num_of_points=sampled_points,bounds=cur_bounds)
            new_k1 = ( 1 / new_radius )
            new_k2 = 0
    else:
        a, b, c, d, e = info['a'], info['b'], info['c'], info['d'], info['e']
        K_orig, H_orig = compute_curvatures([a,b,c,d,e])

        discriminant_orig = H_orig ** 2 - K_orig
        old_k1 = H_orig + np.sqrt(discriminant_orig)
        old_k2 = H_orig - np.sqrt(discriminant_orig)
        while True:
            # noise_to_add = np.random.normal(0, 0.1, 5)
            noise_to_add = np.random.normal(0, 0.065, 5)
            K_cont, H_cont = compute_curvatures([a, b, c, d, e] + noise_to_add)
            discriminant_cont = H_cont ** 2 - K_cont
            k1_cont = H_cont + np.sqrt(discriminant_cont)
            k2_cont = H_cont - np.sqrt(discriminant_cont)

            temp_max_diff = max(abs(k1_cont - old_k1), abs(k2_cont - old_k2))

            if (((temp_max_diff > min_curve_diff)) and ((temp_max_diff < max_curve_diff))):
                a = info['a'] + noise_to_add[0]
                b = info['b'] + noise_to_add[1]
                c = info['c'] + noise_to_add[2]
                d = info['d'] + noise_to_add[3]
                e = info['e'] + noise_to_add[4]
                # print(f"{temp_max_diff}, ")
                break
            count += 1
        new_bounds, contrastive_point_cloud = samplePoints(a, b, c, d, e, count=sampled_points, min_len=min_len,max_len=max_len, bounds=cur_bounds)
        new_k1 = k1_cont
        new_k2 = k2_cont
    if class_label == 4:
        # contrastive_point_cloud = sampleHalfSpacePoints(contrastive_point_cloud)
        contrastive_point_cloud = find_representative_point(contrastive_point_cloud)
    return count, old_k1, old_k2, new_k1, new_k2, new_bounds, contrastive_point_cloud

def compute_curvatures(coeffs):
    a, b, c, d, e = coeffs
    denom = (1 + d**2 + e**2)
    K = (4 * a * b - c**2) / (denom**2)
    H = (a * (1 + e**2) - d * e * c + b * (1 + d**2)) / (denom**(3/2))
    return K, H


def samplePoints(a, b, c, d, e, count, center_point=np.array([0,0,0]), min_len=0.5,max_len=2,bounds=None):
    def surface_function(x, y):
        return a * x**2 + b * y**2 + c * x * y + d * x + e * y
    if bounds is None:
        size_x = np.random.uniform( 2 * min_len,  2 * max_len)
        pct_pos_x = np.random.uniform(0.2, 0.8)
        upper_bound_x = size_x * pct_pos_x
        lower_bound_x = -(size_x *(1- pct_pos_x))
        size_y = np.random.uniform( 2 * min_len,  2 * max_len)
        pct_pos_y = np.random.uniform(0.2, 0.8)
        upper_bound_y = size_y * pct_pos_y
        lower_bound_y = -(size_y *(1- pct_pos_y))
    else:
        [lower_bound_x,upper_bound_x,lower_bound_y,upper_bound_y] = bounds
        
    alpha_x = np.clip(np.random.normal(loc=0.5, scale=0.1), 0.2, 0.8)
    N1_x, N2_x = np.random.multinomial(count - 10, [alpha_x, 1-alpha_x]) + np.array([5, 5])
    # 3. Generate N random points in the square [-1, 1] x [-1, 1]
    x_coords_neg = np.random.uniform(lower_bound_x, 0, N1_x)
    x_coords_pos = np.random.uniform(0, upper_bound_x, N2_x)
    x_samples = np.concatenate((x_coords_neg,x_coords_pos))
    x_samples = x_samples[np.random.permutation(x_samples.shape[0])]
    alpha_y = np.clip(np.random.normal(loc=0.5, scale=0.1), 0.2, 0.8)
    N1_y, N2_y = np.random.multinomial(count - 10, [alpha_y, 1-alpha_y]) + np.array([5, 5])
    y_coords_neg = np.random.uniform(lower_bound_y, 0, N1_y)
    y_coords_pos = np.random.uniform(0, upper_bound_y, N2_y)
    y_samples = np.concatenate((y_coords_neg, y_coords_pos))
    y_samples = y_samples[np.random.permutation(y_samples.shape[0])]


    # Evaluate the surface function at the random points
    z_samples = surface_function(x_samples, y_samples)

    # Create an array with the sampled points
    sampled_points = np.column_stack((x_samples, y_samples, z_samples))

    # Concatenate the centroid [0, 0, 0] to the beginning of the array
    centroid = np.expand_dims(center_point, axis=0)
    sampled_points_with_centroid = np.concatenate((centroid, sampled_points), axis=0)

    return [lower_bound_x,upper_bound_x,lower_bound_y,upper_bound_y],sampled_points_with_centroid
def sampleHalfSpacePoints(sampled_points_with_centroid):
    center_point_idx = np.argsort(np.linalg.norm(sampled_points_with_centroid, axis=1))[np.random.choice(np.arange(-5,0))]
    sampled_points_with_centroid = sampled_points_with_centroid - sampled_points_with_centroid[center_point_idx, :]
    sampled_points_with_centroid[center_point_idx, :] = (sampled_points_with_centroid[0, :]).copy()
    sampled_points_with_centroid[0, :] = np.array([[0, 0, 0]])
    return sampled_points_with_centroid


def find_representative_point(point_cloud):
    # Find min and max points for each axis
    min_idx = [np.argmin(point_cloud[:, i]) for i in range(3)]
    max_idx = [np.argmax(point_cloud[:, i]) for i in range(3)]
    full_list = min_idx + max_idx

    # Compute norms and find the threshold for the 5 smallest points
    norms = np.linalg.norm(point_cloud, axis=1)
    smallest_norms = sorted(norms)
    min_value = smallest_norms[5]

    # Filter out indices of points that are in the top 5 smallest norms
    filtered_indices = [idx for idx in full_list if norms[idx] > min_value]
    if not filtered_indices:
        raise ValueError("No points left after filtering the top 5 smallest norms.")
    # Find the index with the smallest norm among the remaining indices
    center_point_idx = min(filtered_indices, key=lambda idx: norms[idx])

    point_cloud = point_cloud - point_cloud[center_point_idx, :]
    point_cloud[center_point_idx, :] = (point_cloud[0, :]).copy()
    point_cloud[0, :] = np.array([[0, 0, 0]])

    return point_cloud
def generate_surfaces_angles_and_sample(N, angle,min_len,max_len, bounds=None):
    angle_rad = np.radians((180 - angle) / 2)
    if bounds is None:
        size_x = np.random.uniform(2* min_len, 2* max_len)
        pct_pos_x = np.random.uniform(0.2, 0.8)
        upper_bound_x = size_x * pct_pos_x
        lower_bound_x = -(size_x *(1- pct_pos_x))
        size_y = np.random.uniform(2* min_len, 2* max_len)
        pct_pos_y = np.random.uniform(0.2, 0.8)
        upper_bound_y = size_y * pct_pos_y
        lower_bound_y = -(size_y *(1- pct_pos_y))
    else:
        [lower_bound_x,upper_bound_x,lower_bound_y, upper_bound_y] = bounds

    # 2. Compute the slopes (m1 and m2) for the surfaces
    m1 = np.tan(angle_rad)  # slope for the left surface (x < 0)
    m2 = -m1  # slope for the right surface (x >= 0)

    alpha_x = np.clip(np.random.normal(loc=0.5, scale=0.1), 0.2, 0.8)
    N1_x, N2_x = np.random.multinomial(N - 10, [alpha_x, 1-alpha_x]) + np.array([5, 5])
    # 3. Generate N random points in the square [-1, 1] x [-1, 1]
    x_coords_neg = np.random.uniform(lower_bound_x, 0, N1_x)
    x_coords_pos = np.random.uniform(0, upper_bound_x, N2_x)
    x_coords = np.concatenate((x_coords_neg,x_coords_pos))
    alpha_y = np.clip(np.random.normal(loc=0.5, scale=0.1), 0.2, 0.8)
    N1_y, N2_y = np.random.multinomial(N - 10, [alpha_y, 1-alpha_y]) + np.array([5, 5])
    y_coords_neg = np.random.uniform(lower_bound_y, 0, N1_y)
    y_coords_pos = np.random.uniform(0, upper_bound_y, N2_y)
    y_coords = np.concatenate((y_coords_neg, y_coords_pos))

    # 4. Calculate the corresponding z values based on the surfaces
    z_coords = np.where(x_coords < 0, m1 * x_coords, m2 * x_coords)
    # z_coords = np.abs(x_coords)

    # 5. Stack the points into a single array
    points = np.stack((x_coords, y_coords, z_coords), axis=-1)
    center = np.array([0,0,0])
    points = np.vstack((center,points))
    # center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[np.random.choice([0, 1, 2])]
    center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[0]
    points = points - points[center_point_idx, :]

    # must fix - is wrong! first point must be centered!
    points[center_point_idx, :] = (points[0, :]).copy()
    points[0, :] = np.array([[0, 0, 0]])
    return [lower_bound_x,upper_bound_x,lower_bound_y, upper_bound_y], points
def sample_cylinder_point_cloud(radius, min_len,max_len, num_of_points,top_half=True, bounds=None):
    if bounds is None:
        size_x = np.random.uniform(2* min_len, 2* max_len)
        pct_pos_x = np.random.uniform(0.2, 0.8)
        upper_bound_x =  size_x * pct_pos_x
        lower_bound_x = -(size_x *(1- pct_pos_x))
        size_y = np.random.uniform(2* min_len, 2* max_len)
        pct_pos_y = np.random.uniform(0.2, 0.8)
        upper_bound_y = size_y * pct_pos_y
        lower_bound_y = -(size_y *(1- pct_pos_y))
    else:
        [lower_bound_x,upper_bound_x,lower_bound_y,upper_bound_y] = bounds


    if top_half:
        theta = np.random.uniform(-0.5*np.pi, 0.5*np.pi, num_of_points)
    else:
        theta = np.random.uniform(0.5*np.pi, 1.5*np.pi, num_of_points)

    # Sample random heights (z) along the length of the cylinder
    alpha = np.clip(np.random.normal(loc=0.5, scale=0.1), 0.2, 0.8)
    N1, N2 = np.random.multinomial(num_of_points - 10, [alpha, 1-alpha]) + np.array([5, 5])
    # 3. Generate N random points in the square [-1, 1] x [-1, 1]
    x_coords_neg = np.random.uniform(lower_bound_x, 0, N1)
    x_coords_pos = np.random.uniform(0, upper_bound_x, N2)
    x = np.hstack([x_coords_neg, x_coords_pos])


    if bounds is not None:
        if radius > upper_bound_y:
            radius = upper_bound_y
    # Compute the (x, y) coordinates on the circular cross-section
    z = radius * np.cos(theta)
    y = radius * np.sin(theta)

    # Stack the coordinates into a (num_of_points, 3) array
    point_cloud = np.stack((x, y, z), axis=-1)

    centering = radius if top_half else -radius
    points = np.vstack([(np.array([0, 0, 0])).reshape(1, 3), point_cloud - (np.array([0,0,centering]))])
    # center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[np.random.choice([0, 1, 2])]
    # points = points - points[center_point_idx, :]
    # points[center_point_idx, :] = (points[0, :]).copy()
    # points[0, :] = np.array([[0, 0, 0]])
    return [lower_bound_x,upper_bound_x,lower_bound_y, upper_bound_y], points


def sample_sphere_point_cloud(radius, num_of_points, top_half=True, bounds=None):
    # if bounds is not None:
    #     if radius > abs(bounds[0]):
    #         radius = abs(bounds[0])
    theta = np.random.uniform(0, 2 * np.pi, num_of_points)

    if top_half:
        phi = np.random.uniform(0, np.pi / 2, num_of_points)  # Top hemisphere
    else:
        phi = np.random.uniform(np.pi / 2, np.pi, num_of_points)  # Bottom hemisphere

    # Convert spherical coordinates to Cartesian coordinates
    x = radius * np.sin(phi) * np.cos(theta)
    y = radius * np.sin(phi) * np.sin(theta)
    z = radius * np.cos(phi)

    # Stack the coordinates into a (num_of_points, 3) array
    point_cloud = np.stack((x, y, z), axis=-1)

    centering = radius if top_half else -radius
    points = np.vstack([(np.array([0, 0, 0])).reshape(1, 3), point_cloud - (np.array([0,0,centering]))])

    # center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[np.random.choice([0, 1, 2])]
    center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[0]
    points = points - points[center_point_idx, :]
    points[center_point_idx, :] = (points[0, :]).copy()
    points[0, :] = np.array([[0, 0, 0]])
    return points

def equilateral_triangle_coordinates(r, a):
    # pythagoras thm on height of pyr and height of face:  h^2 + 0.25R^2 = 0.75R^2 * 1 / (tan(alpha/2))^2;
    beta = np.tan(a / 2)
    h = r * np.sqrt((4 / ( 3 * beta ))-0.25)

    # Calculate the 2D coordinates of the vertices of an equilateral triangle
    # Centered at (0, 0) in the x-y plane
    vertices = [np.array([0,0,0])]
    for i in range(3):
        angle = 2 * np.pi * i / 3  # 120-degree steps
        x = r * np.cos(angle)
        y = r * np.sin(angle)
        vertices.append((x, y, h))

    return h, np.array(vertices)

def sample_pyramid(n_points, head_angle_rad, min_len,max_len, bias=0.0, bounds=None):
    if bounds is None:
        r = np.random.uniform(min_len, max_len) * (4/3)
    else:
        r = abs(bounds[0])
    h, base_vertices = equilateral_triangle_coordinates(r , head_angle_rad )
    base_vertices = base_vertices[1:,:]

    # Define the pyramid tip
    tip = np.array([0, 0, 0])
    base_vertices[0] *= (np.random.uniform(1-bias, 1+bias))
    base_vertices[1] *= (np.random.uniform(1-bias, 1+bias))
    base_vertices[2] *= (np.random.uniform(1-bias, 1+bias))
    # Define the three triangular side faces
    triangles = np.array([
        [tip, base_vertices[0], base_vertices[1]],  # Side 1
        [tip, base_vertices[1], base_vertices[2]],  # Side 2
        [tip, base_vertices[2], base_vertices[0]]  # Side 3
    ])

    # Calculate areas of the triangular sides
    def triangle_area(v1, v2, v3):
        return 0.5 * np.linalg.norm(np.cross(v2 - v1, v3 - v1))

    areas = np.array([
        triangle_area(*triangles[i]) for i in range(3)
    ])
    total_area = areas.sum()

    # Allocate points proportionally to triangle areas
    points_per_triangle = np.random.multinomial(n_points, areas / total_area)

    # Generate barycentric coordinates for surface sampling
    u = np.random.rand(n_points)
    v = np.random.rand(n_points)
    mask = u + v > 1
    u[mask] = 1 - u[mask]
    v[mask] = 1 - v[mask]
    w = 1 - (u + v)

    # Repeat triangles based on the number of points per triangle
    triangle_indices = np.repeat(np.arange(3), points_per_triangle)
    selected_triangles = triangles[triangle_indices]

    # Extract vertices for selected triangles
    v1 = selected_triangles[:, 0]
    v2 = selected_triangles[:, 1]
    v3 = selected_triangles[:, 2]

    # Compute sampled points using barycentric coordinates
    sampled_points = u[:, None] * v1 + v[:, None] * v2 + w[:, None] * v3

    points = np.vstack([tip.reshape(1,3), sampled_points])

    # center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[np.random.choice([0, 1, 2])]
    center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[0]
    # center_point_idx = np.argsort(np.linalg.norm(points, axis=1))[np.random.choice([0,1])]
    points = points - points[center_point_idx, :]
    points[center_point_idx, :] = (points[0, :]).copy()
    points[0, :] = np.array([[0, 0, 0]])
    return r, points


def random_rotation(point_cloud):
    rot = R.random().as_matrix()
    rot_mat = torch.tensor(rot, dtype=torch.float32)
    rotated_point_cloud = torch.matmul(point_cloud, rot_mat.T)
    return rot, rotated_point_cloud
