import torch


# Neighboring atom distance is upper bounded by
# distance in x0
def distance_upper_bound_objective(x0, xi_trainable, num_fixed=16, num_optimized=8, atoms_per_residue=4):
    """
    Primal objective that enforces original distances as upper bounds.
    Atoms can move closer but not further apart than in the original structure.
    
    Args:
        x0: original representation [residues, atoms, xyz]
        xi_trainable: trainable part of the representation
        num_fixed: number of fixed residues
        num_optimized: number of optimized residues
        atoms_per_residue: number of atoms per residue
    """
    total_loss = 0.0
    
    # Reshape if needed
    if xi_trainable.dim() == 2:
        xi_trainable = xi_trainable.view(num_optimized, atoms_per_residue, 3)
    
    # Reconstruct full xi for calculations
    xi_all = torch.cat([x0[:num_fixed], xi_trainable], dim=0)
    
    # 1. Intra-residue distances (distances between atoms within the same residue)
    for i in range(num_fixed, num_fixed + num_optimized):
        # Compare all pairs of atoms within this residue
        for a1 in range(atoms_per_residue):
            for a2 in range(a1 + 1, atoms_per_residue):
                # Original distance
                orig_dist = torch.norm(x0[i, a1] - x0[i, a2], p=2)
                # Current distance
                curr_dist = torch.norm(xi_all[i, a1] - xi_all[i, a2], p=2)
                # Penalize only if current distance exceeds original distance
                violation = torch.nn.functional.relu(curr_dist - orig_dist)
                total_loss += violation
    
    # 2. Inter-residue distances (distances between consecutive residues)
    for i in range(num_fixed + num_optimized - 1):
        # Only include pairs where at least one residue is being optimized
        if i >= num_fixed - 1:
            # Connect each atom in residue i to each atom in residue i+1
            for a1 in range(atoms_per_residue):
                for a2 in range(atoms_per_residue):
                    # Original distance between consecutive residues
                    orig_dist = torch.norm(x0[i, a1] - x0[i+1, a2], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[i, a1] - xi_all[i+1, a2], p=2)
                    # Penalize only if current distance exceeds original distance
                    violation = torch.nn.functional.relu(curr_dist - orig_dist)
                    total_loss += violation
    
    # 3. Distances to the last non-trainable residue's atoms
    if num_fixed > 0:
        last_fixed_residue_idx = num_fixed - 1
        
        # Connect each atom in the last fixed residue to atoms in the first few trainable residues
        for trainable_res_idx in range(min(3, num_optimized)):  # Connect to first 3 trainable residues or fewer
            trainable_residue = num_fixed + trainable_res_idx
            
            for fixed_atom in range(atoms_per_residue):
                for train_atom in range(atoms_per_residue):
                    # Original distance from last fixed atom to trainable atom
                    orig_dist = torch.norm(x0[last_fixed_residue_idx, fixed_atom] - 
                                          x0[trainable_residue, train_atom], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[last_fixed_residue_idx, fixed_atom] - 
                                          xi_all[trainable_residue, train_atom], p=2)
                    # Penalize only if current distance exceeds original distance (with higher weight)
                    violation = torch.nn.functional.relu(curr_dist - orig_dist)
                    total_loss += 1.5 * violation
    
    return total_loss




