import os
import ase
import math
import tqdm
import torch
import ase.io
import numpy as np
from rdkit import Chem, RDLogger
from typing import Any, Optional, List, Tuple, Dict


RDLogger.DisableLog('rdApp.*')

EPS_SMALL = 1e-3

def is_geometrically_degenerate(P: np.ndarray, tol_rel: float = 0.05) -> str:
    """
    Inspect centered coordinates P (N,3) and decide if the point set is:
      - 'linear'    : all points lie approximately on a line,
      - 'planar'    : points lie approximately in a plane,
      - 'nondegenerate' : full 3D.
    Use SVD singular values relative tolerance tol_rel.
    Returns one of the three strings.
    """
    if P.shape[0] < 3:
        # 1 or 2 points: linear degeneration
        return 'linear'
    P_c = P - P.mean(axis=0)
    _, S, _ = np.linalg.svd(P_c, full_matrices=False)
    if S[1] < tol_rel * S[0]:
        # second singular value tiny -> linear
        return 'linear'
    if S[2] < tol_rel * S[0]:
        # third singular value tiny -> planar
        return 'planar'
    return 'nondegenerate'

def numpy_to_torch_R_t(R_np: np.ndarray, t_np: np.ndarray, device: Optional[torch.device] = None):
    """
    Convert numpy R, t to torch tensors (float32).
    """
    R_t = torch.tensor(R_np.astype(np.float32), dtype=torch.float32, device=device)
    t_t = torch.tensor(t_np.astype(np.float32), dtype=torch.float32, device=device)
    return R_t, t_t

def rotation_angle_distance(R1: np.ndarray, R2: np.ndarray) -> float:
    """
    Geodesic distance between two rotation matrices in radians.
    """
    R = R1.T @ R2
    tr = (np.trace(R) - 1.0) / 2.0
    tr = max(-1.0, min(1.0, tr))
    theta = math.acos(tr)
    return float(theta)

def _filter_perms_preserve_atomic_numbers(perms: List[List[int]], exemplar_z: np.ndarray) -> List[List[int]]:
    """Keep only permutations p such that exemplar_z[p] == exemplar_z (element-wise).
       This protects against RDKit returning automorphisms that swap different elements."""
    out = []
    for p in perms:
        p_arr = np.asarray(p, dtype=int)
        if p_arr.shape[0] != exemplar_z.shape[0]:
            continue
        if np.array_equal(exemplar_z[p_arr], exemplar_z):
            out.append(p_arr.tolist())
    return out

def _clean_copy_for_automorphisms(mol: Chem.Mol) -> Chem.Mol:
    mcopy = Chem.RWMol(mol)
    for a in mcopy.GetAtoms():
        try: a.SetIsAromatic(False)
        except Exception: pass
        try: a.SetFormalCharge(0)
        except Exception: pass
        try: a.SetChiralTag(Chem.CHI_UNSPECIFIED)
        except Exception: pass
    for b in mcopy.GetBonds():
        try: b.SetIsAromatic(False)
        except Exception: pass
        try: b.SetStereo(Chem.BondStereo.STEREONONE)
        except Exception: pass
    return mcopy.GetMol()

def _get_graph_automorphisms_preserve_element(mol: Chem.Mol,
                                              exemplar_z: Optional[np.ndarray] = None) -> List[List[int]]:
    mc = _clean_copy_for_automorphisms(mol)
    raw = mc.GetSubstructMatches(mc, uniquify=False)  # tuples
    perms = [list(t) for t in raw]
    if exemplar_z is not None:
        perms = _filter_perms_preserve_atomic_numbers(perms, exemplar_z)
    return perms

def _generate_candidate_fragment_smiles(mol: Chem.Mol, system_atom_indices: set) -> str:
    """
    Generates a canonical SMILES for the system as an induced subgraph.
    It includes the system atoms and any originally attached Hydrogens.
    It does NOT add stubs or explicit Hydrogen caps at cut sites; 
    cut atoms will appear with their implicit valence removed, matching
    the raw fragment library format.
    """
    # 1. Identify atoms to keep: System Ring Atoms + Attached Hydrogens
    atoms_to_keep = set(system_atom_indices)
    
    for idx in system_atom_indices:
        atom = mol.GetAtomWithIdx(idx)
        for bond in atom.GetBonds():
            nbr = bond.GetOtherAtom(atom)
            nbr_idx = nbr.GetIdx()
            
            # If the neighbor is part of the system, it's already kept.
            if nbr_idx in system_atom_indices:
                continue
            
            # If the neighbor is an explicit Hydrogen, keep it.
            # (Matches the logic that the fragment includes its Hs)
            if nbr.GetAtomicNum() == 1:
                atoms_to_keep.add(nbr_idx)
            
            # If the neighbor is a heavy atom outside the system:
            # We do NOTHING. We do not add a stub (*), and we do NOT cap with H.
            # The bond is simply ignored, leaving the internal atom with 'open' valence.

    # 2. Build the temporary RWMol
    label_mol = Chem.RWMol()
    node_map = {} # orig_idx -> new_idx
    
    # Add Real Atoms (System + Existing Hs)
    for orig_idx in sorted(list(atoms_to_keep)):
        src_atom = mol.GetAtomWithIdx(orig_idx)
        new_atom = Chem.Atom(src_atom.GetAtomicNum())
        
        # Critical: SetNoImplicit(True) ensures RDKit doesn't guess Hs 
        # for the atoms where we cut bonds. They will appear as-is.
        new_atom.SetNoImplicit(True) 
        
        try: new_atom.SetFormalCharge(src_atom.GetFormalCharge())
        except: pass
        try: new_atom.SetChiralTag(src_atom.GetChiralTag())
        except: pass
        
        # Force non-aromatic to ensure SMILES consistency with input Kekulization
        new_atom.SetIsAromatic(False) 
        
        idx = label_mol.AddAtom(new_atom)
        node_map[orig_idx] = idx

    # Add Internal Bonds Only
    sorted_keep = sorted(list(atoms_to_keep))
    for i in range(len(sorted_keep)):
        for j in range(i+1, len(sorted_keep)):
            u, v = sorted_keep[i], sorted_keep[j]
            bond = mol.GetBondBetweenAtoms(u, v)
            if bond:
                label_mol.AddBond(node_map[u], node_map[v], bond.GetBondType())

    # 3. Generate SMILES
    # allHsExplicit=True ensures the H atoms we added are written as nodes (e.g. [H][C])
    return Chem.MolToSmiles(label_mol, isomericSmiles=True, canonical=True, allHsExplicit=True)


