from math import prod

import torch
from jaxtyping import Bool, Float
from scipy.spatial.transform import Rotation as Scipy_Rotation
from torch import Tensor

from openfold.np.residue_constants import (
    restype_atom14_mask,
    restype_atom14_rigid_group_positions,
    restype_atom14_to_rigid_group,
    restype_rigid_group_default_frame,
)
from openfold.utils.all_atom_multimer import atom14_to_atom37
from openfold.utils.feats import (
    frames_and_literature_positions_to_atom14_pos,
    torsion_angles_to_frames,
)
from openfold.utils.rigid_utils import Rigid, Rotation

nm_to_ang_scale = 10.0
ang_to_nm = lambda trans: trans / nm_to_ang_scale
nm_to_ang = lambda trans: trans * nm_to_ang_scale


def get_atom37_ca_mask(n: int, device: torch.device) -> Bool[Tensor, "n 37"]:

    mask = torch.zeros((n, 37), device=device, dtype=torch.bool)
    mask[:, 1] = True
    return mask.bool()


def get_atom37_bb3_mask(n: int, device: torch.device) -> Bool[Tensor, "n 37"]:

    mask = torch.zeros((n, 37), device=device, dtype=torch.bool)
    mask[:, 0] = True
    mask[:, 1] = True
    mask[:, 2] = True
    return mask.bool()


def get_atom37_bb3o_mask(n: int, device: torch.device) -> Bool[Tensor, "n 37"]:

    mask = torch.zeros((n, 37), device=device, dtype=torch.bool)
    mask[:, 0] = True
    mask[:, 1] = True
    mask[:, 2] = True
    mask[:, 4] = True
    return mask


def trans_nm_and_rot_to_atom37(trans, rot, impute_ox=False):

    return trans_ang_and_rot_to_atom37(nm_to_ang(trans), rot, impute_ox=impute_ox)


def trans_ang_and_rot_to_atom37(trans, rot, impute_ox=False):

    return openfold_bb_frames_to_atom37(
        Rigid(Rotation(rot_mats=rot, quats=None), trans), impute_ox=impute_ox
    )


def trans_nm_to_atom37(ca_coors_nm):

    return trans_ang_to_atom37(nm_to_ang(ca_coors_nm))


def trans_ang_to_atom37(ca_coors):

    original_shape = ca_coors.shape
    atom37_shape = list(original_shape[:-1]) + [37, original_shape[-1]]
    ca_coors_atom37 = torch.zeros(
        atom37_shape, dtype=ca_coors.dtype, device=ca_coors.device
    )
    ca_coors_atom37[..., 1, :] = ca_coors
    return ca_coors_atom37


def openfold_bb_frames_to_atom37(frames, impute_ox=False):

    default_frames = torch.tensor(
        restype_rigid_group_default_frame,
        dtype=frames.dtype,
        device=frames.device,
        requires_grad=False,
    )
    group_idx = torch.tensor(
        restype_atom14_to_rigid_group, device=frames.device, requires_grad=False
    )
    atom_mask = torch.tensor(
        restype_atom14_mask,
        dtype=frames.dtype,
        device=frames.device,
        requires_grad=False,
    )
    lit_positions = torch.tensor(
        restype_atom14_rigid_group_positions,
        dtype=frames.dtype,
        device=frames.device,
        requires_grad=False,
    )

    backb_to_global = Rigid(
        Rotation(rot_mats=frames.get_rots().get_rot_mats(), quats=None),
        frames.get_trans(),
    )

    angles = torch.randn(frames.shape + (7, 2), device=frames.device) * 0.001 + 1.0

    aatype = torch.ones(frames.shape).long()

    all_frames_to_global = torsion_angles_to_frames(
        backb_to_global, angles, aatype, default_frames
    )
    coords_atom14 = frames_and_literature_positions_to_atom14_pos(
        all_frames_to_global,
        aatype,
        default_frames,
        group_idx,
        atom_mask,
        lit_positions,
    )
    coords_atom37 = atom14_to_atom37(
        coords_atom14, torch.zeros(frames.shape, device=frames.device).int()
    )[0]
    if not impute_ox:
        return coords_atom37
    return batch_adjust_oxygen_pos(coords_atom37.clone())


def batch_adjust_oxygen_pos(atom_37):
    assert atom_37.ndim == 4
    return torch.stack([adjust_oxygen_pos(atom_37[i]) for i in range(atom_37.shape[0])])


def adjust_oxygen_pos(atom_37: torch.Tensor, pos_is_known=None) -> torch.Tensor:

    N = atom_37.shape[0]
    assert atom_37.shape == (N, 37, 3)

    calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / (
        torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7
    )

    nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / (
        torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7
    )

    carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl
    carbonyl_to_oxygen = carbonyl_to_oxygen / (
        torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7
    )

    atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23

    calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / (
        torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
    )

    calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / (
        torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7
    )
    carbonyl_to_oxygen_term: torch.Tensor = (
        calpha_to_carbonyl_term + calpha_to_nitrogen_term
    )
    carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / (
        torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7
    )

    if pos_is_known is None:
        pos_is_known = torch.ones(
            (atom_37.shape[0],), dtype=torch.int64, device=atom_37.device
        )

    next_res_gone: torch.Tensor = ~pos_is_known.bool()
    next_res_gone = torch.cat(
        [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0
    )
    next_res_gone = next_res_gone[1:]

    atom_37[next_res_gone, 4, :] = (
        atom_37[next_res_gone, 2, :] + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23
    )

    return atom_37


def sample_uniform_rotation(
    shape=tuple(), dtype=None, device=None
) -> Float[Tensor, "*batch 3 3"]:

    return torch.tensor(
        Scipy_Rotation.random(prod(shape)).as_matrix(),
        device=device,
        dtype=dtype,
    ).reshape(*shape, 3, 3)