# Neighboring atom distance is consistent with
# distance in x0
def distance_preservation_objective(x0, xi_trainable, num_fixed=16, num_optimized=8, atoms_per_residue=4):
    """
    Simple primal objective that preserves distances between neighboring atoms,
    including connections to the last non-trainable atoms.
    
    Args:
        x0: original representation [residues, atoms, xyz]
        xi_trainable: trainable part of the representation
        num_fixed: number of fixed residues
        num_optimized: number of optimized residues
        atoms_per_residue: number of atoms per residue
    """
    total_loss = 0.0
    
    # Reshape if needed
    if xi_trainable.dim() == 2:
        xi_trainable = xi_trainable.view(num_optimized, atoms_per_residue, 3)
    
    # Reconstruct full xi for calculations
    xi_all = torch.cat([x0[:num_fixed], xi_trainable], dim=0)
    
    # 1. Intra-residue distances (distances between atoms within the same residue)
    for i in range(num_fixed, num_fixed + num_optimized):
        residue_idx = i - num_fixed  # Index in the trainable portion
        
        # Compare all pairs of atoms within this residue
        for a1 in range(atoms_per_residue):
            for a2 in range(a1 + 1, atoms_per_residue):
                # Original distance
                orig_dist = torch.norm(x0[i, a1] - x0[i, a2], p=2)
                # Current distance
                curr_dist = torch.norm(xi_all[i, a1] - xi_all[i, a2], p=2)
                # Penalize changes in distances
                total_loss += torch.abs(orig_dist - curr_dist)
    
    # 2. Inter-residue distances (distances between consecutive residues)
    for i in range(num_fixed + num_optimized - 1):
        # Only include pairs where at least one residue is being optimized
        if i >= num_fixed - 1:
            # Connect each atom in residue i to each atom in residue i+1
            for a1 in range(atoms_per_residue):
                for a2 in range(atoms_per_residue):
                    # Original distance between consecutive residues
                    orig_dist = torch.norm(x0[i, a1] - x0[i+1, a2], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[i, a1] - xi_all[i+1, a2], p=2)
                    # Penalize changes in distances
                    total_loss += torch.abs(orig_dist - curr_dist)
    
    # 3. NEW: Distances to the last non-trainable residue's atoms
    if num_fixed > 0:
        last_fixed_residue_idx = num_fixed - 1
        
        # Connect each atom in the last fixed residue to atoms in the first few trainable residues
        # This adds more stability at the boundary
        for trainable_res_idx in range(min(3, num_optimized)):  # Connect to first 3 trainable residues or fewer
            trainable_residue = num_fixed + trainable_res_idx
            
            for fixed_atom in range(atoms_per_residue):
                for train_atom in range(atoms_per_residue):
                    # Original distance from last fixed atom to trainable atom
                    orig_dist = torch.norm(x0[last_fixed_residue_idx, fixed_atom] - 
                                          x0[trainable_residue, train_atom], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[last_fixed_residue_idx, fixed_atom] - 
                                          xi_all[trainable_residue, train_atom], p=2)
                    # Penalize changes in distances (with higher weight for this boundary)
                    total_loss += 1.5 * torch.abs(orig_dist - curr_dist)
    
    return total_loss



def distance_ub_objective(x0, xi_trainable, num_fixed=16, num_optimized=8, atoms_per_residue=4):
    """
    Primal objective that enforces original distances as upper bounds.
    Atoms can move closer but not further apart than in the original structure.
    
    Args:
        x0: original representation [residues, atoms, xyz]
        xi_trainable: trainable part of the representation
        num_fixed: number of fixed residues
        num_optimized: number of optimized residues
        atoms_per_residue: number of atoms per residue
    """
    total_loss = 0.0
    
    # Reshape if needed
    if xi_trainable.dim() == 2:
        xi_trainable = xi_trainable.view(num_optimized, atoms_per_residue, 3)
    
    # Reconstruct full xi for calculations
    xi_all = torch.cat([x0[:num_fixed], xi_trainable], dim=0)
    
    # 1. Intra-residue distances (distances between atoms within the same residue)
    for i in range(num_fixed, num_fixed + num_optimized):
        # Compare all pairs of atoms within this residue
        for a1 in range(atoms_per_residue):
            for a2 in range(a1 + 1, atoms_per_residue):
                # Original distance
                orig_dist = torch.norm(x0[i, a1] - x0[i, a2], p=2)
                # Current distance
                curr_dist = torch.norm(xi_all[i, a1] - xi_all[i, a2], p=2)
                # Penalize only if current distance exceeds original distance
                violation = torch.nn.functional.relu(curr_dist - 1.5)
                total_loss += violation
    
    # 2. Inter-residue distances (distances between consecutive residues)
    for i in range(num_fixed + num_optimized - 1):
        # Only include pairs where at least one residue is being optimized
        if i >= num_fixed - 1:
            # Connect each atom in residue i to each atom in residue i+1
            for a1 in range(atoms_per_residue):
                for a2 in range(atoms_per_residue):
                    # Original distance between consecutive residues
                    orig_dist = torch.norm(x0[i, a1] - x0[i+1, a2], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[i, a1] - xi_all[i+1, a2], p=2)
                    # Penalize only if current distance exceeds original distance
                    violation = torch.nn.functional.relu(curr_dist - 1.5)
                    total_loss += violation
    
    # 3. Distances to the last non-trainable residue's atoms
    if num_fixed > 0:
        last_fixed_residue_idx = num_fixed - 1
        
        # Connect each atom in the last fixed residue to atoms in the first few trainable residues
        for trainable_res_idx in range(min(3, num_optimized)):  # Connect to first 3 trainable residues or fewer
            trainable_residue = num_fixed + trainable_res_idx
            
            for fixed_atom in range(atoms_per_residue):
                for train_atom in range(atoms_per_residue):
                    # Original distance from last fixed atom to trainable atom
                    orig_dist = torch.norm(1.5 - 
                                          x0[trainable_residue, train_atom], p=2)
                    # Current distance
                    curr_dist = torch.norm(xi_all[last_fixed_residue_idx, fixed_atom] - 
                                          xi_all[trainable_residue, train_atom], p=2)
                    # Penalize only if current distance exceeds original distance (with higher weight)
                    violation = torch.nn.functional.relu(curr_dist - 1.5)
                    total_loss += 1.5 * violation
    
    return total_loss