def get_minimal_stub_cuts(mol: Chem.Mol, rare_smiles: Optional[List[str]] = None) -> List[int]:
    """
    Identifies bonds to cut.
    
    Modified Logic:
    1. Identify Systems.
    2. Check System Rigidity:
       - If System is Planar AND SMILES not in rare_smiles -> Keep Whole.
       - Else -> Decompose into rings.
    3. Check Ring Rigidity (during decomposition):
       - If Ring is (Small OR Planar) AND SMILES not in rare_smiles -> Keep Ring.
       - Else -> Ring is flexible (don't protect bonds).
    4. Cut:
       - Isolate protected rings/systems.
       - Cut all other single bonds (this will shatter rare rings).
    """
    bonds_to_cut = []
    ri = mol.GetRingInfo()
    bond_rings = ri.BondRings()
    atom_rings = ri.AtomRings()
    n_rings = len(atom_rings)
    
    rare_smiles_set = set(rare_smiles) if rare_smiles else set()

    try:
        conf = mol.GetConformer()
        all_pos = conf.GetPositions()
    except ValueError:
        conf = None
        all_pos = None

    # --- Helper: Check Allowlist ---
    def _is_allowed(indices: set) -> bool:
        """Returns True if the fragment SMILES is NOT in the rare list."""
        if not rare_smiles_set:
            return True
        smi = _generate_candidate_fragment_smiles(mol, indices)
        return smi not in rare_smiles_set

    # --- 1. Identify Ring Systems ---
    atom_to_rings = {}
    for r_idx, atoms in enumerate(atom_rings):
        for a_idx in atoms:
            if a_idx not in atom_to_rings: atom_to_rings[a_idx] = []
            atom_to_rings[a_idx].append(r_idx)
            
    ring_adj = {i: set() for i in range(n_rings)}
    for a_idx, r_indices in atom_to_rings.items():
        if len(r_indices) > 1:
            for i in range(len(r_indices)):
                for j in range(i+1, len(r_indices)):
                    r1, r2 = r_indices[i], r_indices[j]
                    ring_adj[r1].add(r2)
                    ring_adj[r2].add(r1)
    
    visited_rings = set()
    ring_systems = []
    for i in range(n_rings):
        if i not in visited_rings:
            component = set()
            stack = [i]
            visited_rings.add(i)
            while stack:
                r = stack.pop()
                component.add(r)
                for nbr in ring_adj[r]:
                    if nbr not in visited_rings:
                        visited_rings.add(nbr)
                        stack.append(nbr)
            ring_systems.append(component)

    # --- 2. Determine Which Rings are Rigid ---
    safe_bonds = set()
    safe_atoms = set()
    
    def _is_set_planar(indices):
        if conf is None: return False
        coords = all_pos[list(indices)]
        return is_geometrically_degenerate(coords, tol_rel=0.05) != 'nondegenerate'

    for sys_rings in ring_systems:
        sys_atom_indices = set()
        sys_bond_indices = set()
        for r_idx in sys_rings:
            sys_atom_indices.update(atom_rings[r_idx])
            for b_idx in bond_rings[r_idx]:
                sys_bond_indices.add(b_idx)
        
        # Policy A: Check Global System Planarity + Rarity
        is_globally_rigid = False
        if _is_set_planar(sys_atom_indices):
            # Optimization: Only generate SMILES if geometry passes
            if _is_allowed(sys_atom_indices):
                is_globally_rigid = True

        if is_globally_rigid:
            safe_bonds.update(sys_bond_indices)
            safe_atoms.update(sys_atom_indices)
            continue

        # Policy B: Decompose (System is flexible, or was vetoed by rare list)
        eligible_rings = []
        for r_idx in sys_rings:
            r_atoms = atom_rings[r_idx]
            
            # 1. Geometric Check (Small OR Planar)
            is_r_rigid = (len(r_atoms) <= 3) or _is_set_planar(r_atoms)
            
            # 2. Rarity Check (if geometry passed)
            if is_r_rigid:
                if not _is_allowed(set(r_atoms)):
                    is_r_rigid = False
            
            if is_r_rigid:
                eligible_rings.append(r_idx)
        
        # Greedy Selection
        eligible_rings.sort(key=lambda r: (-len(atom_rings[r]), tuple(sorted(atom_rings[r]))))
        
        kept_atoms_in_sys = set()
        for r_idx in eligible_rings:
            r_atoms_set = set(atom_rings[r_idx])
            if len(kept_atoms_in_sys.intersection(r_atoms_set)) == 0:
                kept_atoms_in_sys.update(r_atoms_set)
                safe_atoms.update(r_atoms_set)
                for b_idx in bond_rings[r_idx]:
                    safe_bonds.add(b_idx)

    # --- 3. Identify Cuts ---
    for b in mol.GetBonds():
        idx = b.GetIdx()
        u = b.GetBeginAtom()
        v = b.GetEndAtom()
        
        # Keep hydrogen bonds
        if u.GetAtomicNum() == 1 or v.GetAtomicNum() == 1:
            continue

        # Keep Safe Internal Bonds
        if idx in safe_bonds:
            continue
            
        # CUT ISOLATION: Cut bonds connecting safe rings to outside
        if u.GetIdx() in safe_atoms or v.GetIdx() in safe_atoms:
            bonds_to_cut.append(idx)
            continue

        # Keep Acyclic Double/Triple bonds
        if b.GetBondType() != Chem.BondType.SINGLE:
            continue
            
        # Cut Acyclic Single bonds AND bonds inside Rare/Unsafe rings
        bonds_to_cut.append(idx)
        
    return sorted(list(set(bonds_to_cut)))


def numpy_kabsch_with_reflection_flag(P: np.ndarray, Q: np.ndarray, weights: Optional[np.ndarray] = None
                                      ) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]:
    """
    Weighted Kabsch: P -> Q. Returns (R, t, weighted_mse, per_atom_errors, reflected_flag).
    Same semantics as before.
    """
    if P.shape != Q.shape:
        raise ValueError("P and Q must have same shape for Kabsch")
    N = P.shape[0]
    if weights is None:
        w = np.ones(N, dtype=float)
    else:
        w = weights.astype(float).copy()
        if w.shape[0] != N:
            raise ValueError("weights length mismatch in kabsch")
    wsum = w.sum()
    if wsum == 0:
        raise ValueError("sum of weights is zero in kabsch")
    w = w / wsum

    if N < 2:
        cP = (w[:, None] * P).sum(axis=0)
        cQ = (w[:, None] * Q).sum(axis=0)
        R = np.eye(3, dtype=float)
        t = cQ - R @ cP
        P_pred = (R @ P.T).T + t
        per_atom = np.linalg.norm(P_pred - Q, axis=1)
        weighted_mse = (w * (per_atom**2)).sum()
        return R, t, float(weighted_mse), per_atom, False

    cP = (w[:, None] * P).sum(axis=0)
    cQ = (w[:, None] * Q).sum(axis=0)
    P_c = P - cP
    Q_c = Q - cQ

    H = (P_c * w[:, None]).T @ Q_c
    U, S, Vt = np.linalg.svd(H)
    V = Vt.T
    R_raw = V @ U.T

    reflected = False
    if np.linalg.det(R_raw) < 0:
        reflected = True
        V[:, -1] *= -1
        R = V @ U.T
    else:
        R = R_raw

    t = cQ - R @ cP

    P_pred = (R @ P.T).T + t
    diffs = P_pred - Q
    per_atom_sq = np.sum(diffs * diffs, axis=1)
    per_atom = np.sqrt(per_atom_sq)
    weighted_mse = (w * per_atom_sq).sum()

    return R, t, float(weighted_mse), per_atom, reflected


