import torch
import torch.nn.functional as F
from SDE_model.rigid_utils import  Rigid


def loss_fn_tensor(pred_atom4, gt_atom4, t, mask):
    bb_mask = mask 
    loss_mask = bb_mask
    batch_size, _, num_res = bb_mask.shape

    """begin to calculate loss """
    bb_atom_loss_mask = loss_mask[..., None, None].expand(-1, -1, -1, 4, 3)

    per_atom_loss = ((pred_atom4 - gt_atom4) ** 2) * bb_atom_loss_mask  # [bs, t, res, 4, 3]
    per_res_loss = per_atom_loss.sum(dim=(-1, -2))  # [bs, t, res]
    atoms_per_res = bb_atom_loss_mask.sum(dim=(-1, -2))  # [bs, t, res]
    normalized_res_loss = per_res_loss / (atoms_per_res + 1e-8)  # [bs, t, res]
    masked_res_loss = normalized_res_loss * mask  # [bs, t, res]
    valid_res_count = mask.sum(dim=(-1, -2))  # [bs]
    bb_atom_loss = masked_res_loss.sum(dim=(-1, -2)) / (valid_res_count + 1e-10)  # [bs]
    
    bb_atom_loss *= 1.0
    bb_atom_loss_ori = bb_atom_loss * 0.25
    bb_atom_loss *= t < 0.25
    bb_atom_loss *= 0.25
    
    """begin to calculate distance loss """
    time_steps = gt_atom4.size(1)
    batch_time_size = batch_size * time_steps
    # [bs, t, num_res, 4, 3] -> [bs*t, num_res*4, 3]
    gt_flat_atoms = gt_atom4.reshape([batch_time_size, num_res*4, 3])
    pred_flat_atoms = pred_atom4.reshape([batch_time_size, num_res*4, 3])

    gt_pair_dists = torch.linalg.norm(
        gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
    pred_pair_dists = torch.linalg.norm(
        pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)

    loss_mask = loss_mask.reshape([batch_time_size, num_res])
    bb_mask = bb_mask.reshape([batch_time_size, num_res])

    # [bs*t, num_res] -> [bs*t, num_res*4]
    flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 4))
    flat_loss_mask = flat_loss_mask.reshape([batch_time_size, num_res*4])
    flat_res_mask = torch.tile(bb_mask[:, :, None], (1, 1, 4))
    flat_res_mask = flat_res_mask.reshape([batch_time_size, num_res*4])

    gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
    pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
    pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :]
    
    proximity_mask = gt_pair_dists < 6
    pair_dist_mask = pair_dist_mask * proximity_mask

    dist_mat_loss = torch.sum(
        (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
        dim=(1, 2))
    dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res)

    dist_mat_loss = dist_mat_loss.reshape([batch_size, time_steps])
    dist_mat_loss = torch.mean(dist_mat_loss, dim=1)

    dist_mat_loss *= 1.0
    dist_mat_loss_ori = 0.25 * dist_mat_loss
    dist_mat_loss *= t < 0.1
    dist_mat_loss *= 0.25
    
    final_loss = (bb_atom_loss + dist_mat_loss)
    final_loss_ori = (bb_atom_loss_ori + dist_mat_loss_ori)
    return final_loss, final_loss_ori

def get_auxiliary_loss(
    atom4_pred,  # [batch_size, timestep, seq_len, 4, 3]
    atom4_true,  # [batch_size, timestep, seq_len, 4, 3]
    mask):

    auxiliary_loss, auxiliary_loss_ori = loss_fn_tensor(atom4_pred, atom4_true, 0, mask)
    return auxiliary_loss, auxiliary_loss_ori


def calc_violation_loss(atom_positions, residue_mask, clash_tol=1.2):
    B, T, N_res, N_atoms, _ = atom_positions.shape  
    

    atom_mask = torch.any(torch.abs(atom_positions) > 1e-7, dim=-1)  
    

    residue_mask_bool = residue_mask.bool()
    residue_mask_expanded = residue_mask_bool.unsqueeze(-1).expand(-1, -1, -1, N_atoms)  
    atom_mask = atom_mask & residue_mask_expanded
    
    total_violations = torch.zeros(1, device=device)
    total_valid_pairs = 0
    min_distances = []
    
    for b in range(B):
        for t in range(T):

            coords_i = atom_positions[b, t].unsqueeze(1).unsqueeze(3)  # [N_res, 1, 4, 1, 3]
            coords_j = atom_positions[b, t].unsqueeze(0).unsqueeze(2)  # [1, N_res, 1, 4, 3]
            

            distances = torch.sqrt(torch.sum((coords_i - coords_j) ** 2, dim=-1) + 1e-10)
            
            mask_i = atom_mask[b, t].unsqueeze(1).unsqueeze(3)  # [N_res, 1, 4, 1]
            mask_j = atom_mask[b, t].unsqueeze(0).unsqueeze(2)  # [1, N_res, 1, 4]
            
            pair_mask = mask_i & mask_j
            
            tri_mask = torch.triu(torch.ones(N_res, N_res, device=device), diagonal=1).bool()
            tri_mask = tri_mask.unsqueeze(-1).unsqueeze(-1)  # [N_res, N_res, 1, 1]

            final_mask = pair_mask & tri_mask

            clash_mask = (distances < clash_tol) & final_mask
            num_violations = clash_mask.sum().item()
            

            batch_violations = torch.nn.functional.relu(clash_tol - distances)
            batch_violations = batch_violations * final_mask.float()  
            total_violations += batch_violations.sum()
            total_valid_pairs += final_mask.sum().item()
            
    if total_valid_pairs > 0:
        violation_loss = total_violations / 1
        if violation_loss > 0:
            return violation_loss
        else:
            return total_violations  
    else:
        return atom_positions.sum() * 0.0 