# Clifford algebra

from .cwrap import build
import math
import functools
import numpy as np
import torch
from torch import Tensor


class CliffordAlgebra:
    num_bases: int
    """Dim of the 1-vector space. For O(N) equivariant models, it's N."""
    dim: int
    """Dim of the Clifford algebra. For O(N) equivariant models, it's 2**N."""
    metric: Tensor
    """Length-N vector of metric."""
    index2bitmap: Tensor
    """Length-2**N: the i-th basis of Clifford algebra."""
    bitmap2index: Tensor
    """Length-2**N: inverse of `index2bitmap`."""
    grades: Tensor
    """Length-2**N: the grade of the i-th basis."""
    table: Tensor
    """2**N * 2**N * 2**N: i-th basis * k-th basis = \sum_j table[i, j, k] * j-th basis."""

    def geometric_product(self, a: Tensor, b: Tensor, blades=None) -> Tensor:
        """
        Compute geometric product of `a` and `b`:
            (..., D_l) * (..., D_r) -> (..., D_o)

        blades = (indices of D_l, 
                  indices of D_o, 
                  indices of D_r)
        blades default to (range(2**N), range(2**N), range(2**N))
        """
        table = self.table

        if blades is not None:
            blades_l, blades_o, blades_r = blades
            assert isinstance(blades_l, Tensor)
            assert isinstance(blades_o, Tensor)
            assert isinstance(blades_r, Tensor)
            table = table[blades_l[:, None, None], blades_o[:, None], blades_r]

        return torch.einsum("...i,ijk,...k->...j", a, table, b)
    
    def project(self, mv: Tensor, index) -> Tensor:
        """
        (..., 2**N) -> (..., len(index))
        """
        return mv[..., index]

    def project_to_grade(self, mv: Tensor, grade: int) -> Tensor:
        """
        (..., 2**N) -> (..., size of grade)
        """
        return mv[..., self._grade_stops[grade]:
                       self._grade_stops[grade+1]]
    
    def embed(self, tensor: Tensor, index) -> Tensor:
        """
        (..., any size) -> (..., 2**N)
        """
        mv = torch.zeros(
            *tensor.shape[:-1], self.dim, device=tensor.device, dtype=tensor.dtype
        )
        mv[..., index] = tensor
        return mv

    def embed_grade(self, tensor: Tensor, grade: int) -> Tensor:
        """
        (..., size of grade) -> (..., 2**N)
        """
        mv = torch.zeros(
            *tensor.shape[:-1], self.dim, device=tensor.device, dtype=tensor.dtype
        )
        mv[..., self._grade_stops[grade]:self._grade_stops[grade+1]] = tensor
        return mv

    def alpha(self, mv: Tensor, blades=None) -> Tensor:
        signs = self._alpha_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv

    def beta(self, mv: Tensor, blades=None) -> Tensor:
        signs = self._beta_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv

    def gamma(self, mv: Tensor, blades=None) -> Tensor:
        signs = self._gamma_signs
        if blades is not None:
            signs = signs[blades]
        return signs * mv
    
    def dot(self, x: Tensor, y: Tensor, blades=None) -> Tensor:
        """
        Compute (beta(x) * y)(0-component).

            (..., D_l) * (..., D_r) -> (..., 1)

        blades = (indices of D_l,
                  indices of D_r)

        blades default to (range(2**N), range(2**N))
        """
        if blades is not None:
            assert len(blades) == 2
            beta_blades = blades[0]
            blades = (
                blades[0],
                torch.tensor([0], device=x.device),
                blades[1],
            )
        else:
            blades = torch.arange(self.dim, device=x.device)
            blades = (
                blades,
                torch.tensor([0],  device=x.device),
                blades,
            )
            beta_blades = None

        return self.geometric_product(
            self.beta(x, blades=beta_blades),
            y,
            blades=blades,
        )
    
    def squared_norm(self, mv: Tensor, blades=None) -> Tensor:
        """
        (..., D) -> (..., 1)

        D default to 2**N.
        """
        if blades is not None:
            blades = (blades, blades)
        return self.dot(mv, mv, blades=blades)
    
    def _smooth_abs_sqrt(self, input: Tensor, eps=1e-16) -> Tensor:
        return (input**2 + eps) ** 0.25

    def norm(self, mv: Tensor, blades=None) -> Tensor:
        """
        (..., D) -> (..., 1)
        
        D default to 2**N.
        """
        return self._smooth_abs_sqrt(self.squared_norm(mv, blades=blades))

    def norms(self, mv: Tensor, grades=None) -> list:
        """
        Compute norm separately for each grade (in grades, if grades 
        is not None).
        """
        if grades is None:
            grades = range(self.num_bases + 1)
        return [
            self.norm(self.project_to_grade(mv, grade), 
                      blades=torch.arange(self._grade_stops[grade],
                                          self._grade_stops[grade+1],
                                          device=mv.device)
                     )
            for grade in grades
        ]
    
    def squared_norms(self, mv: Tensor, grades=None) -> list:
        """
        Compute squared norm separately for each grade (in grades, if 
        grades is not None).
        """
        if grades is None:
            grades = range(self.num_bases + 1)
        return [
            self.squared_norm(self.project_to_grade(mv, grade), 
                              blades=torch.arange(self._grade_stops[grade],
                                                  self._grade_stops[grade+1],
                                                  device=mv.device)
                             )
            for grade in grades
        ]

    def __init__(self, metric):
        self.num_bases = len(metric)
        self.dim = 1 << self.num_bases
        self.metric = np.array(metric, dtype=np.float32)

        (self.index2bitmap, self.bitmap2index, self.grades, self.table
         ) = map(torch.from_numpy, build(self.num_bases, self.metric))
        self.metric = torch.from_numpy(self.metric)
        self.table = self.table.permute(0, 2, 1)
        self.device = torch.device("cpu")

    def to(self, device):
        self.device = torch.device(device)
        for attr in ['metric', 'index2bitmap', 'bitmap2index', 
                     'grades', 'table', '_alpha_signs', 
                     '_beta_signs', '_gamma_signs', '_grade_stops', 
                     '_geometric_product_paths']:
            setattr(self, attr, getattr(self, attr).to(device))

    @functools.cached_property
    def _alpha_signs(self):
        return torch.pow(-1, self.grades)

    @functools.cached_property
    def _beta_signs(self):
        return torch.pow(-1, self.grades * (self.grades - 1) // 2)

    @functools.cached_property
    def _gamma_signs(self):
        return torch.pow(-1, self.grades * (self.grades + 1) // 2)
    
    @functools.cached_property
    def _grade_stops(self):
        return torch.tensor([0] + 
                            [math.comb(self.num_bases, i) 
                             for i in range(self.num_bases + 1)], 
                            dtype=torch.int32,
                            device=self.device).cumsum(0)
    
    @functools.cached_property
    def _geometric_product_paths(self):
        gp_paths = torch.zeros((self.num_bases + 1, self.num_bases + 1, self.num_bases + 1),
                            dtype=torch.bool, device=self.device)
        for i in range(self.num_bases + 1):
            for j in range(self.num_bases + 1):
                for k in range(self.num_bases + 1):
                    gp_paths[i, j, k] = torch.any(
                        self.table[self._grade_stops[i]:self._grade_stops[i+1],
                                   self._grade_stops[j]:self._grade_stops[j+1],
                                   self._grade_stops[k]:self._grade_stops[k+1]] != 0
                    )

        return gp_paths
    