def robust_kabsch_symmetry(P: np.ndarray, Q: np.ndarray, weights: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]:
    """
    Returns (R, t, weighted_mse, per_atom_errors, reflected_flag).
    Calculates the best SE(3) rotation mapping P -> Q.
    Strictly assumes N >= 3 due to pipeline filtering and Deep Stub injection.
    """
    N = P.shape[0]

    if N < 3:
        raise ValueError(f"robust_kabsch_symmetry called with {N} points. Pipeline guarantees >= 3.")

    R, t, mse, per_atom, reflected = numpy_kabsch_with_reflection_flag(P, Q, weights)

    # Only snap to identity if the MSE is numerically zero AND the matrix is exactly identity
    if mse < EPS_SMALL:
        if np.allclose(R, np.eye(3), atol=1e-4):
            R = np.eye(3); t = np.mean(Q, axis=0) - np.mean(P, axis=0)
            per_atom = np.linalg.norm((R @ P.T).T + t - Q, axis=1)
            return R, t, float((weights * (per_atom**2)).sum() / weights.sum()), per_atom, False

    return R, t, mse, per_atom, reflected

def is_molecule_globally_linear(atoms: ase.Atoms, tol: float = 0.05) -> bool:
    """
    Checks if an entire molecule is collinear.
    Returns True if the smallest singular value is < tol * largest singular value.
    """
    if len(atoms) < 3: 
        return True # Considered linear/degenerate for safety
        
    pos = atoms.positions - atoms.positions.mean(axis=0)
    # SVD
    S = np.linalg.svd(pos, compute_uv=False)
    # S is sorted desc. If S[1] (2nd dimension) is tiny, it's linear.
    if S[1] < tol * S[0]:
        return True
    return False


class FragmentClass:
    """
    Stores exemplar geometry and automorphism perms (permutations of exemplar indices).
    exemplar_pos: numpy (N,3)
    exemplar_z: numpy (N,)
    perms: list of numpy arrays shape (N,) each representing a permutation of indices 0..N-1
    """
    def __init__(
        self, 
        smiles: str, 
        exemplar_pos: np.ndarray, 
        exemplar_z: np.ndarray,
        exemplar_edge_index: np.ndarray,
        exemplar_edge_attr: np.ndarray,
        perms: List[List[int]]):
        self.smiles = smiles
        self.exemplar_pos = exemplar_pos.copy()
        self.exemplar_z = exemplar_z.copy()
        self.exemplar_edge_index = exemplar_edge_index.copy()
        self.exemplar_edge_attr = exemplar_edge_attr.copy()
        # perms must be permutations of range(N) with respect to the real-only canonical ordering
        self.perms = [np.array(p, dtype=int) for p in perms]
        # Defensive sanity
        if self.exemplar_pos.shape[0] != self.exemplar_z.shape[0]:
            raise ValueError(f"FragmentClass exemplar size mismatch: pos rows {self.exemplar_pos.shape[0]} vs z length {self.exemplar_z.shape[0]}")
        # Validate perms shapes
        N = self.exemplar_pos.shape[0]
        for p in self.perms:
            if p.shape[0] != N:
                raise ValueError(f"FragmentClass permutation length mismatch: expected {N}, got {p.shape[0]}")

