import torch
from torch import Tensor
import torch.nn.functional as F
import math
from typing import Dict, Tuple


############# HELPER FUNCTIONS #############


# Distance Constraints:
# - N - - O to be 2.9 +/- 0.2

# Angle Constraints:
# - C=O - - N to be 155 +/- 10
# - CA-N - - O to be 120 +/- 10

# [:, 0], [:, 1], [:, 2], [:, 3]
# " N  ", " CA ", " C  ", " O  "


# Compute angle between three points
def compute_angle(a, b, c):
    """∠ABC in radians for three 3-D tensors."""
    ba, bc = a - b, c - b
    cosang = torch.dot(ba, bc) / (ba.norm() * bc.norm())
    return torch.acos(torch.clamp(cosang, -1.0, 1.0))



# Joint constraint set of distance and angle constraints
def joint_penalty(xi, start_index=None):
    return distance_penalty(xi, start_index) + angle_penalty(xi)



# Extract backbone chain from pdb file path
def extract_backbone_tensor(pdb_path, chain_id='P'):
    backbone_atoms = ['N', 'CA', 'C', 'O']
    residues = {}
    
    with open(pdb_path, 'r') as f:
        for line in f:
            if line.startswith("ATOM"):
                atom_name = line[12:16].strip()
                current_chain = line[21]
                res_seq = int(line[22:26])
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                
                if current_chain == chain_id and atom_name in backbone_atoms:
                    if res_seq not in residues:
                        residues[res_seq] = {}
                    residues[res_seq][atom_name] = [x, y, z]
    
    # Sort residues by residue number and format as tensor
    backbone_tensor = []
    for res_id in sorted(residues.keys()):
        atoms = residues[res_id]
        if all(atom in atoms for atom in backbone_atoms):
            coords = [atoms[atom] for atom in backbone_atoms]
            backbone_tensor.append(coords)
    
    return torch.tensor(backbone_tensor, dtype=torch.float32)



############# BETA SHEET (TOY EXAMPLE) #############


def distance_penalty_with_offset(xi, offset=0, start_index=None, beta_length=16, tolerance=0.0):
    """
    Calculate distance penalty only at a specific position offset.
    
    Parameters:
    - xi: The coordinates tensor
    - offset: The offset from start_index to consider
    - start_index: Starting index for pairing (if None, computed automatically)
    - beta_length: Length of beta strand
    - tolerance: Additional tolerance to apply
    
    Returns:
    - Distance penalty for the specified position offset
    """
    
    if start_index is None:
        # Compute index to start pairings
        last_beta = xi[beta_length-1, 0]
        first_bond_dist = torch.norm(xi[(beta_length + 3):, 3] - last_beta, dim=1) 
        start_index = torch.argmin(first_bond_dist) + (beta_length + 3)
    
    # Apply offset to start_index
    actual_start = start_index + offset
    
    # Ensure we're still within valid range
    if actual_start >= xi.shape[0]:
        return torch.tensor(0.0, device=xi.device)
    
    # Get corresponding beta residue index
    beta_index = beta_length - 1 - offset
    if beta_index < 0:
        return torch.tensor(0.0, device=xi.device)
    
    # Ideal H-bond distance
    ideal_distance = 2.9  # Å
    
    # Position-dependent tolerance
    position_tolerance = 0.2
    if offset == 0:
        position_tolerance = 0.6
    
    # Add the user-provided tolerance
    total_tolerance = position_tolerance + tolerance
    
    # Lower and upper bounds for this position
    lower_bound = ideal_distance - total_tolerance
    upper_bound = ideal_distance + total_tolerance
    
    # O to N distance
    dist_ON = torch.norm(xi[actual_start, 0] - xi[beta_index, 3])
    dist_ON_violation = torch.clamp(lower_bound - dist_ON, min=0) + torch.clamp(dist_ON - upper_bound, min=0)
    
    # N to O distance
    dist_NO = torch.norm(xi[actual_start, 3] - xi[beta_index, 0])
    dist_NO_violation = torch.clamp(lower_bound - dist_NO, min=0) + torch.clamp(dist_NO - upper_bound, min=0)
    
    # Total violation for this offset
    violation = dist_ON_violation + dist_NO_violation
    
    return violation



