import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from project.utils.deepinteract_constants import DEFAULT_MISSING_NORM_VEC


def min_max_normalize_tensor(tensor: torch.Tensor, device=None):
    """Normalize provided tensor to have values be in range [0, 1]."""
    min_value = min(tensor)
    max_value = max(tensor)
    tensor = torch.tensor([(value - min_value) / (max_value - min_value) for value in tensor], device=device)
    return tensor



def gather_nodes(nodes, neighbor_idx):
    """Collect node features of neighbor."""
    # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
    # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
    neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
    neighbors_flat = neighbors_flat.unsqueeze(-1).expand(-1, -1, nodes.size(2))
    # Gather and re-pack
    neighbor_features = torch.gather(nodes, 1, neighbors_flat)
    neighbor_features = neighbor_features.view(list(neighbor_idx.shape)[:3] + [-1])
    return neighbor_features


def gather_edges(edges, neighbor_idx):
    """Collect edge features of neighbor."""
    # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
    neighbors = neighbor_idx.unsqueeze(-1).expand(-1, -1, -1, edges.size(-1))
    edge_features = torch.gather(edges, 2, neighbors)
    return edge_features



class GeometricProteinFeatures(nn.Module):
    """Extract the required geometric features of proteins"""

    def __init__(self,
        num_pos_embed: int = 20,
        num_rbf: int = 18,
        dropout_rate: float=0.1,
    ):
        super(GeometricProteinFeatures, self).__init__()

        self.num_pos_embed = num_pos_embed
        self.num_rbf = num_rbf
        self.dropout = dropout_rate

    def get_dihedrals(self, X: torch.tensor, eps=1e-7):
        """Calculate the backbone dihedrals"""
        """
        Input:
            X: torch.tensor, shape=[B, N_res, M_atom, 3], M-->(N, CA, C, O, CB)
        Return:
            D_features: torch.tensor, shape[B, N_res, 3x2]
        """

        # only require the N, CA, C atoms
        B, N, M, _ = X.shape
        X = X[:, :, :3, :].reshape(B, N*3, 3)
        
        # Shifted slices of unit vectors
        dX = X[:, 1:, :] - X[:, :-1, :]
        U = F.normalize(dX, dim=-1)
        u_2 = U[:, :-2, :]
        u_1 = U[:, 1:-1, :]
        u_0 = U[:, 2:, :]

        #Backbone normals
        n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
        n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)

        # Angle between normals
        cosD = (n_2 * n_1).sum(-1)
        cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
        D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)

        # This scheme will remove phi[0], psi[-1], omega[-1]
        D = F.pad(D, (1, 2), 'constant', 0)
        D = D.view((D.size(0), int(D.size(1) / 3), 3))

        # Lift angle representations to the circle
        D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
        return D_features

    def compute_dist_rbf(self, dist: torch.Tensor):
        """ Apply radial basis function on distance"""
        """
            Input: 
                dist: torch.Tensor, shape=[B, N_res, K, 1]
            Return:
                dist_rbf: torch.Tensor, shape=[B, N_res, K, num_rbf]
        """

        D_min, D_max, D_count = 0., 20., self.num_rbf
        D_mu = torch.linspace(D_min, D_max, D_count)
        D_mu = D_mu.view([1, 1, 1, -1])
        D_sigma = (D_max - D_min) / D_count
        dist_rbf = torch.exp(-((dist - D_mu) / D_sigma) ** 2)
        #D_expand = torch.unsqueeze(dist, -1)
        #dist_rbf = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2)

        return dist_rbf

    def convert_rotations_into_quaternions(self, R: torch.Tensor):
        """convert R to Q"""
        """
            Input:
                R: torch.Tensor, shape=[B, N_res, K, 3, 3], K is the number of edges
            Return:
                Q: torch.Tensor, shape=[B, N_res, K, 4]
        """

        # For the simple Wikipedia version, see: en.wikipedia.org/wiki/Rotation_matrix#Quaternion
        # For other options, see math.stackexchange.com/questions/2074316/calculating-rotation-axis-from-rotation-matrix
        diag = torch.diagonal(R, dim1=-2, dim2=-1)
        Rxx, Ryy, Rzz = diag.unbind(-1)
        magnitudes = 0.5 * torch.sqrt(torch.abs(1 + torch.stack([
            Rxx - Ryy - Rzz,
            - Rxx + Ryy - Rzz,
            - Rxx - Ryy + Rzz
        ], -1)))
        _R = lambda i, j: R[:, :, :, i, j]
        signs = torch.sign(torch.stack([
            _R(2, 1) - _R(1, 2),
            _R(0, 2) - _R(2, 0),
            _R(1, 0) - _R(0, 1)
        ], -1))
        xyz = signs * magnitudes
        # The relu enforces a non-negative trace
        w = torch.sqrt(F.relu(1 + diag.sum(-1, keepdim=True))) / 2.  # Ensure we only get the real component
        Q = torch.cat((xyz, w), -1)
        Q = F.normalize(Q, dim=-1)

        return Q

    def get_coarse_orientation_feats(self, X: torch.Tensor, E_idx: torch.Tensor, eps=1e-6):
        """Derive pair features."""
        """
            Input:
                X: torch.Tensor, shape=[B, N_res, 3], coordinates of centra atom eg. CA
                E_idx: torch.Tensor, shape=[B, N_res, K]
            Outputs:
                AD_features: torch.Tensor, shape=[B, N_res, 3]
                O_features: torch.Tensor, shape=[B, N_res, K, 7]
        """

        # Shifted slices of unit vectors
        dX = X[:, 1:, :] - X[:, :-1, :]
        U = F.normalize(dX, dim=-1)
        u_2 = U[:, :-2, :]
        u_1 = U[:, 1:-1, :]
        u_0 = U[:, 2:, :]
        # Backbone normals
        n_2 = F.normalize(torch.cross(u_2, u_1), dim=-1)
        n_1 = F.normalize(torch.cross(u_1, u_0), dim=-1)

        # Bond angle calculation
        cosA = -(u_1 * u_0).sum(-1)
        cosA = torch.clamp(cosA, -1 + eps, 1 - eps)
        A = torch.acos(cosA)
        # Angle between normals
        cosD = (n_2 * n_1).sum(-1)
        cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
        D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
        # Backbone features
        AD_features = torch.stack((torch.cos(A), torch.sin(A) * torch.cos(D), torch.sin(A) * torch.sin(D)), 2)
        AD_features = F.pad(AD_features, (0, 0, 1, 2), 'constant', 0)

        # Build relative orientations
        o_1 = F.normalize(u_2 - u_1, dim=-1)
        O = torch.stack((o_1, n_2, torch.cross(o_1, n_2)), 2)
        O = O.view(list(O.shape[:2]) + [9])
        O = F.pad(O, (0, 0, 1, 2), 'constant', 0)

        O_neighbors = gather_nodes(O, E_idx)
        X_neighbors = gather_nodes(X, E_idx)

        # Re-view as rotation matrices
        O = O.view(list(O.shape[:2]) + [3, 3])  # O can map from a global ref. frame to a node's local ref. frame
        O_neighbors = O_neighbors.view(list(O_neighbors.shape[:3]) + [3, 3])

        # Rotate into local reference frames
        dX = X_neighbors - X.unsqueeze(-2)
        dU = torch.matmul(O.unsqueeze(2), dX.unsqueeze(-1)).squeeze(-1)
        dU = F.normalize(dU, dim=-1)
        # dU represents the relative direction to neighboring node j from node i's ref. frame
        R = torch.matmul(O.unsqueeze(2).transpose(-1, -2), O_neighbors)
        Q = self.convert_rotations_into_quaternions(R)

        O_features = torch.cat((dU, Q), dim=-1)

        return AD_features, O_features

    def get_amide_angles(self, X: torch.Tensor, edge_ids: torch.Tensor, edges):

        norm_vec = DEFAULT_MISSING_NORM_VEC

        # vec1 : Ca-Cb
        vec1 = X[:, :, 1] - X[:, :, 4]
        # vec2 : Cb-N
        vec2 = X[:, :, 4] - X[:, :, 0]
        norm_vec = torch.cross( vec1, vec2)


        src_vec = norm_vec[:, edges[0], :].float().squeeze()
        dst_vec = norm_vec[:, edges[1], :].float().squeeze()

        angles = np.array([
            torch.acos(torch.dot(vec1, vec2) / ( torch.linalg.norm(vec1) * torch.linalg.norm(vec2)))
            for vec1, vec2 in zip(src_vec, dst_vec)
            ]
        )
        np.nan_to_num(angles, copy=False, nan=0.0, posinf=None, neginf=None)


        amide_angles = torch.from_numpy(
            np.nan_to_num(
                min_max_normalize_tensor(torch.from_numpy(angles)).cpu().numpy(),
                copy=True, nan=0.0, posinf=None, neginf=None
            )
        ).reshape(-1, 1)
        

        return amide_angles




    def forward(self,
        coords: torch.Tensor,
        dist: torch.Tensor,
        edge_ids: torch.Tensor,
        edges,
    ):

        bb_angles = self.get_dihedrals(coords)
        dist_rbf = self.compute_dist_rbf(dist)
        _, dir_ori = self.get_coarse_orientation_feats(coords[:, :, 1, :], edge_ids)
        amide_angles = self.get_amide_angles(coords, edge_ids, edges)

        return bb_angles, dist_rbf, dir_ori, amide_angles