class FragmentClassManager:
    def __init__(self,
                 rmsd_thresh: float = 1.0,
                 max_dev_thresh: float = 1.0,
                 heavy_weight: float = 1.0,
                 h_weight: float = 0.1,
                 stub_weight: float = 0.1):
        self.classes: List[FragmentClass] = []
        self.by_smiles: Dict[str, List[int]] = {}
        self.rmsd_thresh = float(rmsd_thresh)
        self.max_dev_thresh = float(max_dev_thresh)
        self.heavy_weight = float(heavy_weight)
        self.h_weight = float(h_weight)
        self.stub_weight = float(stub_weight)
        self.frozen = False 

    def _atom_weights_from_z(self, z: np.ndarray) -> np.ndarray:
        """Weights for Orientation/Rotation (Includes Stubs)"""
        w = np.ones_like(z, dtype=float)
        w[z == 1] = self.h_weight          # Hydrogens
        w[(z != 1) & (z != 0)] = self.heavy_weight # Heavy Atoms
        w[z == 0] = self.stub_weight       # Stubs (Ghost Heavy Atoms)
        return w

    def _weights_for_centering(self, z: np.ndarray) -> np.ndarray:
        """Weights for Translation/Centroid (Real Atoms Only)"""
        w = np.zeros_like(z, dtype=float)
        w[z == 1] = self.h_weight
        w[(z != 1) & (z != 0)] = self.heavy_weight
        w[z == 0] = 0.0  # Stubs excluded from physical center of mass
        return w

    def _select_orientation_indices(self, z: np.ndarray) -> np.ndarray:
        """
        Use ALL indices; let weights suppress noise. Stubs are included for orientation.
        """
        return np.arange(len(z), dtype=int)

    def _compute_canonical_rotation(self, cls: FragmentClass, pos: np.ndarray, z: np.ndarray, perm: np.ndarray, weights_center: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute canonical rotation and translation mapping exemplar -> instance.
        Uses the exemplar ordering (cls.exemplar_pos) and the provided perm to reorder
        exemplar points to match `pos` ordering, then computes an SE(3) mapping.
        We use rotation weights that include stubs for orientation and weights_center (real atoms only)
        for translation centering to be consistent with how exemplars are stored.
        """
        N = pos.shape[0]
        Q_reordered = np.zeros_like(pos)
        w_center_reordered = np.zeros_like(weights_center)
        for k in range(N):
            target_idx = perm[k]
            Q_reordered[target_idx] = pos[k]
            w_center_reordered[target_idx] = weights_center[k]

        weights_rotation = self._atom_weights_from_z(cls.exemplar_z)

        # Robust Kabsch / orientation using rotation weights (all points matter)
        R_canon, _, _, _, _ = robust_kabsch_symmetry(cls.exemplar_pos, Q_reordered, weights=weights_rotation)

        # Translation: compute weighted centroids using weights_center (real atoms only)
        if np.sum(w_center_reordered) > 1e-12:
            cQ = np.average(Q_reordered, axis=0, weights=w_center_reordered)
            cP = np.average(cls.exemplar_pos, axis=0, weights=w_center_reordered)
        else:
            cQ = np.mean(Q_reordered, axis=0)
            cP = np.mean(cls.exemplar_pos, axis=0)

        t_canon = cQ - (R_canon @ cP)
        return R_canon, t_canon

    def _compute_inst_to_canonical(self, pos: np.ndarray, z: np.ndarray, weights_center: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        Compute a deterministic canonical exemplar frame for a fragment instance.

        CRITICAL:
         - The exemplar returned is centered so that the weighted centroid (weights_center,
           i.e. real-atoms-only centroid) is at the origin. This ensures translation
           computed later (cQ - R @ cP) is consistent and small.
         - Orientation (rotation) is found via weighted PCA using rotation-weights (which include stubs).
        Returns:
            pos_centered : pos - centroid (so positional information is relative to that centroid)
            R_inst_to_canonical : rotation matrix such that exemplar_pos = pos_centered @ R_inst_to_canonical
        """
        if pos.shape[0] == 0:
            raise ValueError("Empty fragment")

        # Compute center using weights_center (real atoms only) so exemplar has physical centroid at origin.
        wsum = float(np.sum(weights_center))
        if wsum < 1e-12:
            center = np.mean(pos, axis=0)
        else:
            center = np.average(pos, axis=0, weights=weights_center)

        # Subtract this center --> this ensures exemplar weighted centroid (real atoms) will be zero after rotation.
        pos_centered = pos - center

        # Rotation: use rotation-weights (stubs included) to define principal axes.
        weights_rotation = self._atom_weights_from_z(z)
        # If too small, raise error since we are supposed to catch this earlier.
        if pos.shape[0] < 3 or np.sum(weights_rotation) <= 0:
            raise ValueError("Insufficient points/weights for PCA frame definition")

        # Weighted PCA for frame definition (on pos_centered)
        # Compute weighted mean for rotation-weights (note: pos_centered may have small non-zero mean for these weights)
        wrot_sum = np.sum(weights_rotation)
        rot_mean = np.average(pos_centered, axis=0, weights=weights_rotation)
        Pm = pos_centered - rot_mean

        cov = (Pm.T @ (Pm * weights_rotation[:, None])) / (wrot_sum + 1e-20)
        _, evecs = np.linalg.eigh(cov)
        # Build rotation matrix from principal axes: largest eigenvector first
        # We want R such that we rotate instance points into canonical axes
        R_pca = np.column_stack([evecs[:, 2], evecs[:, 1], evecs[:, 0]]).T

        # Enforce right-handedness and sensible sign choice
        pos_rot = pos_centered @ R_pca.T
        for i in range(3):
            skew = np.sum(pos_rot[:, i] ** 3)
            if skew < -1e-9:
                R_pca[i, :] *= -1.0
        if np.linalg.det(R_pca) < 0:
            R_pca[2, :] *= -1.0

        # We want to return R_inst_to_canonical such that exemplar_pos = pos_centered @ R_inst_to_canonical
        R_inst_to_canonical = R_pca.T
        return pos_centered, R_inst_to_canonical

    def _matches_fragment(self, cls: FragmentClass, pos: np.ndarray, z: np.ndarray) -> Optional[Tuple[np.ndarray, np.ndarray, float, np.ndarray]]:
        """
        Try to match `pos` (instance) to class exemplar `cls`.
        Uses orientation weights (including stubs) for alignment and weights_center for RMSD checks.
        """
        if pos.shape[0] != cls.exemplar_pos.shape[0]:
            return None
        if not np.array_equal(z, cls.exemplar_z):
            return None

        weights_align = self._atom_weights_from_z(z)
        weights_center = self._weights_for_centering(z)

        best_selection_score = float('inf')
        best_result = None

        for perm in cls.perms:
            P_ex_perm = cls.exemplar_pos[perm]

            # Align using rotation weights (stubs included)
            R_sub, _, _, _, reflected = robust_kabsch_symmetry(P_ex_perm, pos, weights=weights_align)

            # Compute centroids for translation using weights_center (real atoms)
            if np.sum(weights_center) > 1e-12:
                cQ = np.average(pos, axis=0, weights=weights_center)
                cP = np.average(P_ex_perm, axis=0, weights=weights_center)
            else:
                cQ = np.mean(pos, axis=0)
                cP = np.mean(P_ex_perm, axis=0)
            t_sub = cQ - (R_sub @ cP)

            P_pred_full = (R_sub @ P_ex_perm.T).T + t_sub
            per_atom_diff_sq = np.sum((P_pred_full - pos)**2, axis=1)
            per_atom_full = np.sqrt(per_atom_diff_sq)

            selection_mse = np.average(per_atom_diff_sq, weights=weights_align)

            if np.sum(weights_center) > 1e-12:
                real_mse = np.average(per_atom_diff_sq, weights=weights_center)
            else:
                real_mse = np.mean(per_atom_diff_sq)
            real_rmsd = float(np.sqrt(real_mse))

            geom_type = is_geometrically_degenerate(cls.exemplar_pos)
            allow_reflection = (geom_type != 'nondegenerate')

            if reflected and not allow_reflection:
                continue

            if selection_mse < best_selection_score:
                best_selection_score = selection_mse
                if np.any(z > 0):
                    max_dev = float(np.max(per_atom_full[z > 0]))
                else:
                    max_dev = float(np.max(per_atom_full))
                best_result = (perm, real_rmsd, per_atom_full, max_dev)

        if best_result is not None:
            perm, real_rmsd, per_atom_full, max_dev = best_result
            if real_rmsd <= self.rmsd_thresh and max_dev <= self.max_dev_thresh:
                R_canon, t_canon = self._compute_canonical_rotation(cls, pos, z, perm, weights_center)
                return R_canon, t_canon, real_rmsd, per_atom_full

        return None

    def find_or_create_class(
        self, 
        smiles: str, 
        pos: np.ndarray, 
        z: np.ndarray, 
        edge_index: np.ndarray, 
        edge_attr: np.ndarray,
        perms_from_canon: List[List[int]]) -> Tuple[int, FragmentClass, Optional[Tuple[np.ndarray, np.ndarray, float, np.ndarray]]]:
        """
        Creates or finds a class for the fragment with canonical exemplar creation.
        The exemplar will be stored such that the weighted centroid (weights_center) is at origin.
        """
        if not perms_from_canon:
            perms = [list(range(pos.shape[0]))]
        else:
            perms = perms_from_canon
            
        existing_indices = self.by_smiles.get(smiles, [])

        for cls_idx in existing_indices:
            cls = self.classes[cls_idx]
            match = self._matches_fragment(cls, pos, z)
            if match is not None:
                return cls_idx, cls, match

        if self.frozen:
            # If we are here, no existing class matched this fragment within tolerance
            raise ValueError(f"FROZEN_FRAGMENT_LIBRARY: Fragment '{smiles}' not found in frozen library.")

        # Create new class: compute exemplar in canonical frame
        weights_center = self._weights_for_centering(z)
        pos_centered, R_inst_to_canonical = self._compute_inst_to_canonical(pos, z, weights_center)
        exemplar_pos = (pos_centered @ R_inst_to_canonical).astype(np.float64)

        new_cls = FragmentClass(smiles=smiles, exemplar_pos=exemplar_pos, exemplar_z=z, 
                                exemplar_edge_index=edge_index, exemplar_edge_attr=edge_attr,
                                perms=perms)
        new_idx = len(self.classes)
        self.classes.append(new_cls)
        self.by_smiles.setdefault(smiles, []).append(new_idx)

        return new_idx, new_cls, None

# -------------------------
# FragmentData container
# -------------------------

class FragmentData:
    def __init__(self, pos: torch.Tensor, z: torch.Tensor, smiles: str):
        """
        pos: torch (N,3) canonical-order positions for real atoms (float32)
        z:   torch (N,) atomic numbers (int64)
        smiles: canonical fragment smiles string (real-only, no '*' dummies)
        """
        self.pos = pos
        self.z = z
        self.smiles = smiles
        self.rot_gt: Optional[torch.Tensor] = None
        self.trans_gt: Optional[torch.Tensor] = None
        self.orbit_rot: Optional[torch.Tensor] = None
        self.class_id: Optional[int] = None
        self.original_indices: List[int] = []
        
def decompose_to_rigid_fragments(data: Any,
                                 class_manager: Optional[Any] = None,
                                 rare_smiles_check = None
                                 ) -> List[Any]:
    if class_manager is None:
        raise ValueError("decompose_to_rigid_fragments requires a FragmentClassManager")

    smi_original = data.smiles

    # 1. Parse Molecule 
    src_mol = Chem.MolFromSmiles(smi_original)
    if src_mol is None: 
        raise RuntimeError(f"Parse failed: {smi_original}")

    src_mol = Chem.AddHs(src_mol)
    
    pos_np = data.pos.detach().cpu().numpy().astype(np.float64)

    num_src_atoms = src_mol.GetNumAtoms()
    if pos_np.shape[0] != num_src_atoms:
        raise RuntimeError(f"Atom count mismatch: SMILES heavy atoms ({num_src_atoms}) vs Input Positions ({pos_np.shape[0]}).")

    # Kekulize to expose aromatic bonds as Single/Double
    Chem.Kekulize(src_mol, clearAromaticFlags=True)

    # Setup Conformer
    conf = Chem.Conformer(num_src_atoms)
    for i in range(num_src_atoms):
        conf.SetAtomPosition(i, pos_np[i])
        src_mol.GetAtomWithIdx(i).SetIntProp('orig_idx', int(i))
    src_mol.AddConformer(conf)

    parent_ranks = list(Chem.CanonicalRankAtoms(src_mol, breakTies=True))

    # 3. Identify Cuts
    bonds_to_cut = get_minimal_stub_cuts(src_mol, rare_smiles_check)

    # 4. Fragment
    if not bonds_to_cut:
        frag_atom_indices = [set(range(num_src_atoms))]
    else:
        frag_mol = Chem.FragmentOnBonds(src_mol, bonds_to_cut, addDummies=False)
        rdkit_frags = Chem.GetMolFrags(frag_mol, asMols=True, sanitizeFrags=False)
        frag_atom_indices = []
        for frag in rdkit_frags:
            idxs = set(a.GetIntProp('orig_idx') for a in frag.GetAtoms() if a.HasProp('orig_idx'))
            frag_atom_indices.append(idxs)

    final_frag_indices = [f for i, f in enumerate(frag_atom_indices)]
    output_fragments: List[Any] = []

    # 6. Build Rigid Fragments
    for f_idx, frag_indices in enumerate(final_frag_indices):
        if not frag_indices: continue

        real_atoms = []
        for orig_idx in sorted(frag_indices):
            pos = conf.GetAtomPosition(orig_idx)
            atom_pos_np = np.array([pos.x, pos.y, pos.z])
            a = src_mol.GetAtomWithIdx(orig_idx)
            real_atoms.append({
                'z': a.GetAtomicNum(),
                'pos': atom_pos_np,
                'orig_idx': orig_idx,
                'is_stub': False
            })

        # --- A. Check Geometric Stability ---
        real_pos_np = np.array([x['pos'] for x in real_atoms])
        is_real_stable = False
        
        # Must have >= 3 atoms to define a plane
        if len(real_atoms) >= 3:
            P_c = real_pos_np - real_pos_np.mean(axis=0)
            S = np.linalg.svd(P_c, compute_uv=False)
            # Must be non-linear (2nd singular value > threshold)
            if S[1] > 1e-3 * S[0]: 
                is_real_stable = True

        selected_stubs = []
        
        # --- B. Add Stubs ONLY if Unstable ---
        if not is_real_stable:
            immediate_stubs = []
            for item in real_atoms:
                a_obj = src_mol.GetAtomWithIdx(item['orig_idx'])
                atom_pos = item['pos']
                
                for b in a_obj.GetBonds():
                    nbr = b.GetOtherAtom(a_obj)
                    nbr_idx = nbr.GetIdx()
                    
                    # Neighbor is outside fragment -> Stub candidate
                    if nbr_idx not in frag_indices:
                        nbr_rank = parent_ranks[nbr_idx]
                        nbr_pos = np.array(conf.GetAtomPosition(nbr_idx))
                        
                        vec = nbr_pos - atom_pos
                        norm = np.linalg.norm(vec)
                        if norm < 1e-6: norm = 1.0
                        stub_pos = atom_pos + (vec / norm)
                        
                        immediate_stubs.append({
                            'rank_key': nbr_rank, 'z': 0, 'pos': stub_pos,
                            'orig_idx': nbr_idx, 'is_stub': True, 'attached_to': item['orig_idx']
                        })
            
            immediate_stubs.sort(key=lambda x: x['rank_key'])

            # Check stability again with immediate stubs
            current_points = [x['pos'] for x in real_atoms] + [x['pos'] for x in immediate_stubs]
            if len(current_points) >= 3:
                cp_np = np.array(current_points)
                S = np.linalg.svd(cp_np - cp_np.mean(axis=0), compute_uv=False)
                is_real_stable = (S[1] > 1e-3 * S[0])

            distant_stubs = []
            if not is_real_stable:
                # Deep Search for non-linear anchor
                visited = set(frag_indices)
                for s in immediate_stubs: visited.add(s['orig_idx'])
                
                queue = [s['orig_idx'] for s in immediate_stubs]
                idx_ptr = 0
                found_anchor = None
                base_points = [x['pos'] for x in real_atoms] + [x['pos'] for x in immediate_stubs]
                
                while idx_ptr < len(queue):
                    curr_idx = queue[idx_ptr]; idx_ptr += 1
                    curr_atom = src_mol.GetAtomWithIdx(curr_idx)
                    curr_pos = np.array(conf.GetAtomPosition(curr_idx))
                    
                    test_points = np.array(base_points + [curr_pos])
                    S_test = np.linalg.svd(test_points - test_points.mean(axis=0), compute_uv=False)
                    if len(test_points) >= 3 and S_test[1] > 1e-3 * S_test[0]:
                        found_anchor = curr_idx; break
                    
                    for b in curr_atom.GetBonds():
                        nbr = b.GetOtherAtom(curr_atom)
                        if nbr.GetIdx() not in visited:
                            visited.add(nbr.GetIdx()); queue.append(nbr.GetIdx())
                
                if found_anchor is not None:
                    anchor_pos = np.array(conf.GetAtomPosition(found_anchor))
                    distant_stubs.append({
                        'rank_key': 999999, 'z': 0, 'pos': anchor_pos,
                        'orig_idx': found_anchor, 'is_stub': True, 'attached_to': next(iter(frag_indices))
                    })
                else:
                    raise RuntimeError(f"Fragment {f_idx} is globally linear in {smi_original}")
            
            selected_stubs = immediate_stubs + distant_stubs

        if len(real_atoms) + len(selected_stubs) < 3:
             raise RuntimeError(f"Fragment has < 3 points total. Mol: {smi_original}")

        # --- C. Build Label Molecule ---
        label_mol = Chem.RWMol()
        node_map = {}
        
        for item in real_atoms:
            a_obj = Chem.Atom(item['z'])
            a_obj.SetNoImplicit(True)
            p_a = src_mol.GetAtomWithIdx(item['orig_idx'])
            # Copy properties (Stereo is valuable even for heavy-only)
            try: a_obj.SetFormalCharge(p_a.GetFormalCharge())
            except: pass
            try: a_obj.SetChiralTag(p_a.GetChiralTag())
            except: pass
            a_obj.SetIsAromatic(False)
            idx = label_mol.AddAtom(a_obj)
            node_map[item['orig_idx']] = idx
            
        for item in selected_stubs:
            a_obj = Chem.Atom(0); a_obj.SetNoImplicit(True)
            idx = label_mol.AddAtom(a_obj); node_map[id(item)] = idx
            
        real_ids = [r['orig_idx'] for r in real_atoms]
        for i in range(len(real_ids)):
            for j in range(i+1, len(real_ids)):
                u, v = real_ids[i], real_ids[j]
                b = src_mol.GetBondBetweenAtoms(u, v)
                if b: label_mol.AddBond(node_map[u], node_map[v], b.GetBondType())
                
        for s in selected_stubs:
            label_mol.AddBond(node_map[s['attached_to']], node_map[id(s)], Chem.BondType.SINGLE)
        
        final_label_mol = label_mol.GetMol()
        try: ranks = list(Chem.CanonicalRankAtoms(final_label_mol, breakTies=True))
        except: ranks = list(range(final_label_mol.GetNumAtoms()))
        
        for item in real_atoms: item['rank'] = ranks[node_map[item['orig_idx']]]
        for item in selected_stubs: item['rank'] = (ranks[node_map[item['attached_to']]], ranks[node_map[id(item)]])
        
        real_atoms.sort(key=lambda x: x['rank'])
        selected_stubs.sort(key=lambda x: x['rank'])
        
        ordered_data = real_atoms + selected_stubs
        ordered_pos = np.array([x['pos'] for x in ordered_data], dtype=np.float32)
        ordered_z = np.array([x['z'] for x in ordered_data], dtype=np.int64)
        ordered_orig = [x['orig_idx'] for x in real_atoms]

        # Check canonical layout
        stub_indices = np.where(ordered_z == 0)[0]
        if stub_indices.size > 0:
            if not np.all(ordered_z[:stub_indices[0]] > 0): 
                raise AssertionError("Canonical ordering violated")
            
        # --- Extract Connectivity and Bond Types ---
        label_idx_to_ordered_idx = {}
        for ord_idx, item in enumerate(ordered_data):
            key = item['orig_idx'] if not item['is_stub'] else id(item)
            label_idx_to_ordered_idx[node_map[key]] = ord_idx

        edge_list = []
        attr_list = []
        
        for b in final_label_mol.GetBonds():
            u_lbl = b.GetBeginAtomIdx()
            v_lbl = b.GetEndAtomIdx()
            if u_lbl in label_idx_to_ordered_idx and v_lbl in label_idx_to_ordered_idx:
                u_ord = label_idx_to_ordered_idx[u_lbl]
                v_ord = label_idx_to_ordered_idx[v_lbl]
                
                # Convert RDKit types to Integer Order
                bt = b.GetBondType()
                if bt == Chem.BondType.SINGLE: order = 1
                elif bt == Chem.BondType.DOUBLE: order = 2
                elif bt == Chem.BondType.TRIPLE: order = 3
                else: order = 1 # Fallback
                
                # Add Undirected
                edge_list.append([u_ord, v_ord])
                edge_list.append([v_ord, u_ord])
                attr_list.append(order)
                attr_list.append(order)
        
        if edge_list:
            exemplar_edge_index = np.array(edge_list, dtype=np.int64).T
            exemplar_edge_attr = np.array(attr_list, dtype=np.int64)
        else:
            exemplar_edge_index = np.empty((2, 0), dtype=np.int64)
            exemplar_edge_attr = np.empty((0,), dtype=np.int64)
            
        # Graph Automorphisms
        node_to_ordered = {node_map[d['orig_idx'] if not d['is_stub'] else id(d)]: i for i, d in enumerate(ordered_data)}
        mol_auto = Chem.Mol(final_label_mol)
        graph_z = np.array([a.GetAtomicNum() for a in final_label_mol.GetAtoms()])
        raw_perms = _get_graph_automorphisms_preserve_element(mol_auto, exemplar_z=graph_z)
        real_perms = []
        for p in raw_perms:
            try: real_perms.append([node_to_ordered[p[node_map[d['orig_idx'] if not d['is_stub'] else id(d)]]] for d in ordered_data])
            except: continue
        real_perms = _filter_perms_preserve_atomic_numbers(real_perms, ordered_z)
        uniq = sorted(list(set(tuple(x) for x in real_perms)))
        real_perms = [list(x) for x in uniq]
        if not real_perms: real_perms = [list(range(len(ordered_pos)))]
        
        # --- Create Class ---
        try: lbl = Chem.MolToSmiles(final_label_mol, isomericSmiles=True, canonical=True, allHsExplicit=True)
        except: lbl = "Error"

        class_idx, _, align = class_manager.find_or_create_class(
            lbl, ordered_pos, ordered_z, 
            exemplar_edge_index, exemplar_edge_attr,
            perms_from_canon=real_perms
        )
        
        frag_d = FragmentData(pos=torch.tensor(ordered_pos, dtype=torch.float32), 
                              z=torch.tensor(ordered_z, dtype=torch.int64), 
                              smiles=lbl)
        frag_d.class_id = class_idx
        frag_d.original_indices = ordered_orig
        
        if align:
            R, t, _, _ = align
            frag_d.rot_gt = torch.tensor(R.astype(np.float32))
            frag_d.trans_gt = torch.tensor(t.astype(np.float32))
            
        output_fragments.append(frag_d)

    return output_fragments

def compute_se3_targets(fragments: List['FragmentData'],
                        class_manager: 'FragmentClassManager') -> List['FragmentData']:
    """
    For each fragment (already assigned class_id), compute frag.rot_gt and frag.trans_gt.
    """
    if class_manager is None:
        raise ValueError("compute_se3_targets requires a FragmentClassManager")

    for i, frag in enumerate(fragments):
        # skip already-computed
        if frag.rot_gt is not None and frag.trans_gt is not None:
            continue
        if frag.class_id is None:
            raise ValueError(f"Fragment {i} ({frag.smiles}) has no class_id")

        class_idx = frag.class_id
        if class_idx < 0 or class_idx >= len(class_manager.classes):
            raise ValueError(f"Invalid class index {class_idx} for fragment {i}")

        cls = class_manager.classes[class_idx]
        exemplar_pos = cls.exemplar_pos.astype(np.float64)
        exemplar_z = cls.exemplar_z.astype(np.int64)
        perms = cls.perms

        if exemplar_pos.shape[0] != frag.pos.shape[0]:
            raise ValueError(f"Atom count mismatch between exemplar (N={exemplar_pos.shape[0]}) and fragment {i} (N={frag.pos.shape[0]})")

        P_inst = frag.pos.detach().cpu().numpy().astype(np.float64)
        P_tmpl = exemplar_pos
        z = exemplar_z

        weights_center = class_manager._weights_for_centering(z)
        sel = class_manager._select_orientation_indices(z)
        if sel.size == 0:
            raise ValueError("No orientation indices found for fragment")

        geom_type = is_geometrically_degenerate(P_tmpl)
        allow_reflection_globally = (geom_type != 'nondegenerate')

        best_rmsd = float('inf')
        best_info = None

        weights_rotation = class_manager._atom_weights_from_z(z)

        for perm in perms:
            P_tmpl_perm = P_tmpl[perm]
            P_sub = P_tmpl_perm[sel]
            Q_sub = P_inst[sel]

            w_sub = weights_rotation[sel] if weights_rotation is not None else None
            if w_sub is None or np.sum(w_sub) <= 0:
                w_sub_pass = None
            else:
                w_sub_pass = w_sub

            R_np, t_np, mse_sub, per_atom_sub, reflected = robust_kabsch_symmetry(P_sub, Q_sub, weights=w_sub_pass)

            # Strict 3D Reflection Check
            if reflected and not allow_reflection_globally:
                continue 

            P_pred_all = (R_np @ P_tmpl_perm.T).T + t_np
            per_atom_full = np.linalg.norm(P_pred_all - P_inst, axis=1)

            if np.sum(weights_center) <= 0:
                weighted_mse = float(np.mean(per_atom_full**2))
            else:
                weighted_mse = float(np.sum(weights_center * (per_atom_full**2)) / np.sum(weights_center))
            weighted_rmsd = float(np.sqrt(weighted_mse))

            if weighted_rmsd < best_rmsd:
                best_rmsd = weighted_rmsd
                best_info = (R_np, t_np, weighted_rmsd, per_atom_full, reflected, perm)

        if best_info is None:
            raise ValueError(f"No valid alignment found for fragment {i} of class {class_idx}")

        best_R_np, best_t_np, best_rmsd, best_per_atom, _, best_perm = best_info
        
        # 1. Reconstruct the specific permutation used for the match
        cls = class_manager.classes[class_idx]
        P_tmpl = cls.exemplar_pos.astype(np.float64) # Canonical
        P_tmpl_perm = P_tmpl[best_perm]               # Permuted
        
        # 2. Compute the symmetry S that maps Permuted -> Canonical
        # S @ P_perm.T = P_canon.T
        z = cls.exemplar_z.astype(np.int64)
        weights_rotation = class_manager._atom_weights_from_z(z)
        sel = class_manager._select_orientation_indices(z)
        
        # Align Permuted (Source) to Canonical (Target)
        S_corr, _, _, _, _ = robust_kabsch_symmetry(
            P_tmpl_perm[sel], 
            P_tmpl[sel], 
            weights=weights_rotation[sel]
        )
        
        # 3. Correct the Rotation to map Canonical -> Global
        # Current best_R_np maps Permuted -> Global
        # We want: R_canon @ P_canon.T = Global
        # Substitute P_canon.T = S @ P_perm.T
        # R_canon @ S @ P_perm.T = Global
        # implies R_canon @ S = best_R_np  =>  R_canon = best_R_np @ S.T
        R_target_canon = best_R_np @ S_corr.T
        
        # 4. Re-calculate translation for consistency
        # t = Centroid_Global - R_canon @ Centroid_Canon
        # (Using weights to match definition)
        weights_center = class_manager._weights_for_centering(z)
        
        def w_mean(pts, w):
            if np.sum(w) > 1e-12: return np.average(pts, axis=0, weights=w)
            return np.mean(pts, axis=0)
            
        cP = w_mean(P_tmpl, weights_center)
        cQ = w_mean(frag.pos.detach().cpu().numpy(), weights_center)
        
        t_target_canon = cQ - (R_target_canon @ cP)

        # 5. Save corrected targets
        device = frag.pos.device if isinstance(frag.pos, torch.Tensor) else torch.device('cpu')
        R_torch, t_torch = numpy_to_torch_R_t(R_target_canon, t_target_canon, device=device)
        frag.rot_gt = R_torch
        frag.trans_gt = t_torch

    return fragments


def generate_orbit_from_exemplar_perms(class_manager: FragmentClassManager,
                                       angle_tol: float = 1e-3,
                                       ) -> Dict[int, List[np.ndarray]]:
    """
    Uses Rotation Weights (Stubs > 0) to detect symmetries.
    """
    candidates_by_class: Dict[int, List[np.ndarray]] = {}
    SYMMETRY_MSE_THRESH = 0.01

    for cid, cls in enumerate(class_manager.classes):
        exemplar = cls.exemplar_pos.astype(np.float64)
        z = cls.exemplar_z.astype(np.int64)
        weights_rotation = class_manager._atom_weights_from_z(z) 

        perms_raw = [list(p) for p in cls.perms]
        perms_filtered = _filter_perms_preserve_atomic_numbers(perms_raw, z)
        seen = set(); perms_unique = []
        for p in perms_filtered:
            key = tuple(p)
            if key not in seen:
                seen.add(key); perms_unique.append(p)
        perms_unique.sort(key=lambda x: tuple(x))

        raw_candidates: List[np.ndarray] = []
        sel = class_manager._select_orientation_indices(z)
        
        for perm in perms_unique:
            perm = np.asarray(perm, dtype=int)
            if perm.shape[0] != exemplar.shape[0]: continue
            
            P_sub = exemplar[perm][sel]
            Q_sub = exemplar[sel]
            w_sub = weights_rotation[sel]
            
            R_np, t_np, _, _, _ = robust_kabsch_symmetry(P_sub, Q_sub, weights=w_sub)
            
            P_pred_all = (R_np @ exemplar[perm].T).T + t_np
            per_atom_full = np.linalg.norm(P_pred_all - exemplar, axis=1)
            
            w_sum = np.sum(weights_rotation)
            weighted_mse_full = float(np.sum(weights_rotation * (per_atom_full**2)) / w_sum)
            
            if weighted_mse_full < SYMMETRY_MSE_THRESH:
                raw_candidates.append(R_np)

        deduped: List[np.ndarray] = []
        for R in raw_candidates:
            keep = True
            for R_existing in deduped:
                # Use global utility
                if rotation_angle_distance(R_existing, R) <= angle_tol:
                    keep = False; break
            if keep: deduped.append(R.copy())

        if len(deduped) == 0: deduped.append(np.eye(3, dtype=float))

        cls.candidate_symmetries_exemplar = deduped
        candidates_by_class[cid] = deduped

    return candidates_by_class

def prepare_and_save_dataset(train_data, val_data, class_manager, output_path=".", rare_smiles_check=None):
    """
    Main driver.
    1. Fragments the data using the Manager.
    2. Computes SE(3) targets (Instance -> Exemplar).
    3. Computes Symmetries (Exemplar -> Exemplar).
    4. Serializes output in requested format.
    """
    
    # --- A. Build Instances ---
    print("1. Decomposing Molecules...")
    
    splits = [('train', train_data), ('val', val_data)]
    all_fragments_by_split = {'train': [], 'val': []}
    
    skipped_small = 0
    skipped_linear = 0
    skipped_unknown = 0
    
    # To map molecule_index -> list of fragments
    mol_fragments_map = {'train': {}, 'val': {}}

    for split_name, dataset in splits:
        print(f"  Processing {split_name}...")
        # freeze the vocabulary for val
        class_manager.frozen = (split_name != 'train')
        
        for idx, data in tqdm.tqdm(enumerate(dataset), total=len(dataset), desc=f"Decomposing {split_name}"):
            # --- Heavy Atom Filtering ---
            # Retrieve positions and atomic numbers
            raw_pos = data.pos.detach().cpu().numpy() if hasattr(data.pos, 'detach') else data.pos
            raw_z = data.h
            
            z_np = raw_z.detach().cpu().numpy() if hasattr(raw_z, 'detach') else np.asarray(raw_z)
            heavy_mask = (z_np > 1)
            heavy_pos = raw_pos[heavy_mask]

            # 1. Filter Small Molecules (based on Heavy Atoms)
            if len(heavy_pos) < 3:
                skipped_small += 1
                continue
                
            # 2. Filter Global Linearity (based on Heavy Atoms)
            P_c = heavy_pos - heavy_pos.mean(axis=0)
            S = np.linalg.svd(P_c, compute_uv=False)
            if S[1] < 0.05 * S[0]: # Threshold for global linearity
                skipped_linear += 1
                continue

            try:
                frags = decompose_to_rigid_fragments(data, class_manager, rare_smiles_check)
            except ValueError as e:
                if "FROZEN_FRAGMENT_LIBRARY" in str(e):
                    skipped_unknown += 1
                    continue
                raise e
            
            if not frags: 
                raise RuntimeError(f"Decomposition returned zero fragments for molecule index {idx} in split {split_name}.")

            all_fragments_by_split[split_name].extend(frags)
            mol_fragments_map[split_name][idx] = frags

    print(f"Skipped: {skipped_small} small (<3 heavy atoms), {skipped_linear} globally linear, {skipped_unknown} unknown/new fragments.")

    # --- B. Compute Targets (R_gt, t_gt) ---
    print("2. Computing Ground Truth SE(3)...")
    for split_name in ['train', 'val']:
        if all_fragments_by_split[split_name]:
            compute_se3_targets(all_fragments_by_split[split_name], class_manager)

    # --- C. Compute Symmetries (Orbits) ---
    print("3. Generating Symmetry Orbits...")
    _prev_frozen = class_manager.frozen
    class_manager.frozen = False
    generate_orbit_from_exemplar_perms(class_manager)
    class_manager.frozen = _prev_frozen

    # --- D. Serialize Library ---
    print("4. saving Fragment Library...")
    fragment_library = {}
    for i, cls in enumerate(class_manager.classes):
        fragment_library[i] = {
            'smiles': cls.smiles,
            'exemplar_z': torch.tensor(cls.exemplar_z, dtype=torch.long),
            'exemplar_pos': torch.tensor(cls.exemplar_pos, dtype=torch.float32), 
            'exemplar_edge_index': torch.tensor(cls.exemplar_edge_index, dtype=torch.long),
            'exemplar_edge_attr': torch.tensor(cls.exemplar_edge_attr, dtype=torch.long),
            'symmetries': torch.tensor(np.array(cls.candidate_symmetries_exemplar), dtype=torch.float32)
        }
    
    os.makedirs(output_path, exist_ok=True)
    torch.save(fragment_library, f"{output_path}/fragment_library.pt")

    # --- E. Serialize Data Splits ---
    print("5. Saving Processed Splits...")
    processed_data = {}
    
    for split_name in ['train', 'val']:
        dataset_list = []
        mol_map = mol_fragments_map[split_name]
        valid_indices = sorted(mol_map.keys())
        
        for midx in tqdm.tqdm(valid_indices, desc=f"Saving {split_name}"):
            frags = mol_map[midx]
            K = len(frags)
            
            ids = torch.tensor([f.class_id for f in frags], dtype=torch.long)
            trans = torch.stack([f.trans_gt for f in frags])
            rots = torch.stack([f.rot_gt for f in frags])
            orig_indices = [f.original_indices for f in frags]

            dataset_list.append({
                'frag_ids': ids,
                'trans': trans,
                'rots': rots,
                'num_frags': K,
                'original_indices': orig_indices
            })
            
        processed_data[split_name] = dataset_list
        torch.save(dataset_list, f"{output_path}/{split_name}_processed.pt")

    print("Done.")
    


if __name__ == "__main__":
    FOLDER_NAME_FIRST = 'fragmented_geom_ablation_planar_rings_and_fused_with_h_freq_based_first_part_with_rare'
    KIND = 'largest'
    FOLDER_NAME = f'fragmented_geom_ablation_planar_rings_and_fused_with_h_freq_based_{KIND}_threshold'
    with open(f"{FOLDER_NAME_FIRST}/rare_fragments_{KIND}.txt", 'r') as f:
        rare_smiles_check = [line.strip() for line in f.readlines()]
    train = torch.load('/MotiFlow/data/geom_drugs/train_data_40k_fixed_aligned.pt')
    val = torch.load('/MotiFlow/data/geom_drugs/val_data_5k_fixed_aligned.pt')
    class_manager = FragmentClassManager(1, 1)
    prepare_and_save_dataset(train, val, class_manager, output_path=f"{FOLDER_NAME}", rare_smiles_check=rare_smiles_check)
