import numpy as np
from scipy.spatial.transform import Rotation as R
import torch

def modify_conformer(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False):
    pos = np.copy(pos)
    if not isinstance(pos, np.ndarray):
        pos = pos.detach().cpu().numpy()
    edge_index = edge_index.cpu().numpy() if torch.is_tensor(edge_index) else edge_index
    torsion_updates = torsion_updates.cpu().numpy() if torch.is_tensor(torsion_updates) else torsion_updates
    mask_rotate = mask_rotate.cpu().numpy() if torch.is_tensor(mask_rotate) else mask_rotate
    #print('torsion update',torsion_updates )
    
    for idx_edge, (u, v) in enumerate(edge_index):
        if torsion_updates[idx_edge] == 0:
            continue
        if mask_rotate[idx_edge, u] or not mask_rotate[idx_edge, v]:
            raise ValueError(f"Edge {idx_edge} has incorrect mask: u={u}, v={v}")

        rot_vec = pos[u] - pos[v]
        rot_axis = rot_vec / np.linalg.norm(rot_vec)
        rot_vec_scaled = rot_axis * torsion_updates[idx_edge]
        
        rot_mat = R.from_rotvec(rot_vec_scaled).as_matrix()

        target_mask = mask_rotate[idx_edge]
        pos[target_mask] = (pos[target_mask] - pos[v]) @ rot_mat.T + pos[v]

    if not as_numpy:
        pos = torch.from_numpy(pos.astype(np.float32))
    return pos
    

def rotation_perturb_mol(data, pos, torsion_updates = None):
    if torsion_updates is None:
        torsion_updates = np.random.uniform(low=-np.pi,high=np.pi, size=data.mask_edges.sum())
    pos_tau = modify_conformer(pos, data.edge_index.T[data.mask_edges],
                                        data.mask_rotate, torsion_updates)
    return torsion_updates, pos_tau

# def perturb_seeds(data, pdb=None):
#     for i, data_conf in enumerate(data):
#         torsion_updates = np.random.uniform(low=-np.pi,high=np.pi, size=data_conf.edge_mask.sum())
#         data_conf.pos = modify_conformer(data_conf.pos, data_conf.edge_index.T[data_conf.edge_mask],
#                                          data_conf.mask_rotate, torsion_updates)
#         data_conf.total_perturb = torsion_updates
#     return data