def angle_penalty_with_offset(xi, offset=0, start_index=None, beta_length=16, tolerance=0.0):
    """
    Calculate penalty for all angle types but only at a specific position offset.
    
    Parameters:
    - xi: The coordinates tensor
    - offset: The offset from start_index to consider (default: 0)
    - start_index: Starting index for angle calculation (if None, computed automatically)
    - beta_length: Length of beta sheet (default: 16)
    
    Returns:
    - Sum of penalties for all angle types at the specified position offset
    """
    
    def compute_angle(a, c, b):
        """∠ABC in radians for three 3-D tensors."""
        ba, bc = a - b, c - b
        cosang = torch.dot(ba, bc) / (ba.norm() * bc.norm())
        return torch.acos(torch.clamp(cosang, -1.0, 1.0))
    
    # Convert target angle from degrees to radians
    target_rad_1 = 155 * math.pi / 180  # For CN and NC angles
    target_rad_2 = 120 * math.pi / 180  # For CAO and OCA angles
    tolerance_rad = 10 * math.pi / 180 + tolerance
    
    if start_index is None:
        # Compute index to start pairings
        last_beta = xi[beta_length-1, 0]
        first_bond_dist = torch.norm(xi[(beta_length + 3):, 3] - last_beta, dim=1) 
        start_index = torch.argmin(first_bond_dist) + (beta_length + 3)
    
    # Apply offset to start_index
    actual_start = start_index + offset
    
    # Ensure we're still within valid range
    if actual_start >= xi.shape[0]:
        return torch.tensor(0.0, device=xi.device)
    
    # Get corresponding beta residue index
    beta_index = beta_length - 1 - offset
    if beta_index < 0:
        return torch.tensor(0.0, device=xi.device)
    
    # Calculate all four angle types for this specific position
    
    # C to N angles
    CN_angle = compute_angle(xi[actual_start, 2], xi[beta_index, 0], xi[actual_start, 3])
    NC_angle = compute_angle(xi[beta_index, 2], xi[actual_start, 0], xi[beta_index, 3])
    
    # Tolerance for C to N
    lower_bound_1 = target_rad_1 - tolerance_rad
    upper_bound_1 = target_rad_1 + tolerance_rad
    
    # Penalty terms for C to N
    penalty_CN = torch.clamp(lower_bound_1 - CN_angle, min=0) + torch.clamp(CN_angle - upper_bound_1, min=0)
    penalty_NC = torch.clamp(lower_bound_1 - NC_angle, min=0) + torch.clamp(NC_angle - upper_bound_1, min=0)
    
    # CA to O angles
    CAO_angle = compute_angle(xi[actual_start, 1], xi[beta_index, 3], xi[actual_start, 0])
    OCA_angle = compute_angle(xi[beta_index, 1], xi[actual_start, 3], xi[beta_index, 0])
    
    # Tolerance for CA to O
    lower_bound_2 = target_rad_2 - tolerance_rad
    upper_bound_2 = target_rad_2 + tolerance_rad
    
    # Penalty terms for CA to O
    penalty_CAO = torch.clamp(lower_bound_2 - CAO_angle, min=0) + torch.clamp(CAO_angle - upper_bound_2, min=0)
    penalty_OCA = torch.clamp(lower_bound_2 - OCA_angle, min=0) + torch.clamp(OCA_angle - upper_bound_2, min=0)
    
    # Sum all penalties for this position
    total_penalty = penalty_CN + penalty_NC + penalty_CAO + penalty_OCA
    
    return total_penalty


################### PDZ SETTING ####################


def distance_penalty_pdz(xi, p_chain, offset=0, start_index=94, p_start=-2, tolerance=0.0):
    # return torch.tensor(0.0)
    # Alternating: N, O with residue laddering
    residue = start_index + (offset - (offset % 2))
    atom    = 0 if (offset % 2) == 0 else 3
    
        
    # Alternating: O, N with laddering
    p_residue = p_start - (offset - (offset % 2))
    p_atom    = 3 if (offset % 2) == 0 else 0
    
    
    # Ideal H-bond distance
    ideal_distance = 2.9  # instead of 2.9 Å?
    
    # Position-dependent tolerance
    position_tolerance = 0.2
    
    # Add the user-provided tolerance
    total_tolerance = position_tolerance + tolerance
    
    # Lower and upper bounds for this position
    lower_bound = ideal_distance - total_tolerance
    upper_bound = ideal_distance + total_tolerance
    
    # Compute distance
    dist = torch.norm(xi[residue, atom] - p_chain[p_residue, p_atom])
    
    dist_violation = torch.clamp(lower_bound - dist, min=0) + torch.clamp(dist - upper_bound, min=0)
    
    # extra_penalty = peptide_clash_penalty(xi[87:], p_chain)
    
    return dist_violation # + extra_penalty





