import torch
from torch.func import vmap
from manifolds.general import Manifold_general # Import the parent class

def _dihedral(coords):
    """
    Computes the dihedral angle for a single molecular structure.
    Accepts coordinates for 4 atoms.
    """
    r12 = coords[1] - coords[0]
    r23 = coords[2] - coords[1]
    r34 = coords[3] - coords[2]
    
    n1 = torch.cross(r12, r23)
    n2 = torch.cross(r23, r34)
    
    cos_phi = torch.dot(n1, n2)
    sin_phi = torch.dot(n1, r34) * torch.norm(r23)
    
    return torch.atan2(sin_phi, cos_phi)

class Manifold_MD(Manifold_general):
    """
    Manifold class for Alanine Dipeptide, defined as a subclass of Manifold_general.

    This class specifies the dipeptide constraints and relies on the parent class
    for all projection, retraction, and differentiation logic. It can handle
    cases with or without inequality constraints on the psi angle.
    """
    def __init__(self, psi_windows=[(130, 170)], boundary_repulsion=0.1):
        """
        Initializes the manifold by defining dipeptide-specific constraints and
        passing them to the parent Manifold_general class.
        """
        # --- 1. Define Dipeptide-Specific Parameters ---
        dim = 30  # 10 atoms * 3 coordinates
        
        # Equality constraints (m): 1 (for the phi angle)
        m = 1
        self.phi_target_rad = torch.deg2rad(torch.tensor(-70.0))
        
        # --- MODIFICATION: Inequality constraints (l) ---
        # If psi_windows are defined, we now have only ONE inequality constraint,
        # which is the minimum distance to the outside of any window.
        if psi_windows:
            self.psi_windows_rad = torch.deg2rad(torch.tensor(psi_windows, dtype=torch.float32))
            l = 1 # We now have a single g(x) function.
        else:
            self.psi_windows_rad = torch.empty(0, 2)
            l = 0
        # --- END OF MODIFICATION ---

        # --- 2. Initialize the Parent Class ---
        super().__init__(dim=dim, m=m, l=l, h=self._h_phi, g=self._g_psi, boundary_repulsion_rate=boundary_repulsion)

    # --- Dipeptide-specific angle calculation methods ---
    def angle_phi(self, x):
        """Computes the phi angle for a batch of conformations."""
        atom_indices = [1, 3, 4, 6]
        return vmap(lambda s: _dihedral(s[atom_indices]))(x)

    def angle_psi(self, x):
        """Computes the psi angle for a batch of conformations."""
        atom_indices = [3, 4, 6, 8]
        return vmap(lambda s: _dihedral(s[atom_indices]))(x)

    # --- Single-instance constraint functions required by the parent class ---
    def _h_phi(self, x):
        """
        Equality constraint for a single sample: h(x) = angle_phi(x) - target = 0
        """
        x = x.reshape(-1, 3)
        phi = _dihedral(x[[1, 3, 4, 6]])
        return (phi - self.phi_target_rad).unsqueeze(0)

    def _g_psi(self, x):
        """
        Inequality constraints for a single sample: g(x) <= 0.
        This new implementation returns a single value that is negative if psi is
        inside ANY of the valid windows, and positive otherwise.
        """
        if self.l == 0:
            return torch.empty(0, device=x.device, dtype=x.dtype)
        
        x = x.reshape(-1, 3)
        psi = _dihedral(x[[3, 4, 6, 8]])
        
        # --- MODIFICATION: Vectorized min-of-max logic ---
        # For each window, calculate the "distance to the outside".
        # This is positive if psi is outside a window, negative if inside.
        # It's defined as max(psi - high, low - psi).
        
        lows = self.psi_windows_rad[:, 0].to(psi.device)
        highs = self.psi_windows_rad[:, 1].to(psi.device)
        
        # Calculate violations for all windows at once
        g1 = psi - highs  # (psi - high)
        g2 = lows - psi   # (low - psi)
        
        # Shape: (num_windows, 2). Columns are g1 and g2 for each window.
        per_window_violations = torch.stack([g1, g2], dim=1)
        
        # Shape: (num_windows,). Contains the max violation for each window.
        # This value is < 0 if psi is inside that window.
        distance_to_outside = torch.max(per_window_violations, dim=1).values
        
        # The final constraint is the minimum of these distances.
        # If psi is inside ANY window, this value will be < 0.
        min_distance = torch.min(distance_to_outside)
        
        # Return as a tensor of shape (l,), which is (1,) in this case.
        return min_distance.unsqueeze(0)
        # --- END OF MODIFICATION ---

    @torch.enable_grad()
    def constrain_grad_fn(self, samples):
        samples.requires_grad_(True)
        gradients = torch.autograd.grad(
            outputs=self.constrain_fn(samples).sum(),
            inputs=samples,
            create_graph=True,
            retain_graph=True)[0]
        return gradients.detach()