import torch
from torch import nn

from .rapidash_fiber_grids import  matrix_to_quat


class Invariants(nn.Module):
    def __init__(
        self,
        base_space,
        group,
        fiber_space = None,
    ):
        super().__init__()

        if base_space not in ["R2", "R3", "Rn", "SE2", "R2S1", "R3S2", "SE3"]:
            raise ValueError(f"Unknown base space: {base_space}")
        else:
            self.base_space = base_space

        if group not in ["T2", "T3", "Tn", "SE2", "SE3", "SEn"]:
            raise ValueError(f"Unknown group: {group}")
        else:
            self.group = group

        if fiber_space not in [None, "S1", "SO2", "S2", "SO3"]:
            raise ValueError(f"Unknown fiber space: {fiber_space}")
        else:
            self.fiber_space = fiber_space

    def forward(self, pos_send, pos_receive, ori_grid=None):
        # Standard group convolution invariants
        if self.base_space in ["R2", "R3"] and (self.fiber_space is None):
            return self.compute_Rn_invariants(pos_send, pos_receive)
        elif self.base_space == "R3S2" and (self.fiber_space is None):
            return self.compute_R3S2_invariants(pos_send, pos_receive)
        elif self.base_space == "SE3" and (self.fiber_space is None):
            raise NotImplementedError("SE3 base space not implemented")
        # Separable group convolution invariants over space Rn x F, with F the fiber space
        elif self.base_space in ["R2", "R3"] and (self.fiber_space is not None):
            if self.fiber_space in ["S1", "SO2"]:
                return self.compute_separable_R3S1_invariants(ori_grid, pos_send, pos_receive)
            elif self.fiber_space == "S2":
                return self.compute_R3S2_invariants_fiber_bundle(pos_send, pos_receive, ori_grid)
            elif self.fiber_space == "SO3":
                return self.compute_R3SO3_invariants_fiber_bundle(pos_send, pos_receive, ori_grid)
        # Raise error of invalid specification
        else:
            raise ValueError(f"Unknown base space or invariants combination: {self.base_space} with num_ori = {self.num_ori}")

    def compute_Rn_invariants(self, pos_send, pos_receive):
        rel_pos = pos_send - pos_receive  # [num_edges, 3]
        if self.group in ["T2", "T3", "Tn"]:
            return rel_pos, None
        elif self.group in ["SE2", "SE3", "SEn"]:
            return rel_pos.norm(dim=-1, keepdim=True), None

    def compute_R3S2_invariants_fiber_bundle(self, pos_send, pos_receive, fiber_grid):
        rel_pos = pos_send - pos_receive  # [num_edges, 3]
        if self.group in ["SE3", "SEn"]:
            rel_pos = rel_pos[:, None, :]  # [num_edges, 1, 3]
            ori_grid_a = fiber_grid[None, :, :]  # [1, num_ori, 3]
            ori_grid_b = fiber_grid[:, None, :]  # [num_ori, 1, 3]

            invariant1 = (rel_pos * ori_grid_a).sum(dim=-1, keepdim=True)  # [num_edges, num_ori, 1]
            invariant2 = self.compute_orthogonal_displacement(rel_pos, invariant1, ori_grid_a, 3)
            invariant3 = (ori_grid_a * ori_grid_b).sum(dim=-1, keepdim=True)  # [num_ori, num_ori, 1]
            # invariant3 = torch.pi / 2 - torch.asin(invariant3.clamp(-1.,1.))

            spatial_invariants = torch.cat([invariant1, invariant2], dim=-1)  # [num_edges, num_ori, 2]
            fiber_invariants = invariant3  # [num_ori, num_ori, 1]

            return spatial_invariants, fiber_invariants
        else:
            raise NotImplementedError(f"Symmetry {self.group} not implemented for space {self.base_space}")
    
    def compute_R3SO3_invariants_fiber_bundle(self, pos_send, pos_receive, fiber_grid):
        rel_pos = pos_send - pos_receive  # [num_edges, 3]
        if self.group in ["SE3", "SEn"]:
            # Ri^T (xj - xi)
            spatial_invariants = torch.einsum('nij,bi->bni', fiber_grid, rel_pos)  # [num_edges, num_ori, 3]
            # Ri^T Rj
            fiber_invariants = torch.einsum('mij,nik->mnjk', fiber_grid, fiber_grid)  #.flatten(-2,-1)  # [num_ori, num_ori, 9]
            # fiber_invariants = matrix_to_euler(fiber_invariants)  # [num_ori, num_ori, 3]
            num_ori = fiber_invariants.size(0)
            fiber_invariants = matrix_to_quat(fiber_invariants.reshape(-1,3,3)).reshape(num_ori, num_ori, -1)  # [num_ori, num_ori, 3]


            return spatial_invariants, fiber_invariants
        else:
            raise NotImplementedError(f"Symmetry {self.group} not implemented for space {self.base_space}")

    def compute_R3S2_invariants(self, pos_send, pos_receive):
        if self.group in ["SE3", "SEn"]:
            pos_send, ori_send = pos_send[:, :3], pos_send[:, 3:]  # [num_edges, 3], [num_edges, 3]
            pos_receive, ori_receive = pos_receive[:, :3], pos_receive[:, 3:]

            rel_pos = pos_send - pos_receive  # [num_edges, 3]
            invariant1 = torch.sum(rel_pos * ori_receive, dim=-1, keepdim=True)
            invariant2 = (rel_pos - ori_receive * invariant1).norm(dim=-1, keepdim=True)

            frame = self.construct_reference_frame(rel_pos, ori_receive)
            rel_ori = torch.einsum("bji,bj->bi", frame, ori_send)
            invariants = torch.cat([invariant1, invariant2, rel_ori], dim=-1)

            return invariants
        else:
            raise NotImplementedError(f"Symmetry {self.group} not implemented for space {self.base_space}")

    def compute_orthogonal_displacement(self, rel_pos, invariant1, ori_grid_a, dim):
        if dim == 2:
            return (rel_pos - invariant1 * ori_grid_a).sum(dim=-1, keepdim=True)  # [num_edges, num_ori, 1]
        elif dim == 3:
            return (rel_pos - invariant1 * ori_grid_a).norm(dim=-1, keepdim=True)  # [num_edges, num_ori, 1]

    def construct_reference_frame(self, rel_pos, ori_receive):
        v3 = ori_receive  # reference z-axis
        v1 = rel_pos  # reference x-axis, make orthonormal to v3:
        v1 = v1 - torch.sum(v1 * v3, dim=-1, keepdim=True) * v3
        v2 = torch.cross(v3, v1, dim=-1)
        return torch.stack([v1, v2, v3], dim=-1)  # [num_edges, 3, 3]