def peptide_clash_penalty(
    xi: Tensor,
    p_chain: Tensor,
    *,
    min_distance: float = 2.7,
) -> Tensor:
    """
    Penalise any atom-pair distance < min_distance.

    Parameters
    ----------
    xi        : (R, 4, 3)   – the fragment being optimised
    p_chain   : (M, 4, 3)   – reference / frozen part
    min_distance : Å cut-off for steric clash
    """
    return torch.tensor(0.0)
    R, A, _ = xi.shape                         # A should be 4

    # ---- 1) xi  vs  reference peptide --------------------------------
    xi_flat       = xi.reshape(-1, 3)          # (R*A, 3)
    peptide_flat  = p_chain.reshape(-1, 3)     # (M*A, 3)

    d_ext = torch.cdist(xi_flat, peptide_flat, p=2)   # (R*A, M*A)
    ext_viol = torch.clamp(min_distance - d_ext, min=0.0).sum()

    # ---- 2) xi  vs  xi  (self-clash) ---------------------------------
    # pairwise distances inside xi
    d_int = torch.cdist(xi_flat, xi_flat, p=2)        # (R*A, R*A)

    # mask out the upper triangle & self
    tri_mask = torch.triu(torch.ones_like(d_int, dtype=torch.bool), diagonal=1)

    # mask out covalent neighbours (same residue, or peptide C–N)
    # Build residue index & atom index for every flattened row
    res_idx = torch.arange(R, device=xi.device).repeat_interleave(A)  # (R*A,)
    atom_idx = torch.arange(A, device=xi.device).tile(R)              # 0,1,2,3,...

    # same residue  → exclude
    same_res = res_idx.unsqueeze(0) == res_idx.unsqueeze(1)

    # peptide C(i)-N(i+1) pairs → exclude
    peptide_mask = (atom_idx.unsqueeze(0) == 2) & (atom_idx.unsqueeze(1) == 0) \
                 & ((res_idx.unsqueeze(0) + 1) == res_idx.unsqueeze(1))
    peptide_mask |= peptide_mask.clone().T         # symmetry

    # final mask: consider only non-bonded, unique pairs
    keep = tri_mask & ~same_res & ~peptide_mask

    # compute self-clash penalty
    if keep.any():
        sel = d_int[keep]
        int_viol = torch.clamp(min_distance - sel, min=0.0).sum()
    else:
        int_viol = torch.tensor(0.0, device=xi.device, dtype=xi.dtype)

    return ext_viol #+ int_viol



def angle_penalty_pdz(xi, p_chain, offset=0, start_index=94, p_start=-2,
                      tolerance=0.0):
    """
    Zero penalty when
        ∠(C‑O‑N) = 155° ± 10°
        ∠(CA‑N‑O) = 120° ± 10°
    The first angle is centred at O, the second at N.
    """
    # return torch.tensor(0.0)
    tgt_O = math.radians(155.0)       # angle at O
    tgt_N = math.radians(120.0)       # angle at N
    base_tol = math.radians(10.0) # math.radians(10.0)
    tol = base_tol + tolerance

    residue = start_index + (offset - (offset % 2))
    a_atom  = 0 if (offset % 2) == 0 else 3      # moving atom
    ac_atom = 1 if (offset % 2) == 0 else 2      # its bonded anchor
    a_pos, ac_pos = (residue, a_atom), (residue, ac_atom)

    p_residue = p_start - (offset - (offset % 2))
    p_atom    = 3 if (offset % 2) == 0 else 0    # fixed counterpart
    pc_atom   = 2 if (offset % 2) == 0 else 1
    p_pos, pc_pos = (p_residue, p_atom), (p_residue, pc_atom)

    if offset % 2 == 0:          # N is moving, O is fixed
        N_i,  CA_i   = xi[a_pos],   xi[ac_pos]
        O_fix, C_fix = p_chain[p_pos], p_chain[pc_pos]

        ang_O = compute_angle(C_fix, O_fix, N_i)   # C‑O‑N (centre O)
        ang_N = compute_angle(CA_i, N_i, O_fix)    # CA‑N‑O (centre N)

    else:                         # O is moving, N is fixed        
        O_i,  C_i      = xi[a_pos],   xi[ac_pos]
        N_fix, CA_fix  = p_chain[p_pos], p_chain[pc_pos]

        ang_O = compute_angle(C_i,  O_i,  N_fix)   # C‑O‑N (centre O)
        ang_N = compute_angle(CA_fix, N_fix, O_i)  # CA‑N‑O (centre N)

    def excess(a, tgt):
        return torch.relu(torch.abs(a - tgt) - tol)**2

    return excess(ang_O, tgt_O) + excess(ang_N, tgt_N)




