import torch
from color_refinement import PointCloud2nxGraph, CR_E
import miniball

def C_refinement(atom_type, pos, device, round_num):

    # perform center augmentation
    center = torch.mean(pos, dim=0, keepdim=True)
    dist2center = pos - center
    dist2center = torch.norm(dist2center, dim=1)
    # maintain round_ decimal places, i.e., the error is 1e-round_
    dist2center = dist2center.tolist()
    dist2center = [round(d, round_num) for d in dist2center]
    # color refinement 
    atom_type_list = atom_type.tolist()
    assert len(atom_type_list) == len(dist2center)
    
    color = [[atom_type_list[i], dist2center[i]] for i in range(len(atom_type_list))]
    # text it
    color = [str(c) for c in color]
    
    # reindex
    unique_color = list(set(color))
    atom_type = []
    for c in color:
        atom_type.append(unique_color.index(c))
    return torch.tensor(atom_type, dtype=torch.long, device=device)
    
def VD_refinement(atom_type, pos, device, round_num):
    atom_type_old = atom_type
    atom_type = {i:atom_type[i].item() for i in range(len(atom_type))}
    
    pos = pos.cpu().numpy()
    Gl, Gr = PointCloud2nxGraph(pos, initial_labels=atom_type), PointCloud2nxGraph(pos, initial_labels=atom_type)
    non_isomorphisc, new_label_list_l, new_label_list_r, iteration = CR_E(Gl, Gr, round_num=round_num)
    
    assert new_label_list_r == new_label_list_l
    
    atom_type = torch.tensor(list(new_label_list_l.values()), dtype=torch.long, device=device)
    return atom_type




def test_symmetry(atom_type, pos, error, device="cuda"):

    distinct_type = torch.unique(atom_type)
    type_centers = []
    for t in distinct_type:
        type_mask = atom_type == t
        type_pos = pos[type_mask]
        type_center = torch.mean(type_pos, dim=0)
        type_centers.append(type_center)
    type_centers = torch.stack(type_centers, dim=0)
    
    try:
        center, radius_squared = miniball.get_bounding_ball(type_centers.cpu().numpy())
                
        # additional check begin
        center = torch.tensor(center, dtype=type_centers.dtype, device=type_centers.device)
        radius = torch.sqrt(torch.tensor(radius_squared, dtype=type_centers.dtype, device=type_centers.device))
        type_center_distances = torch.norm(type_centers - center, dim=1)  # Compute distances from all points to the center
        assert torch.allclose(type_center_distances.max(), radius)
        # additional check end
        
        
        symm = radius.item() <= error
        
        return symm, radius.item()
    except:
        print(f"error met for {type_centers}")
        # Use a relaxed version
        dist = type_centers - torch.mean(type_centers, dim=0)
        dist = torch.norm(dist, dim=1)
            
        symm = torch.all(dist <= error)
        
        return symm, dist.max().flatten().item()


def farthest_point_sample(point, npoint, device):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3].to(device)
    centroids = torch.zeros((npoint,)).to(device)
    distance = torch.ones((N,)).to(device) * 1e10
    farthest = torch.randint(0, N, (1,), dtype=torch.long).squeeze().item()
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :].unsqueeze(0)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.argmax(distance, -1)
    point = point[centroids.long(), :]
    return point


def farthest_pretransform(data, npoint, device):
    
    data.pos = farthest_point_sample(data.pos.to(device), npoint, device)
    return data