import torch
import torch.nn.functional as F

from interfacediff.utils.protein.constants import (
    BBHeavyAtom, 
    backbone_atom_coordinates_tensor,
    bb_oxygen_coordinate_tensor,
)
from .topology import get_terminus_flag


def safe_norm(x, dim=-1, keepdim=False, eps=1e-8, sqrt=True):
    out = torch.clamp(torch.sum(torch.square(x), dim=dim, keepdim=keepdim), min=eps)
    return torch.sqrt(out) if sqrt else out


def normalize_vector(v, dim, eps=1e-6):
    return v / (torch.linalg.norm(v, ord=2, dim=dim, keepdim=True) + eps)



def local_to_global(R, t, p):
    
    assert p.size(-1) == 3
    p_size = p.size()
    N, L = p_size[0], p_size[1]

    p = p.view(N, L, -1, 3).transpose(-1, -2)   
    q = torch.matmul(R, p) + t.unsqueeze(-1)    
    q = q.transpose(-1, -2).reshape(p_size)     
    return q


def global_to_local(R, t, q):
    
    assert q.size(-1) == 3
    q_size = q.size()
    N, L = q_size[0], q_size[1]

    q = q.reshape(N, L, -1, 3).transpose(-1, -2)   
    p = torch.matmul(R.transpose(-1, -2), (q - t.unsqueeze(-1)))  
    p = p.transpose(-1, -2).reshape(q_size)     
    return p


def apply_rotation_to_vector(R, p):
    return local_to_global(R, torch.zeros_like(p), p)


def compose_rotation_and_translation(R1, t1, R2, t2):
    
    R_new = torch.matmul(R1, R2)    
    t_new = torch.matmul(R1, t2.unsqueeze(-1)).squeeze(-1) + t1
    return R_new, t_new


def compose_chain(Ts):
    while len(Ts) >= 2:
        R1, t1 = Ts[-2]
        R2, t2 = Ts[-1]
        T_next = compose_rotation_and_translation(R1, t1, R2, t2)
        Ts = Ts[:-2] + [T_next]
    return Ts[0]







def quaternion_to_rotation_matrix(quaternions):
    
    quaternions = F.normalize(quaternions, dim=-1)
    r, i, j, k = torch.unbind(quaternions, -1)
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))



def _hat(v):
    
    vx, vy, vz = v.unbind(dim=-1)
    O = torch.zeros_like(vx)
    return torch.stack([
        O,   -vz,  vy,
        vz,   O,  -vx,
        -vy,  vx,  O
    ], dim=-1).reshape(v.shape[:-1] + (3, 3))

def so3_exp_map(delta, eps=1e-8):
    
    assert delta.shape[-1] == 3, f"delta last dim must be 3, got {delta.shape}"

    theta = torch.linalg.norm(delta, dim=-1, keepdim=True)         
    theta_safe = theta.clamp_min(eps)
    k = delta / theta_safe                                         

    K = _hat(k)                                                    

    
    A = (torch.sin(theta) / theta_safe)[..., None, None]           
    B = ((1.0 - torch.cos(theta)) / (theta_safe * theta_safe))[..., None, None]

    
    batch_shape = delta.shape[:-1]
    I = torch.eye(3, device=delta.device, dtype=delta.dtype) \
          .reshape((1,)*len(batch_shape) + (3,3)) \
          .expand(batch_shape + (3,3))

    R = I + A * K + B * (K @ K)
    return R









def quaternion_1ijk_to_rotation_matrix(q):
    
    b, c, d = torch.unbind(q, dim=-1)
    s = torch.sqrt(1 + b**2 + c**2 + d**2)
    a, b, c, d = 1/s, b/s, c/s, d/s

    o = torch.stack(
        (
            a**2 + b**2 - c**2 - d**2,  2*b*c - 2*a*d,  2*b*d + 2*a*c,
            2*b*c + 2*a*d,  a**2 - b**2 + c**2 - d**2,  2*c*d - 2*a*b,
            2*b*d - 2*a*c,  2*c*d + 2*a*b,  a**2 - b**2 - c**2 + d**2,
        ),
        -1,
    )
    return o.reshape(q.shape[:-1] + (3, 3))