# ---------------------------------------------------------------------
# unordered atom-index pair → (min Å, max Å)
PAIR_BOUNDS: Dict[Tuple[int, int], Tuple[float, float, str]] = {
    # Covalent bonds (HARD constraints)
    (0, 1): (1.42, 1.50, "covalent"),    # N–CA
    (1, 2): (1.48, 1.56, "covalent"),    # CA–C
    (2, 3): (1.20, 1.26, "covalent"),    # C–O
    (2, 0): (1.30, 1.36, "peptide"),     # C–N peptide bond (inter-residue)
    
    # 2-step distances (SOFT constraints)
    (0, 2): (2.25, 2.65, "soft"),        # N–C intra-residue
    (1, 3): (2.20, 2.65, "soft"),        # CA–O intra-residue
    (3, 0): (2.80, 3.20, "soft"),        # O–N inter-residue
}

# ---------------------------------------------------------------------

def gap_penalty_ra(
    x: Tensor,
    pair_ra: torch.LongTensor,
    *,
    covalent_weight: float = 100.0,    # Heavy penalty for bond breaking
    peptide_weight: float = 50.0,      # Medium penalty for peptide bonds
    soft_weight: float = 5.0,          # Light penalty for conformational constraints
    global_min: float = 1.2,
    global_max: float = 3.5,
) -> Tensor:
    """
    Hierarchical penalty with different weights for different constraint types.
    """
    return torch.tensor(0.0)
    if pair_ra.size(0) == 0:
        return torch.tensor(0.0, device=x.device)
    
    pi_r, pi_a, pj_r, pj_a, constraint_type = pair_ra.t()
    vi = x[pi_r, pi_a]  # (K,3)
    vj = x[pj_r, pj_a]  # (K,3)
    d = (vi - vj).norm(dim=1)  # (K,)
    
    # Build per-pair bounds
    key = torch.sort(torch.stack([pi_a, pj_a], dim=1), dim=1)[0]  # (K,2)
    min_b = torch.full_like(d, global_min)
    max_b = torch.full_like(d, global_max)
    
    # Apply specific bounds
    for (a, b), (lo, hi, _) in PAIR_BOUNDS.items():
        mask = (key[:, 0] == a) & (key[:, 1] == b)
        min_b[mask] = lo
        max_b[mask] = hi
    
    # Calculate violations
    under = torch.clamp(min_b - d, min=0.0)
    over = torch.clamp(d - max_b, min=0.0)
    violations = under + over
    
    # Apply hierarchical weights
    weights = torch.ones_like(violations)
    weights[constraint_type == 0] = covalent_weight  # Covalent bonds
    weights[constraint_type == 1] = peptide_weight   # Peptide bonds
    weights[constraint_type == 2] = soft_weight      # Soft constraints
    
    return (violations * weights).sum()



def planarity_penalty(x, res):
    """
    N–CA–C and CA–C–O angle deviation, quadratic cost once |Δθ| > 5°
    """
    import torch, math
    N  = x[res,0]; CA = x[res,1]; C = x[res,2]; O = x[res,3]
    def angle(a,b,c):
        ba = a - b; bc = c - b
        cosang = torch.dot(ba, bc) / (ba.norm()*bc.norm())
        return torch.acos(torch.clamp(cosang, -1.0, 1.0))
    theta1 = angle(N,  CA, C)    # ideal ≈ 110°
    theta2 = angle(CA, C,  O)    # ideal ≈ 120°
    loss  = torch.clamp(torch.abs(theta1 - 110*math.pi/180) - 5*math.pi/180, min=0)
    loss += torch.clamp(torch.abs(theta2 - 120*math.pi/180) - 5*math.pi/180, min=0)
    return loss



