# scalarization.py
# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from e3nn.o3 import Irreps


class Scalarization(nn.Module):
    """
    For each (mul, ir) term in hidden_irreps with dimension mul*(2l+1):
      - If l==0 and parity is even (0e): copy the mul scalar channels directly.
      - Else: compute Gram(inner-product) over the (2l+1) components -> [B, mul, mul], then flatten to [B, mul*mul].
    Concatenate all terms along the last dimension and return as res.

    Input:
        emb: [B, C] graph-level features where C == hidden_irreps.dim

    Output:
        res: [B, output_dim], with output_dim = sum_over_terms( mul if 0e else mul*mul )
    """
    def __init__(self, hidden_irreps: Irreps, device: str = 'cpu',):
        super().__init__()
        self.hidden_irreps = Irreps(hidden_irreps)

        # Precompute term slices: (start, end, mul, l, is_even, d_rep)
        self._slices: List[Tuple[int, int, int, int, bool, int]] = []
        offset = 0
        for mul, ir in self.hidden_irreps:
            l = ir.l
            is_even = (ir.p == 1)
            d_rep = 2 * l + 1
            block = mul * d_rep
            self._slices.append((offset, offset + block, mul, l, is_even, d_rep))
            offset += block

        assert offset == self.hidden_irreps.dim, "Irreps dim mismatch."

        # Compute output dimension for convenience
        output_dim = 0
        for _, _, mul, l, is_even, _ in self._slices:
            if l == 0 and is_even:
                output_dim += mul
            else:
                output_dim += mul * mul
        self.output_dim = output_dim
        self.to(device)

    def forward(self, emb: Tensor) -> Tensor:
        """
        Args:
            emb: [B, C] with C == hidden_irreps.dim
        Returns:
            res: [B, output_dim]
        """
        if emb.dim() != 2:
            raise ValueError(f"`emb` must be [B, C], got shape {tuple(emb.shape)}")
        if emb.size(-1) != self.hidden_irreps.dim:
            raise ValueError(
                f"Channel mismatch: got C={emb.size(-1)} vs hidden_irreps.dim={self.hidden_irreps.dim}"
            )

        B = emb.size(0)
        chunks: List[Tensor] = []

        for start, end, mul, l, is_even, d_rep in self._slices:
            # Slice the term block: [B, mul * d_rep] -> [B, mul, d_rep]
            term = emb[:, start:end].view(B, mul, d_rep)

            if l == 0 and is_even:
                # 0e: copy directly -> [B, mul]
                # (d_rep == 1, so term[..., 0] is fine)
                chunks.append(term[..., 0])
            else:
                # Gram over representation dimension: [B, mul, mul]
                # safer numerics than matmul with einsum
                # gram[b, i, j] = sum_k term[b, i, k] * term[b, j, k]
                gram = torch.einsum("bik,bjk->bij", term, term)
                # Flatten to [B, mul*mul]
                chunks.append(gram.reshape(B, mul * mul))

        res = torch.cat(chunks, dim=-1) if len(chunks) > 0 else emb.new_zeros(B, 0)
        return res