def dihedral_from_four_points(p0, p1, p2, p3):
    
    v0 = p2 - p1
    v1 = p0 - p1
    v2 = p3 - p2
    u1 = torch.cross(v0, v1, dim=-1)
    n1 = u1 / torch.linalg.norm(u1, dim=-1, keepdim=True)
    u2 = torch.cross(v0, v2, dim=-1)
    n2 = u2 / torch.linalg.norm(u2, dim=-1, keepdim=True)
    sgn = torch.sign( (torch.cross(v1, v2, dim=-1) * v0).sum(-1) )
    dihed = sgn*torch.acos( (n1 * n2).sum(-1).clamp(min=-0.999999, max=0.999999) )
    dihed = torch.nan_to_num(dihed)
    return dihed



def get_backbone_dihedral_angles(pos_atoms, chain_nb, res_nb, mask):
    
    pos_N  = pos_atoms[:, :, BBHeavyAtom.N]   
    pos_CA = pos_atoms[:, :, BBHeavyAtom.CA]
    pos_C  = pos_atoms[:, :, BBHeavyAtom.C]

    N_term_flag, C_term_flag = get_terminus_flag(chain_nb, res_nb, mask)  
    omega_mask = torch.logical_not(N_term_flag)
    phi_mask = torch.logical_not(N_term_flag)
    psi_mask = torch.logical_not(C_term_flag)

    
    omega = F.pad(
        dihedral_from_four_points(pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:]), 
        pad=(1, 0), value=0,
    )
    phi = F.pad(
        dihedral_from_four_points(pos_C[:, :-1], pos_N[:, 1:], pos_CA[:, 1:], pos_C[:, 1:]),
        pad=(1, 0), value=0,
    )

    
    psi = F.pad(
        dihedral_from_four_points(pos_N[:, :-1], pos_CA[:, :-1], pos_C[:, :-1], pos_N[:, 1:]),
        pad=(0, 1), value=0,
    )

    mask_bb_dihed = torch.stack([omega_mask, phi_mask, psi_mask], dim=-1)
    bb_dihedral = torch.stack([omega, phi, psi], dim=-1) * mask_bb_dihed
    return bb_dihedral, mask_bb_dihed


    

def reconstruct_backbone(R, t, aa, chain_nb, res_nb, mask):
    
    N, L = aa.size()
    
    bb_coords = backbone_atom_coordinates_tensor.clone().to(t)  
    oxygen_coord = bb_oxygen_coordinate_tensor.clone().to(t)    
    aa = aa.clamp(min=0, max=20)    

    bb_coords = bb_coords[aa.flatten()].reshape(N, L, -1, 3)    
    oxygen_coord = oxygen_coord[aa.flatten()].reshape(N, L, -1)  
    bb_pos = local_to_global(R, t, bb_coords)   

    
    bb_dihedral, _ = get_backbone_dihedral_angles(bb_pos, chain_nb, res_nb, mask)
    psi = bb_dihedral[..., 2]   
    
    sin_psi = torch.sin(psi).reshape(N, L, 1, 1)
    cos_psi = torch.cos(psi).reshape(N, L, 1, 1)
    zero = torch.zeros_like(sin_psi)
    one = torch.ones_like(sin_psi)
    row1 = torch.cat([one, zero, zero], dim=-1)     
    row2 = torch.cat([zero, cos_psi, -sin_psi], dim=-1) 
    row3 = torch.cat([zero, sin_psi, cos_psi], dim=-1)  
    R_psi = torch.cat([row1, row2, row3], dim=-2)       

    
    R_psi, t_psi = compose_chain([
        (R, t), 
        (R_psi, torch.zeros_like(t)),       
    ])
    O_pos = local_to_global(R_psi, t_psi, oxygen_coord.reshape(N, L, 1, 3))

    bb_pos = torch.cat([bb_pos, O_pos], dim=2)  
    return bb_pos
    

def reconstruct_backbone_partially(pos_ctx, R_new, t_new, aa, chain_nb, res_nb, mask_atoms, mask_recons):
    
    N, L, A = mask_atoms.size()

    mask_res = mask_atoms[:, :, BBHeavyAtom.CA]
    pos_recons = reconstruct_backbone(R_new, t_new, aa, chain_nb, res_nb, mask_res) 
    pos_recons = F.pad(pos_recons, pad=(0, 0, 0, A-4), value=0) 

    pos_new = torch.where(
        mask_recons[:, :, None, None].expand_as(pos_ctx),
        pos_recons, pos_ctx
    )   

    mask_bb_atoms = torch.zeros_like(mask_atoms)
    mask_bb_atoms[:, :, :4] = True
    mask_new = torch.where(
        mask_recons[:, :, None].expand_as(mask_atoms),
        mask_bb_atoms, mask_atoms
    )

    return pos_new, mask_new