##########################################################



import torch
from torch import Tensor
import math
from typing import List, Tuple, Dict

def bond_penalty(x: Tensor, bond_constraints: torch.LongTensor, weight: float = 100.0) -> Tensor:
    """Distance-based bond constraints"""
    if bond_constraints.size(0) == 0:
        return torch.tensor(0.0, device=x.device)
    
    # Bond length targets
    BOND_TARGETS = {
        0: (1.46, 0.02),  # N-CA: 1.46 ± 0.02 Å
        1: (1.33, 0.02),  # C-N peptide: 1.33 ± 0.02 Å  
        2: (1.52, 0.02),  # CA-C: 1.52 ± 0.02 Å
        3: (1.23, 0.02),  # C-O: 1.23 ± 0.02 Å
    }
    
    ri, ai, rj, aj, bond_type = bond_constraints.t()
    vi = x[ri, ai]
    vj = x[rj, aj]
    distances = (vi - vj).norm(dim=1)
    
    penalty = torch.tensor(0.0, device=x.device)
    for bt in torch.unique(bond_type):
        mask = bond_type == bt
        if bt.item() in BOND_TARGETS:
            target, tolerance = BOND_TARGETS[bt.item()]
            deviations = torch.abs(distances[mask] - target) - tolerance
            penalty += weight * torch.clamp(deviations, min=0.0).sum()
    
    return penalty


def angle_penalty(x: Tensor, angle_constraints: torch.LongTensor, weight: float = 10.0) -> Tensor:
    """Angular constraints for proper backbone geometry"""
    if angle_constraints.size(0) == 0:
        return torch.tensor(0.0, device=x.device)
    
    # Target angles in radians
    ANGLE_TARGETS = {
        0: (math.radians(109.5), math.radians(5)),   # Tetrahedral angles (N-CA-C)
        1: (math.radians(121.0), math.radians(3)),   # Planar angles (CA-C-O)
        2: (math.radians(117.0), math.radians(3)),   # Peptide angles (CA-C-N)
        3: (math.radians(122.0), math.radians(5)),   # Peptide angles (C-N-CA)
    }
    
    # Format: (r1, a1, r2, a2, r3, a3, type)
    r1, a1, r2, a2, r3, a3, angle_type = angle_constraints.t()
    
    v1 = x[r1, a1]  # First atom
    v2 = x[r2, a2]  # Vertex atom
    v3 = x[r3, a3]  # Third atom
    
    # Calculate angles
    vec1 = v1 - v2
    vec2 = v3 - v2
    vec1_norm = vec1 / (vec1.norm(dim=1, keepdim=True) + 1e-8)
    vec2_norm = vec2 / (vec2.norm(dim=1, keepdim=True) + 1e-8)
    
    cos_angles = torch.clamp((vec1_norm * vec2_norm).sum(dim=1), -1.0 + 1e-6, 1.0 - 1e-6)
    angles = torch.acos(cos_angles)
    
    penalty = torch.tensor(0.0, device=x.device)
    for at in torch.unique(angle_type):
        mask = angle_type == at
        if at.item() in ANGLE_TARGETS:
            target, tolerance = ANGLE_TARGETS[at.item()]
            deviations = torch.abs(angles[mask] - target) - tolerance
            penalty += weight * torch.clamp(deviations, min=0.0).sum()
    
    return penalty


