import numpy as np
import torch as tc
import torch.nn as nn
import torch.nn.functional as F


class Normalizer(nn.Module):
    def __init__(self, dof: int, make_translation_invariant=True, make_rotation_invariant=True):
        super(Normalizer, self).__init__()
        self.dof = dof
        self.make_translation_invariant = make_translation_invariant
        self.make_rotation_invariant = make_rotation_invariant

    def forward(self, q: tc.Tensor, p: tc.Tensor) -> tuple[tc.Tensor, tc.Tensor]:
        """Normalizes input according to the invariance specifications
        Args:
            q and p of shape (n_samples, *n_obj, dof) each
        Returns:
            Translation and rotation invariant q and rotation invariant p representations of same shape each
        """
        if self.make_translation_invariant:
            # Make q translation invariant
            dims_to_avg = list(range(1, q.ndim - 1))
            q = q - q.mean(dim=dims_to_avg, keepdim=True)   # center q of shape (n_samples, *n_obj, dof)

        if self.make_rotation_invariant:
            # Make q and p rotation invariant
            q_ref_nodes = self.get_ref_nodes(q)             # e.g. [0] for dof=2, [0,1] for dof=3 because at least 2 reference nodes are needed to represent a 3-dof system.
            q = self.get_rot_inv(q, q_ref_nodes)            # q_bar of shape (n_samples, *n_obj, dof)
            p_ref_nodes = self.get_ref_nodes(p)
            p = self.get_rot_inv(p, p_ref_nodes)            # p_bar of shape (n_samples, *n_obj, dof)

        return q, p

    def get_rot_inv(self, v: tc.Tensor, ref_nodes: list[tc.Tensor]) -> tc.Tensor:
        """
        Computes a rotationally invariant version of q or p with arbitrary spatial layout.

        Args:
            v:  Tensor of shape (n_samples, *n_obj, dof)
            ref_nodes: List of tensors each of shape (n_samples, dof)
        Returns:
            v_rot:  Rotationally invariant v
        """
        if self.dof == 1:
            return v

        n_samples, *n_obj, dof = v.shape
        n_nodes = int(np.prod(n_obj))

        # Do not normalize for nodes very close to 0
        norms = [node.norm(dim=1) for node in ref_nodes]
        max_norm = tc.stack(norms, dim=1).max(dim=1).values
        mask = (max_norm < 1e-8)
        mask = mask.reshape(mask.shape[0], *([1] * (v.dim() - 1)))

        e_0 = F.normalize(ref_nodes[0], dim=1)
        basis_vectors = [e_0]
        if dof == 2:
            e_1 = tc.stack([-e_0[:, 1], e_0[:, 0]], dim=1)          # rotate 90-degrees
            basis_vectors.append(e_1)
        elif dof ==3:                                               # we need a plane for 3d
            e_1_candidate = F.normalize(ref_nodes[1])               # (n_samples, 3)

            # threshold if e_1 is aligned with e_0 randomly
            dot = (e_0 * e_1_candidate).sum(dim=1)                  # (n_samples,)
            use_fallback = (dot.abs() > 0.98).unsqueeze(1)          # (n_samples, 1)
            e_1 = tc.where(
                use_fallback,
                tc.cross(e_0, e_1_candidate, dim=1), # orthogonalize aligned vectors
                e_1_candidate
            )

            # Gram-Schmidt: orthogonalize e_1 w.r.t. e_0
            e_1 = e_1 - (e_1 * e_0).sum(dim=1, keepdim=True) * e_0
            e_1 = F.normalize(e_1, dim=1)
            e_2 = tc.cross(e_0, e_1, dim=1)
            basis_vectors.extend([e_1, e_2])
        else:
            raise ValueError(f"DOF={self.dof} is more than 3 spatial dimensions.")

        basis = tc.stack(basis_vectors, dim=1)          # (n_samples, dof, dof)
        v_local = tc.einsum("nij,nkj->nki", basis, v.view(n_samples, n_nodes, dof))   # (n_samples, n_nodes, dof)
        v_local = v_local.view(n_samples, *n_obj, dof)
        return tc.where(mask, v, v_local)      # (n_samples, *n_obj, dof)
        # if mask.any():
            # v_local[mask] = v[mask]
        # return v_local

        v_local[mask] = v[mask]
        return v_local
        return tc.where(mask, v, v_local.view(n_samples, *n_obj, dof))      # (n_samples, *n_obj, dof)
        return v_local.view(n_samples, *n_obj, dof)      # (n_samples, *n_obj, dof)

    def get_ref_nodes(self, v: tc.Tensor) -> list[tc.Tensor]:
        """
        Returns a list of reference nodes. This is important to make the input
        rotation invariant with respect to these reference nodes (reference coordinate frame).
        """
        n_samples, *n_obj, dof = v.shape
        n_nodes = int(np.prod(n_obj))
        if n_nodes > dof - 1:
            ref_nodes = [ v.view(n_samples, n_nodes, dof)[:, i, :] for i in range(self.dof) ]
        else:
            ref_nodes = []
            for i in range(self.dof):
                if i < n_nodes:
                    ref_nodes.append(v.view(n_samples, n_nodes, dof)[:, i, :])
                else:
                    e_i = tc.zeros(dof, device=v.device, dtype=v.dtype)
                    e_i[i] = 1.0
                    e_i = e_i.expand(n_samples, -1).clone() # (n_samples, dof)
                    ref_nodes.append(e_i)
        return ref_nodes
