# irreps_mask.py
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple, Set
from e3nn.o3 import Irreps


class IrrepsMask(nn.Module):
    """
    Irreps-based channel mask layer.
    Keeps channels corresponding to selected (l, parity) irreps blocks
    and zeros out all others.

    Args:
        hidden_irreps: e3nn.o3.Irreps describing the channel layout.
        lp_choose: list like ['0e', '1o'] specifying which (l, parity) blocks to keep.
                   If empty, all channels are kept (no masking).
    """
    def __init__(
        self,
        hidden_irreps: Irreps,
        lp_choose: List[str],
        device:str = 'cpu',
    ):
        super().__init__()
        self.hidden_irreps = Irreps(hidden_irreps)
        self.lp_choose = list(lp_choose) if lp_choose is not None else []

        mask = self._build_mask()
        # Register as buffer (so it follows .to(device) / dtype casts)
        self.register_buffer("mask", mask, persistent=False)

        self.to(device)

    # ---------------- Core ----------------
    def forward(self, graph_feat: Tensor) -> Tensor:
        """
        Apply irreps-based channel mask.

        Args:
            graph_feat: [*, C] where C == hidden_irreps.dim
        Returns:
            masked_feat: same shape as input, with selected channels preserved
        """
        if graph_feat.size(-1) != self.hidden_irreps.dim:
            raise ValueError(
                f"Channel mismatch: got {graph_feat.size(-1)} vs hidden_irreps.dim={self.hidden_irreps.dim}"
            )
        return graph_feat * self.mask.to(graph_feat.dtype).detach()

    # ---------------- Mask building ----------------
    def _parse_tokens(self, tokens: List[str]) -> Set[Tuple[int, str]]:
        """
        Parse tokens like '0e' / '2o' into (l, 'e'|'o') pairs.
        """
        keep: Set[Tuple[int, str]] = set()
        for tok in tokens:
            t = tok.strip().lower()
            if not t:
                continue
            if t[-1] not in ("e", "o") or len(t) < 2:
                raise ValueError(f"Invalid token '{tok}', expected like '0e' or '1o'.")
            l = int(t[:-1])
            p = t[-1]
            keep.add((l, p))
        return keep

    def _build_mask(self) -> Tensor:
        """
        Build a 1D mask vector of length hidden_irreps.dim.
        Channels belonging to selected irreps blocks are 1; others 0.
        """
        total_dim = self.hidden_irreps.dim
        if len(self.lp_choose) == 0:
            return torch.ones(total_dim, dtype=torch.float32)

        keep = self._parse_tokens(self.lp_choose)
        mask = torch.zeros(total_dim, dtype=torch.float32)

        offset = 0
        for mul, ir in self.hidden_irreps:  # each ir has fields l (int), p (±1)
            l = ir.l
            p_char = "e" if ir.p == 1 else "o"
            block_dim_single = 2 * l + 1
            block_dim_total = mul * block_dim_single

            if (l, p_char) in keep:
                mask[offset: offset + block_dim_total] = 1.0

            offset += block_dim_total

        if offset != total_dim:
            raise RuntimeError("Irreps dim mismatch while constructing mask.")
        return mask

    # ---------------- Optional API ----------------
    @torch.no_grad()
    def set_lp_choose(self, tokens: List[str]) -> None:
        """
        Update mask at runtime (e.g., finetuning).
        """
        self.lp_choose = list(tokens) if tokens is not None else []
        new_mask = self._build_mask().to(self.mask.device, dtype=self.mask.dtype)
        self.mask.data.copy_(new_mask)