def omega_dihedral_penalty(x: Tensor, dihedral_constraints: torch.LongTensor, weight: float = 5.0) -> Tensor:
    """Omega dihedral constraints for peptide planarity"""
    if dihedral_constraints.size(0) == 0:
        return torch.tensor(0.0, device=x.device)
    
    r1, a1, r2, a2, r3, a3, r4, a4, dihed_type = dihedral_constraints.t()
    
    v1 = x[r1, a1]  # CA(i)
    v2 = x[r2, a2]  # C(i)
    v3 = x[r3, a3]  # N(i+1)
    v4 = x[r4, a4]  # CA(i+1)
    
    # Calculate dihedral angle
    b1 = v2 - v1
    b2 = v3 - v2
    b3 = v4 - v3
    
    n1 = torch.cross(b1, b2, dim=1)
    n2 = torch.cross(b2, b3, dim=1)
    
    n1_norm = n1 / (n1.norm(dim=1, keepdim=True) + 1e-8)
    n2_norm = n2 / (n2.norm(dim=1, keepdim=True) + 1e-8)
    
    cos_dihed = torch.clamp((n1_norm * n2_norm).sum(dim=1), -1.0 + 1e-6, 1.0 - 1e-6)
    
    # Omega should be close to 180° (π radians) for trans peptides
    target_omega = math.pi
    tolerance = math.radians(15)  # Allow some flexibility
    
    angles = torch.acos(torch.abs(cos_dihed))  # Use abs to handle both ±180°
    deviations = torch.abs(angles - target_omega) - tolerance
    
    return weight * torch.clamp(deviations, min=0.0).sum()


def complete_geometry_penalty(
    x: Tensor,
    constraints: Dict[str, torch.LongTensor],
    *,
    bond_weight: float = 100.0,
    angle_weight: float = 20.0,
    dihedral_weight: float = 10.0,
    full_weight: float = 10.0,
) -> Tensor:
    """Combined penalty for complete protein geometry"""
    
    
    penalty = torch.tensor(0.0, device=x.device)
    
    # Bond constraints (most important)
    if 'bonds' in constraints:
        penalty += bond_penalty(x, constraints['bonds'], bond_weight)
    
    # Angle constraints (secondary)
    if 'angles' in constraints:
        penalty += angle_penalty(x, constraints['angles'], angle_weight)
    
    
    # A bit excessive -- leave ommitted
    # Dihedral constraints (tertiary, for planarity)
    if 'dihedrals' in constraints:
        penalty += omega_dihedral_penalty(x, constraints['dihedrals'], dihedral_weight)
    
    return full_weight * penalty


def differentiable_break_penalty(
    bb_tensor: torch.Tensor,
    break_thresh: float = 4.5,
    softness: float = 10.0,
    reduction: str = "mean",
) -> torch.Tensor:
    """Numerically‑stable, differentiable chain‑break penalty.

    Uses **softplus** (PyTorch’s stable implementation) so you won’t hit
    ``inf`` even for very large CA–CA distances.  Formula:

    ``penalty = softplus(excess, beta=softness) / softness``

    where ``excess = d - break_thresh`` and ``softplus(x,beta) =
    1/β·log(1+exp(β·x))``.
    """
    import torch.nn.functional as F

    CA = bb_tensor[:, 1, :]
    dists = torch.linalg.norm(CA[1:] - CA[:-1], dim=1)

    excess = dists - break_thresh  # [n_res‑1]
    penalty = F.softplus(excess, beta=softness) / softness  # stable soft hinge

    if reduction == "mean":
        return penalty.mean()
    elif reduction == "sum":
        return penalty.sum()
    elif reduction == "none":
        return penalty
    else:
        raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        
        
        
##########################################################


# Atom indices in your (R,4,3): 0:N, 1:CA, 2:C, 3:O

def place_amide_H(N, CA, C, bond_len=1.01, angle_deg=120.0):
    """
    Deterministic peptide H placement in the N-CA-C plane.
    H is placed from N toward the bisector of vectors to CA and C (approximate).
    """
    v1 = (CA - N); v2 = (C - N)
    u1 = v1 / (torch.norm(v1, dim=-1, keepdim=True) + 1e-8)
    u2 = v2 / (torch.norm(v2, dim=-1, keepdim=True) + 1e-8)
    d  = (u1 + u2)
    d  = d / (torch.norm(d, dim=-1, keepdim=True) + 1e-8)
    # optional: enforce target angle C-N-H by rotating slightly around normal; skipped for speed
    H  = N + bond_len * d
    return H

def dssp_energy_CO_NH(C_i, O_i, N_j, H_j):
    """
    Kabsch–Sander electrostatic energy (kcal/mol):
    E = 332 * 0.084 * ( 1/r_ON + 1/r_CH - 1/r_OH - 1/r_CN )
    """
    k = 332.0 * 0.084
    r_ON = torch.norm(O_i - N_j, dim=-1)
    r_CH = torch.norm(C_i - H_j, dim=-1)
    r_OH = torch.norm(O_i - H_j, dim=-1)
    r_CN = torch.norm(C_i - N_j, dim=-1)
    E = k * (1.0/(r_ON+1e-8) + 1.0/(r_CH+1e-8) - 1.0/(r_OH+1e-8) - 1.0/(r_CN+1e-8))
    return E

def beta_sheet_penalty(
    xi_full: torch.Tensor,                # (R,4,3)
    window=(90, 100),                     # inclusive bounds in 1-based residue indexing?
    min_len: int = 5,
    tau_energy: float = 0.15,             # softness around DSSP cutoff -0.5
    tau_min: float = 0.25,                # softness for softmin across candidates
    max_shift: int = 1,                   # allow small registry shifts
    require_pairs: int = None,            # if None, use L (strand length)
):
    """
    Returns a single scalar >= 0. ~0 iff there exists a parallel or antiparallel
    two-strand alignment of length >= min_len with DSSP-like H-bonds.

    We evaluate per-candidate loss = mean softplus( (E_ij + 0.5)/tau_energy )
    over its inter-strand pairs (i,j); take softmin across candidates.
    """
    # convert to 0-based indices assuming xi_full[0] is residue 0; adjust if needed
    lo, hi = window
    # if your residues are 0-based already, keep as is; otherwise subtract 1 here
    lo0, hi0 = lo, hi

    N = xi_full[..., 0, :]   # (R,3)
    CA = xi_full[..., 1, :]
    C  = xi_full[..., 2, :]
    O  = xi_full[..., 3, :]

    # virtual H for all residues
    H = place_amide_H(N, CA, C)

    candidates = []
    for a_start in range(lo0, hi0 - min_len + 2):              # left strand
        for L in range(min_len, hi0 - a_start + 1):
            a_end = a_start + L - 1
            if a_end >= hi0: break
            # right strand must start after left; hairpin-ish
            for b_start in range(a_end + 1, hi0 - min_len + 2):
                b_maxL = hi0 - b_start + 1
                L2 = min(L, b_maxL)                            # balanced length
                if L2 < min_len: continue
                b_end = b_start + L2 - 1

                # build pair lists for both registries with small shifts
                for shift in range(-max_shift, max_shift + 1):
                    # antiparallel: i in a increasing, j in b decreasing with registry shift
                    pairs_anti = []
                    for t in range(L2):
                        i = a_start + t
                        j = b_end - (t + shift)
                        if j < b_start or j > b_end: continue
                        pairs_anti.append((i, j))
                    if len(pairs_anti) >= (require_pairs or L2):
                        candidates.append(("anti", pairs_anti))

                    # parallel: i in a increasing, j in b increasing with shift
                    pairs_para = []
                    for t in range(L2):
                        i = a_start + t
                        j = b_start + (t + shift)
                        if j < b_start or j > b_end: continue
                        pairs_para.append((i, j))
                    if len(pairs_para) >= (require_pairs or L2):
                        candidates.append(("para", pairs_para))

    if not candidates:
        # No feasible candidate geometry inside window; penalize softly
        return torch.tensor(1.0, device=xi_full.device, dtype=xi_full.dtype)

    # compute per-candidate loss
    losses = []
    for kind, pairs in candidates:
        # Inter-strand HBond: use CO (i) with NH (j) AND the reciprocal ladder
        E_list = []
        for i, j in pairs:
            E_ij = dssp_energy_CO_NH(C[i], O[i], N[j], H[j])
            # you can optionally include the reciprocal (j->i) for stronger ladders:
            E_ji = dssp_energy_CO_NH(C[j], O[j], N[i], H[i])
            E_list.extend([E_ij, E_ji])
        E = torch.stack(E_list)
        # soft hinge at -0.5 (kcal/mol): zero when clearly bonded
        per_pair = torch.nn.functional.softplus((E + 0.5)/tau_energy)
        loss = per_pair.mean()
        losses.append(loss)

    Ls = torch.stack(losses)  # (num_candidates,)
    # softmin across candidates
    softmin_val = -tau_min * torch.logsumexp(-Ls / tau_min, dim=0)
    # This is >=0 and ~0 if at least one candidate’s average hinge ≈ 0
    return softmin_val
